Source code for pgmpy.base.AncestralBase

from collections import deque
from collections.abc import Hashable, Iterable

import networkx as nx
import numpy as np

from pgmpy.base._mixin_roles import _GraphRolesMixin
from pgmpy.utils.parser import parse_dagitty


[docs] class AncestralBase(nx.Graph, _GraphRolesMixin): def __init__( self, ebunch: Iterable[tuple[Hashable, Hashable]] | None = None, latents: set[Hashable] = set(), exposures: set[Hashable] = set(), outcomes: set[Hashable] = set(), roles=None, ): """ Ancestral graph base class. Internally, each edge is stored with an attribute dictionary called ``marks``. The ``marks`` dict maps the two endpoint nodes to their respective marks, for example: - Directed: ("A", "B", "-", ">") is stored as ("A", "B", {"marks": {"A": "-", "B": ">"}}) - Bidirected: ("A", "B", ">", ">") is stored as ("A", "B", {"marks": {"A": ">", "B": ">"}}) - Undirected: ("A", "B", "-", "-") is stored as ("A", "B", {"marks": {"A": "-", "B": "-"}}) - Circle endpoint: ("A", "B", "o", ">") is stored as ("A", "B", {"marks": {"A": "o", "B": ">"}}) Parameters ---------- ebunch : Iterable[tuple], optional An iterable of edges of the form (u, v, u_mark, v_mark) used to initialize the graph. Each mark must be one of {">", "-", "o"}. Default is None, which initializes an empty graph. latents : set, optional Set of latent (unobserved) variables in the graph. Default is an empty set. exposures : set, optional Set of exposure variables in the graph. These are the variables that represent the treatment or intervention being studied in a causal analysis. Default is an empty set. outcomes : set, optional Set of outcome variables in the graph. These are the variables that represent the response or dependent variables being studied in a causal analysis. Default is an empty set. roles : dict, optional (default: None) A dictionary mapping roles to node names. The keys are roles, and the values are role names (strings or iterables of str). If provided, this will automatically assign roles to the nodes in the graph. Passing a key-value pair via ``roles`` is equivalent to calling ``with_role(role, variables)`` for each key-value pair in the dictionary. Examples -------- >>> from pgmpy.base import AncestralBase >>> edges = [("A", "B", "-", ">"), ("B", "C", ">", "-")] >>> graph = AncestralBase(ebunch=edges) >>> list(graph.edges(data=True)) [('A', 'B', {'marks': {'A': '-', 'B': '>'}}), ('B', 'C', {'marks': {'B': '>', 'C': '-'}})] >>> graph.add_edge("C", "D", "o", "o") >>> list(graph.edges(data=True)) [('A', 'B', {'marks': {'A': '-', 'B': '>'}}), ('B', 'C', {'marks': {'B': '>', 'C': '-'}}), ('C', 'D', {'marks': {'C': 'o', 'D': 'o'}})] Roles can be assigned to nodes in the graph at construction or using methods. At construction: >>> g = AncestralBase( ... ebunch=[("L", "A", "-", ">"), ("B", "C", "-", ">")], ... latents={"L"}, ... exposures={"A"}, ... outcomes={"B"}, ... ) Roles can also be assigned after creation using ``with_role`` method. >>> g = g.with_role("adjustment", {"L", "C"}) Vertices of a specific role can be retrieved using ``get_role`` method. >>> g.get_role("exposures") ["A"] >>> g.get_role("adjustment") ["L", "C"] """ super().__init__() self.valid_marks = {">", "-", "o"} if ebunch: self.add_edges_from(ebunch) self.latents = set(latents) self.exposures = set(exposures) self.outcomes = set(outcomes) if roles is None: roles = {} elif not isinstance(roles, dict): raise TypeError("Roles must be provided as dictionary") for role, vars in roles.items(): self.with_role(role=role, variables=vars, inplace=True) @property def adjacency_matrix(self): """ Return adjacency matrix with edge marks and node-to-index mapping. Returns ------- M : np.ndarray A square matrix of shape (n_nodes, n_nodes) where M[i, j] is the mark at node j for edge (i, j). node_index : dict Mapping from node label to row/col index. Examples -------- >>> from pgmpy.base import AncestralBase >>> edges = [("A", "B", "-", ">"), ("B", "C", ">", "-")] >>> graph = AncestralBase(ebunch=edges) >>> M, node_index = graph.adjacency_matrix >>> print(M) [[0 '>' 0] ['-' 0 '-'] [0 '>' 0]] >>> print(node_index) {'A': 0, 'B': 1, 'C': 2} """ nodes = list(self.nodes) n = len(nodes) node_index = {node: i for i, node in enumerate(nodes)} M = np.full((n, n), 0, dtype=object) for u, v, data in self.edges(data=True): u_idx, v_idx = node_index[u], node_index[v] u_mark = data["marks"][u] v_mark = data["marks"][v] M[u_idx, v_idx] = v_mark M[v_idx, u_idx] = u_mark return M, node_index @adjacency_matrix.setter def adjacency_matrix(self, value): """ Set graph edges from an adjacency matrix with edge marks. Parameters ---------- value : np.ndarray A square matrix where value[i, j] is the mark at node j for edge (i, j). Marks must be one of {">", "-", "o or 0 (no edge). Returns ------- None Examples -------- >>> from pgmpy.base import AncestralBase >>> M = np.array([[0, ">", 0], ["-", 0, ">"], [0, "-", 0]], dtype=object) >>> graph = AncestralBase() >>> graph.adjacency_matrix = M >>> print(graph.nodes) ['X_0', 'X_1', 'X_2'] >>> print(graph.edges(data=True)) [('X_0', 'X_1', {'marks': {'X_1': '-', 'X_0': '>'}}), ('X_1', 'X_2', {'marks': {'X_2': '-', 'X_1': '>'}})] """ value = np.asarray(value) if value.ndim != 2 or value.shape[0] != value.shape[1]: raise ValueError("Adjacency matrix must be square (n x n).") n = value.shape[0] variables = [f"X_{i}" for i in range(n)] self.clear() for i in range(n): for j in range(n): if i != j: u_mark = value[i, j] v_mark = value[j, i] if u_mark != 0 and v_mark != 0: self.add_edge(variables[i], variables[j], u_mark, v_mark)
[docs] def add_edge(self, u, v, u_mark, v_mark): """ Add an edge with specified marks. Parameters ---------- u : Hashable One endpoint of the edge. v : Hashable The other endpoint of the edge. u_mark : str Mark at node u for edge (u, v). Must be one of {">", "-", "o"}. v_mark : str Mark at node v for edge (u, v). Must be one of {">", "-", "o"}. Returns ------- None Adds the edge to the graph in-place Examples -------- >>> from pgmpy.base import AncestralBase >>> g = AncestralBase() # Directed edge A → B >>> g.add_edge("A", "B", "-", ">") >>> g["A"]["B"]["marks"] {'A': '-', 'B': '>'} # Bidirected edge A ↔ D >>> g.add_edge("A", "D", ">", ">") >>> g["A"]["D"]["marks"] {'A': '>', 'D': '>'} # Undirected edge C — E >>> g.add_edge("C", "E", "-", "-") >>> g["C"]["E"]["marks"] {'C': '-', 'E': '-'} """ if u == v: raise ValueError("Nodes cannot be the same for an edge.") if u_mark not in self.valid_marks or v_mark not in self.valid_marks: raise ValueError(f"Marks must be one of {self.valid_marks}.") super().add_edge(u, v, marks={u: u_mark, v: v_mark})
[docs] def add_edges_from(self, ebunch): """ Add multiple edges from an iterable of (u, v, marks) tuples. Parameters ---------- ebunch : Iterable[tuple] Each tuple should be of the form (u, v, u_mark, v_mark). Returns ------- None Adds the edges to the graph in-place. Examples -------- >>> from pgmpy.base import AncestralBase >>> g = AncestralBase() >>> edges = [("A", "B", "-", ">"), ("B", "C", ">", "-"), ("C", "D", "o", "o")] >>> g.add_edges_from(edges) >>> list(g.edges(data=True)) [('A', 'B', {'marks': {'A': '-', 'B': '>'}}), ('B', 'C', {'marks': {'B': '>', 'C': '-'}}), ('C', 'D', {'marks': {'C': 'o', 'D': 'o'}})] """ for u, v, u_mark, v_mark in ebunch: self.add_edge(u, v, u_mark, v_mark)
[docs] def get_neighbors(self, node, u_type=None, v_type=None): """ Get neighbors of a node with optional edge mark constraints. Parameters ---------- node : Hashable The node whose neighbors are to be found. u_type : Optional[str] Required mark at the given node for the edge. v_type : Optional[str] Required mark at the neighbor node for the edge. Returns ------- neighbors : set Set of neighboring nodes satisfying the mark constraints. Examples -------- >>> from pgmpy.base import AncestralBase >>> edges = [("A", "B", "-", ">"), ("B", "C", ">", "-"), ("C", "D", "o", "o")] >>> graph = AncestralBase(ebunch=edges) >>> print(graph.get_neighbors("B")) {'A', 'C'} >>> print(graph.get_neighbors("B", u_type=">")) {'C', 'A'} >>> print(graph.get_neighbors("B", v_type="-")) {'A', 'C'} >>> print(graph.get_neighbors("B", u_type=">", v_type="-")) {'C', 'A'} """ if node not in self: return set() neighbors = set() for neighbor in nx.all_neighbors(self, node): node_mark, neighbor_mark = ( self.edges[node, neighbor]["marks"][node], self.edges[node, neighbor]["marks"][neighbor], ) if (u_type is None or node_mark == u_type) and (v_type is None or neighbor_mark == v_type): neighbors.add(neighbor) return neighbors
[docs] def get_parents(self, node): """ Get nodes that point to this node with '>' Parameters ---------- node : Hashable The node whose parents are to be found. Returns ------- parents : set Set of parent nodes. Examples -------- >>> from pgmpy.base import AncestralBase >>> edges = [("A", "B", "-", ">"), ("C", "B", "-", ">"), ("B", "D", "-", ">")] >>> graph = AncestralBase(ebunch=edges) >>> print(graph.get_parents("B")) {'A', 'C'} >>> print(graph.get_parents("D")) {'B'} >>> print(graph.get_parents("A")) set() """ return self.get_neighbors(node, u_type=">", v_type="-")
[docs] def get_children(self, node): """ Get nodes that this node points to with '>' Parameters ---------- node : Hashable The node whose children are to be found. Returns ------- children : set Set of child nodes. Examples -------- >>> from pgmpy.base import AncestralBase >>> edges = [("A", "B", "-", ">"), ("A", "C", "-", ">"), ("B", "D", "-", ">")] >>> graph = AncestralBase(ebunch=edges) >>> print(graph.get_children("A")) {'B', 'C'} >>> print(graph.get_children("B")) {'D'} >>> print(graph.get_children("D")) set() """ return self.get_neighbors(node, u_type="-", v_type=">")
[docs] def get_spouses(self, node): """ Get nodes connected by bidirectional '>' edges (spouses). Parameters ---------- node : Hashable The node whose spouses are to be found. Returns ------- spouses : set Set of spouse nodes. Examples -------- >>> from pgmpy.base import AncestralBase >>> edges = [("A", "B", ">", ">"), ("A", "C", "-", ">"), ("C", "D", ">", ">")] >>> graph = AncestralBase(ebunch=edges) >>> print(graph.get_spouses("A")) {'B'} >>> print(graph.get_spouses("C")) {'D'} >>> print(graph.get_spouses("B")) {'A'} """ return self.get_neighbors(node, u_type=">", v_type=">")
[docs] def get_ancestors(self, node): """ Get all ancestor nodes of the given node. Parameters ---------- node : Hashable The node whose ancestors are to be found. Returns ------- ancestors : set Set of ancestor nodes including the starting node. Examples -------- >>> from pgmpy.base import AncestralBase >>> edges = [ ... ("A", "B", "-", ">"), ... ("B", "C", "-", ">"), ... ("C", "D", "-", ">"), ... ("E", "C", "-", ">"), ... ] >>> graph = AncestralBase(ebunch=edges) >>> print(graph.get_ancestors("D")) {'A', 'B', 'C', 'D', 'E'} >>> print(graph.get_ancestors("C")) {'A', 'B', 'C', 'E'} >>> print(graph.get_ancestors("A")) {'A'} """ ancestors = set() visited = set() queue = deque(node) while queue: current = queue.popleft() if current not in visited: visited.add(current) ancestors.add(current) queue.extend(self.get_parents(current)) return ancestors
[docs] def get_descendants(self, node): """ Get all descendant nodes (children, grandchildren, etc.) Parameters ---------- node : Hashable The starting node. Returns ------- descendants : set Set of descendant nodes including the starting node. Examples -------- >>> from pgmpy.base import AncestralBase >>> edges = [ ... ("A", "B", "-", ">"), ... ("B", "C", "-", ">"), ... ("C", "D", "-", ">"), ... ("B", "E", "-", ">"), ... ] >>> graph = AncestralBase(ebunch=edges) >>> print(graph.get_descendants("A")) {'A', 'B', 'C', 'D', 'E'} >>> print(graph.get_descendants("B")) {'B', 'C', 'D', 'E'} >>> print(graph.get_descendants("D")) {'D'} """ descendants = set() visited = set() queue = deque(node) while queue: current = queue.popleft() if current not in visited: visited.add(current) descendants.add(current) queue.extend(self.get_children(current)) return descendants
[docs] def get_reachable_nodes(self, node, u_type=None, v_type=None): """ Get all nodes reachable from the given node following edges with specified marks. Parameters ---------- node : Hashable The starting node. u_type : Optional[str] Required mark at the current node for traversal. v_type : Optional[str] Required mark at the neighbor node for traversal. Returns ------- reachable : set Set of reachable nodes including the starting node. Examples -------- >>> from pgmpy.base import AncestralBase >>> edges = [ ... ("A", "B", "-", ">"), ... ("B", "C", "-", ">"), ... ("A", "D", "o", "o"), ... ("D", "E", "o", "o"), ... ] >>> graph = AncestralBase(ebunch=edges) >>> print(graph.get_reachable_nodes("A", v_type=">")) {'A', 'B', 'C'} >>> print(graph.get_reachable_nodes("A", u_type="o", v_type="o")) {'A', 'D', 'E'} """ reachable = set() visited = set() queue = deque(node) while queue: current = queue.popleft() if current not in visited: visited.add(current) reachable.add(current) queue.extend(self.get_neighbors(current, u_type=u_type, v_type=v_type)) return reachable
[docs] def to_dagitty(self) -> str: """ Convert the MAG to a Dagitty string representation. Returns ------- str A string in Dagitty format representing the MAG. Examples -------- >>> from pgmpy.base import MAG >>> mag = MAG() >>> mag.add_edge("X", "Y", "-", ">") >>> mag.add_edge("Z", "Y", "-", ">") >>> print(mag.to_dagitty()) mag { X -> Y Z -> Y } >>> mag2 = MAG() >>> mag2.add_edge("A", "B", ">", ">") >>> mag2.add_edge("C", "D", "-", "-") >>> print(mag2.to_dagitty()) mag { A <-> B C -- D } >>> # MAG with latent variables and roles >>> mag3 = MAG() >>> mag3.add_edge("L", "X", "-", ">") >>> mag3.add_edge("X", "Y", "-", ">") >>> mag3.latents = {"L"} >>> mag3 = mag3.with_role("exposures", "X") >>> mag3 = mag3.with_role("outcomes", "Y") >>> print(mag3.to_dagitty()) mag { L -> X X -> Y L [latents] Y [outcome] X [exposure] } References ---------- dagitty syntax: https://cran.r-project.org/web/packages/dagitty/dagitty.pdf """ target_type = self.__class__.__name__ lines = [f"{target_type.lower()} {{"] edge_map = { ("-", ">"): "->", (">", "-"): "<-", (">", ">"): "<->", ("o", ">"): "@->", (">", "o"): "<-@", ("o", "o"): "@-@", ("o", "-"): "@--", ("-", "o"): "--@", ("-", "-"): "--", } for u, v in self.edges: marks = self.edges[u, v]["marks"] u_mark, v_mark = marks[u], marks[v] if (u_mark, v_mark) in edge_map: symbol = edge_map[(u_mark, v_mark)] if symbol in ["<-", "<-@", "--@"]: lines.append(f"{v} {symbol[::-1]} {u}") else: lines.append(f"{u} {symbol} {v}") dagitty_role_map = {"exposures": "exposure", "outcomes": "outcome"} for role in self.get_roles(): for var in self.get_role(role): lines.append(f"{var} [{dagitty_role_map.get(role, role)}]") lines.append("}") return "\n".join(lines)
[docs] @classmethod def from_dagitty(cls, string: str = None, filename: str = None): """ Populate the MAG from a Dagitty string representation. Parameters ---------- string : str A string in dagitty format representing the MAG. filename : str, optional Path to file containing Dagitty format string. Returns ------- MAG A new MAG instance created from the Dagitty representation. Examples -------- >>> from pgmpy.base import MAG >>> dag_str = '''dag { ... L -> A ... B -> C ... L [latents] ... B [outcome] ... A [exposure] ... }''' >>> mag = MAG.from_dagitty(dag_str) """ if filename: with open(filename) as f: dagitty_lines = [line.strip() for line in f.readlines()] elif string: dagitty_lines = [line.strip() for line in string.split("\n")] else: raise ValueError("Either `filename` or `string` need to be specified") ebunch, roles, _, nodes = parse_dagitty(dagitty_lines) return cls(ebunch=ebunch, roles=roles)
def __eq__(self, other): """ Checks if two MAGs are equal. Two MAGs are equal if they have the same nodes, edges(including marks), latent variables, and variable roles Parameters ---------- other: MAG object The other MAG to compare with Returns ------- bool True if the MAGs are equal, False otherwise Examples -------- >>> from pgmpy.base import MAG >>> mag1 = MAG( ... ebunch=[("X", "Y", "-", ">"), ("Y", "Z", "-", ">")], ... latents={"L"}, ... roles={"exposures": "X"}, ... ) >>> mag2 = MAG( ... ebunch=[("X", "Y", "-", ">"), ("Y", "Z", "-", ">")], ... latents={"L"}, ... roles={"exposures": "X"}, ... ) >>> mag1 == mag2 True >>> mag3 = MAG( ... ebunch=[("X", "Y", "-", ">")], latents={"L"}, roles={"exposures": "X"} ... ) >>> mag1 == mag3 False """ if not isinstance(other, AncestralBase): return False self_edges = {(u, v, frozenset(data["marks"].items())) for u, v, data in self.edges(data=True)} other_edges = {(u, v, frozenset(data["marks"].items())) for u, v, data in other.edges(data=True)} return ( set(self.nodes()) == set(other.nodes()) and self_edges == other_edges and self.latents == other.latents and self.get_role_dict() == other.get_role_dict() )
[docs] def copy(self): """ Return a copy of the graph, preserving nodes, edges, marks, latents, and roles. Returns ------- AncestralBase A new instance of the same class as self with all properties copied. """ ebunch = [(u, v, data["marks"][u], data["marks"][v]) for u, v, data in self.edges(data=True)] ancestral_base = self.__class__( ebunch=ebunch, latents=self.latents.copy(), exposures=self.exposures.copy(), outcomes=self.outcomes.copy(), ) for role, vars in self.get_role_dict().items(): ancestral_base.with_role(role=role, variables=vars, inplace=True) return ancestral_base