Source code for pgmpy.base.PDAG

import itertools
from typing import Hashable, Iterable

import networkx as nx

from pgmpy.base._mixin_roles import _GraphRolesMixin
from pgmpy.global_vars import logger


[docs] class PDAG(_GraphRolesMixin, nx.DiGraph): """ Class for representing PDAGs (also known as CPDAG). PDAGs are the equivalence classes of DAGs and contain both directed and undirected edges. Note: In this class, undirected edges are represented using two edges in both direction i.e. an undirected edge between X - Y is represented using X -> Y and X <- Y. Parameters ---------- directed_ebunch: list, array-like of 2-tuples List of directed edges in the PDAG. undirected_ebunch: list, array-like of 2-tuples List of undirected edges in the PDAG. latents: list, array-like List of nodes which are latent variables. exposures : set, default=set() Set of exposure variables in the graph. These are the variables that represent the treatment or intervention being studied in a causal analysis. Default is an empty set. outcomes : set, default=set() Set of outcome variables in the graph. These are the variables that represent the response or dependent variables being studied in a causal analysis. Default is an empty set. roles : dict, optional (default: None) A dictionary mapping roles to node names. The keys are roles, and the values are role names (strings or iterables of str). If provided, this will automatically assign roles to the nodes in the graph. Passing a key-value pair via ``roles`` is equivalent to calling ``with_role(role, variables)`` for each key-value pair in the dictionary. Examples -------- >>> from pgmpy.base import PDAG >>> pdag = PDAG( ... directed_ebunch=[("A", "C"), ("D", "C")], ... undirected_ebunch=[("B", "A"), ("B", "D")], ... latents=["E"], ... roles={"exposure": ["A"], "outcome": ["C"]}, ... ) >>> pdag.directed_edges {('A', 'C'), ('D', 'C')} >>> pdag.undirected_edges {('B', 'A'), ('B', 'D')} >>> pdag.latents {'E'} >>> pdag.exposures {'A'} """ def __init__( self, directed_ebunch: list[tuple[Hashable, Hashable]] = [], undirected_ebunch: list[tuple[Hashable, Hashable]] = [], latents: Iterable[Hashable] = [], exposures: set[Hashable] = set(), outcomes: set[Hashable] = set(), roles=None, ): self.directed_edges = set(directed_ebunch) self.undirected_edges = set(undirected_ebunch) super(PDAG, self).__init__( self.directed_edges.union(self.undirected_edges).union( set([(Y, X) for (X, Y) in self.undirected_edges]) ) ) self.latents = set(latents) self.exposures = set(exposures) self.outcomes = set(outcomes) if roles is None: roles = {} elif not isinstance(roles, dict): raise TypeError("Roles must be provided as a dictionary.") for role, vars in roles.items(): self.with_role(role=role, variables=vars, inplace=True)
[docs] def all_neighbors(self, node): """ Returns a set of all neighbors of a node in the PDAG. This includes both directed and undirected edges. Parameters ---------- node: any hashable python object The node for which to get the neighboring nodes. Returns ------- set: A set of neighboring nodes. Examples -------- >>> from pgmpy.base import PDAG >>> pdag = PDAG( ... directed_ebunch=[("A", "C"), ("D", "C")], ... undirected_ebunch=[("B", "A"), ("B", "D")], ... ) >>> pdag.all_neighbors("A") {'B', 'C'} """ return {x for x in self.successors(node)} | {x for x in self.predecessors(node)}
[docs] def directed_children(self, node): """ Returns a set of children of node such that there is a directed edge from `node` to child. """ return {x for x in self.successors(node) if (node, x) in self.directed_edges}
[docs] def directed_parents(self, node): """ Returns a set of parents of node such that there is a directed edge from the parent to `node`. """ return {x for x in self.predecessors(node) if (x, node) in self.directed_edges}
[docs] def has_directed_edge(self, u, v): """ Returns True if there is a directed edge u -> v in the PDAG. """ if (u, v) in self.directed_edges: return True else: return False
[docs] def has_undirected_edge(self, u, v): """ Returns True if there is an undirected edge u - v in the PDAG. """ if (u, v) in self.undirected_edges or (v, u) in self.undirected_edges: return True else: return False
[docs] def undirected_neighbors(self, node): """ Returns a set of neighboring nodes such that all of them have an undirected edge with `node`. Parameters ---------- node: any hashable python object The node for which to get the undirected neighboring nodes. Returns ------- set: A set of neighboring nodes. Examples -------- >>> from pgmpy.base import PDAG >>> pdag = PDAG( ... directed_ebunch=[("A", "C"), ("D", "C")], ... undirected_ebunch=[("B", "A"), ("B", "D")], ... ) >>> pdag.undirected_neighbors("A") {'B'} """ return {var for var in self.successors(node) if self.has_edge(var, node)}
[docs] def is_adjacent(self, u, v): """ Returns True if there is an edge between u and v. This can be either of u - v, u -> v, or u <- v. """ if (u, v) in self.edges or (v, u) in self.edges: return True else: return False
[docs] def copy(self): """ Returns a copy of the object instance. Returns ------- Copy of PDAG: pgmpy.dag.PDAG Returns a copy of self. """ pdag = PDAG( directed_ebunch=list(self.directed_edges.copy()), undirected_ebunch=list(self.undirected_edges.copy()), latents=self.latents, ) pdag.add_nodes_from(self.nodes()) for role, vars in self.get_role_dict().items(): pdag.with_role(role=role, variables=vars, inplace=True) return pdag
def _directed_graph(self): """ Returns a subgraph containing only directed edges. """ dag = nx.DiGraph(self.directed_edges) dag.add_nodes_from(self.nodes()) return dag
[docs] def orient_undirected_edge(self, u, v, inplace=False): """ Orients an undirected edge u - v as u -> v. Parameters ---------- u, v: Any hashable python objects The node names. inplace: boolean (default=False) If True, the PDAG object is modified inplace, otherwise a new modified copy is returned. Returns ------- None or pgmpy.base.PDAG: The modified PDAG object. If inplace=True, returns None and the object itself is modified. If inplace=False, returns a PDAG object. """ if inplace: pdag = self else: pdag = self.copy() # Remove the edge for undirected_edges. if (u, v) in pdag.undirected_edges: pdag.undirected_edges.discard((u, v)) elif (v, u) in pdag.undirected_edges: pdag.undirected_edges.discard((v, u)) else: raise ValueError(f"Undirected Edge {u} - {v} not present in the PDAG.") # Remove the inverse edge from the graph pdag.remove_edge(v, u) # Add the edge to directed_edges. pdag.directed_edges.add((u, v)) if not inplace: return pdag
def _check_new_unshielded_collider(self, u, v): """ Tests if orienting an undirected edge u - v as u -> v creates new unshielded V-structures in the PDAG. Checks whether v has any directed parents other than u that are not adjacent to u. Returns ------- True, if the orientation u -> v would lead to creation of a new V-structure. False, if no new V-structures are formed. """ for node in self.directed_parents(v): if (node != u) and (not self.is_adjacent(u, node)): return True return False
[docs] def apply_meeks_rules(self, apply_r4=False, inplace=False, debug=False): """ Applies the Meek's rules to orient the undirected edges of a PDAG to return a CPDAG. Parameters ---------- apply_r4: boolean (default=False) If True, applies Rules 1 - 4 of Meek's rules. If False, applies only Rules 1 - 3. inplace: boolean (default=False) If True, the PDAG object is modified inplace, otherwise a new modified copy is returned. debug: boolean (default=False) If True, prints the rules being applied to the PDAG. Returns ------- None or pgmpy.base.PDAG: The modified PDAG object. If inplace=True, returns None and the object itself is modified. If inplace=False, returns a PDAG object. Examples -------- >>> from pgmpy.base import PDAG >>> pdag = PDAG( ... directed_ebunch=[("A", "B")], undirected_ebunch=[("B", "C"), ("C", "B")] ... ) >>> pdag.apply_meeks_rules() >>> pdag.directed_edges {('A', 'B'), ('B', 'C')} """ if inplace: pdag = self else: pdag = self.copy() changed = True while changed: changed = False # Rule 1: If X -> Y - Z and # (X not adj Z) and # (adding Y -> Z doesn't create cycle) and # (adding Y -> Z doesn't create an unshielded collider) => Y → Z for y in pdag.nodes(): # Select x's such that there are directed edges x -> y. for x in pdag.directed_parents(y): for z in pdag.undirected_neighbors(y): if ( (not pdag.is_adjacent(x, z)) and (not pdag._check_new_unshielded_collider(y, z)) and (not nx.has_path(pdag._directed_graph(), z, y)) ): pdag.orient_undirected_edge(y, z, inplace=True) changed = True if debug: logger.info( f"Applying Rule 1: {x} -> {y} - {z} => {x} -> {y} -> {z}" ) # Rule 2: If X -> Z -> Y and X - Y => X → Y for z in pdag.nodes(): xs = pdag.directed_parents(z) ys = pdag.directed_children(z) for x in xs: for y in ys: if pdag.has_undirected_edge(x, y): pdag.orient_undirected_edge(x, y, inplace=True) changed = True if debug: logger.info( f"Applying Rule 2: {x} -> {z} -> {y} and {x} - {y} => {x} -> {y}" ) # Rule 3: If X - {Y, Z, W} and {Z, Y} -> W => X -> W for x in pdag.nodes(): undirected_nbs = pdag.undirected_neighbors(x) if len(undirected_nbs) < 3: continue for y, z, w in itertools.permutations(undirected_nbs, 3): if pdag.has_directed_edge(y, w) and pdag.has_directed_edge(z, w): pdag.orient_undirected_edge(x, w, inplace=True) changed = True if debug: logger.info( f"Applying Rule 3: {x} - {y}, {z}, {w} " f"{y}, {z} -> {w} => {x} -> {w}" ) break # Rule 4: If d -> c -> b & a - {b, c, d} and b not adj d => a -> b if apply_r4: for c in pdag.nodes(): for b in pdag.directed_children(c): for d in pdag.directed_parents(c): if b == d or pdag.is_adjacent(b, d): continue # b adjacent d => rule not applicable # find nodes a that are undirected neighbor to b, d, # and directed or undirected neighbor to c cand = set(pdag.undirected_neighbors(b)).intersection( pdag.all_neighbors(c), pdag.undirected_neighbors(d), ) for a in cand: pdag.orient_undirected_edge(a, b, inplace=True) changed = True break if not inplace: return pdag
[docs] def to_dag(self): """ Returns one possible DAG which is represented using the PDAG. Returns ------- pgmpy.base.DAG: Returns an instance of DAG. Examples -------- >>> pdag = PDAG( ... directed_ebunch=[("A", "B"), ("C", "B")], ... undirected_ebunch=[("C", "D"), ("D", "A")], ... ) >>> dag = pdag.to_dag() >>> print(dag.edges()) OutEdgeView([('A', 'B'), ('C', 'B'), ('D', 'C'), ('A', 'D')]) References ---------- [1] Dor, Dorit, and Michael Tarsi. "A simple algorithm to construct a consistent extension of a partially oriented graph." Technicial Report R-185, Cognitive Systems Laboratory, UCLA (1992): 45. """ # Add required edges if it doesn't form a new v-structure or an opposite edge # is already present in the network. from pgmpy.base import DAG dag = DAG() # Add all the nodes and the directed edges dag.add_nodes_from(self.nodes()) dag.add_edges_from(self.directed_edges) dag.latents = self.latents pdag = self.copy() while pdag.number_of_nodes() > 0: # find node with (1) no directed outgoing edges and # (2) the set of undirected neighbors is either empty or # undirected neighbors + parents of X are adjacent found = False for X in sorted(pdag.nodes()): undirected_neighbors = pdag.undirected_neighbors(X) neighbors_are_adjacent = all( ( pdag.has_edge(Y, Z) or pdag.has_edge(Z, Y) for Z in pdag.all_neighbors(X) for Y in undirected_neighbors if not Y == Z ) ) if not pdag.directed_children(X) and ( not undirected_neighbors or neighbors_are_adjacent ): found = True # add all edges of X as outgoing edges to dag for Y in pdag.undirected_neighbors(X): dag.add_edge(Y, X) pdag.remove_node(X) break if not found: logger.warning( "PDAG has no faithful extension (= no oriented DAG with the " + "same v-structures as PDAG). Remaining undirected PDAG edges " + "oriented arbitrarily." ) for X, Y in pdag.edges(): if not dag.has_edge(Y, X): try: dag.add_edge(X, Y) except ValueError: pass break return dag
[docs] def to_graphviz(self) -> object: """ Retuns a pygraphviz object for the DAG. pygraphviz is useful for visualizing the network structure. Examples -------- >>> from pgmpy.utils import get_example_model >>> model = get_example_model("alarm") >>> model.to_graphviz() <AGraph <Swig Object of type 'Agraph_t *' at 0x7fdea4cde040>> """ return nx.nx_agraph.to_agraph(self)
def __eq__(self, other): """ Checks if two PDAGs are equal. Two PDAGs are considered equal if they have the same nodes, edges, latent variables, and variable roles. Parameters ---------- other: PDAG object The other PDAG to compare with. Returns ------- bool True if the PDAGs are equal, False otherwise. """ if not isinstance(other, PDAG): return False return ( set(self.nodes()) == set(other.nodes()) and set(self.directed_edges) == set(other.directed_edges) and set(self.undirected_edges) == set(other.undirected_edges) and self.latents == other.latents and self.get_role_dict() == other.get_role_dict() )