import collections
import networkx as nx
from networkx import MultiDiGraph
from pgmpy.base._mixin_roles import _GraphRolesMixin
from pgmpy.base.DAG import DAG as pgmpy_DAG
[docs]
class ADMG(_GraphRolesMixin, MultiDiGraph):
"""
A class representing an Acyclic Directed Mixed Graph (ADMG).
An ADMG is a directed graph that allows for both directed and bidirected edges.
This class extends the `networkx.MultiDiGraph` and provides additional functionality
for operations involving directed and bidirected edges.
Parameters
----------
directed_ebunch : list of tuple, optional
List of directed edges to initialize the graph, where each tuple is (u, v).
bidirected_ebunch : list of tuple, optional
List of bidirected edges to initialize the graph, where each tuple is (u, v).
latents : set of str, optional
Set of latent variables in the graph. These are not directly represented as nodes
but are used to indicate the presence of bidirected edges.
roles : dict, optional (default: None)
A dictionary mapping roles to node names.
The keys are roles, and the values are role names (strings or iterables of str).
If provided, this will automatically assign roles to the nodes in the graph.
Passing a key-value pair via ``roles`` is equivalent to calling
``with_role(role, variables)`` for each key-value pair in the dictionary.
Examples
--------
>>> from pgmpy.base.ADMG import ADMG
>>> admg = ADMG(
... directed_ebunch=[("X", "Y"), ("Z", "Y")], bidirected_ebunch=[("X", "Z")]
... )
>>> sorted(admg.nodes())
['X', 'Y', 'Z']
>>> sorted(admg.edges())
[('X', 'Y'), ('X', 'Z'), ('Z', 'X'), ('Z', 'Y')]
>>> admg.latents
set()
"""
def __init__(
self,
directed_ebunch=None,
bidirected_ebunch=None,
latents=None,
roles=None,
):
super().__init__()
# Using edge attributes to distinguish bidirected edges
if directed_ebunch:
self.add_directed_edges(directed_ebunch)
if bidirected_ebunch:
self.add_bidirected_edges(bidirected_ebunch)
self.latents = set(latents) if latents else set()
if roles is None:
roles = {}
elif not isinstance(roles, dict):
raise TypeError("Roles must be provided as a dictionary.")
# set the roles to the vertices as networkx attributes
for role, vars in roles.items():
self.with_role(role=role, variables=vars, inplace=True)
[docs]
def add_directed_edges(self, ebunch):
"""
Add directed edges (u -> v) to the ADMG.
Parameters
----------
ebunch : list of tuple
List of directed edges, where each tuple is (u, v).
Examples
--------
>>> from pgmpy.base.ADMG import ADMG
>>> admg = ADMG()
>>> admg.add_directed_edges([("X", "Y"), ("Y", "Z")])
>>> sorted(admg.nodes())
['X', 'Y', 'Z']
>>> sorted(admg.edges())
[('X', 'Y'), ('Y', 'Z')]
"""
for u, v in ebunch:
if u is None or v is None:
raise ValueError("Can't add since one of nodes is None")
key = super().add_edge(u, v, type="directed")
if not nx.is_directed_acyclic_graph(self):
super().remove_edge(u, v, key=key)
raise ValueError("Adding this edge would create a cycle in the graph.")
[docs]
def add_bidirected_edges(self, ebunch):
"""
Add bidirected edges (u <-> v) to the ADMG.
Parameters
----------
ebunch : list of tuple
List of bidirected edges, where each tuple is (u, v).
Examples
--------
>>> from pgmpy.base.ADMG import ADMG
>>> admg = ADMG()
>>> admg.add_bidirected_edges([("X", "Z")])
>>> sorted(admg.nodes())
['X', 'Z']
>>> sorted(admg.edges())
[('X', 'Z'), ('Z', 'X')]
"""
for u, v in ebunch:
if u is None or v is None:
raise ValueError("Can't add since one of the nodes is None")
if u == v:
raise ValueError("Cannot add a bidirected edge from a node to itself.")
# Add two directed edges with a 'type' attribute indicating bidirected
super().add_edge(u, v, type="bidirected")
super().add_edge(v, u, type="bidirected")
[docs]
def add_edge(self, u, v, **kwargs):
"""
Raise an error if trying to add a regular edge.
Examples
--------
>>> from pgmpy.base.ADMG import ADMG
>>> admg = ADMG()
>>> admg.add_edge("X", "Y")
Traceback (most recent call last):
...
NotImplementedError: Use add_directed_edge or add_bidirected_edge to add edges.
"""
raise NotImplementedError("Use add_directed_edge or add_bidirected_edge to add edges.")
[docs]
def get_directed_parents(self, nodes):
"""
Get directed parents of given nodes.
Parameters
----------
nodes : str or iterable of str
Node or list of nodes to query.
Returns
-------
set
Set of directed parents.
Examples
--------
>>> from pgmpy.base.ADMG import ADMG
>>> admg = ADMG(directed_ebunch=[("X", "Y"), ("Z", "Y")])
>>> sorted(admg.get_directed_parents("Y"))
['X', 'Z']
>>> admg.get_directed_parents("X")
set()
"""
nodes_set = {nodes} if isinstance(nodes, str) else set(nodes)
directed_parents = set()
for node in nodes_set:
if node not in self.nodes:
raise ValueError(f"Node {node} is not in the graph.")
for pred in self.predecessors(node):
data = self.get_edge_data(pred, node)
for key in data:
if data[key].get("type") == "directed":
directed_parents.add(pred)
return directed_parents
[docs]
def get_bidirected_parents(self, nodes):
"""
Get bidirected parents (nodes connected via bidirected edge) of the given nodes.
Parameters
----------
nodes : str or iterable of str
Node or list of nodes to query.
Returns
-------
set
Set of bidirected parents.
Examples
--------
>>> from pgmpy.base.ADMG import ADMG
>>> admg = ADMG(directed_ebunch=[("X", "Y")], bidirected_ebunch=[("X", "Z")])
>>> sorted(admg.get_bidirected_parents("X"))
['Z']
>>> admg.get_bidirected_parents("Y")
set()
"""
nodes_set = {nodes} if isinstance(nodes, str) else set(nodes)
bidirected_parents = set()
for node in nodes_set:
if node not in self.nodes:
raise ValueError(f"Node {node} is not in the graph.")
# Get neighbors and check for bidirected edges
for neighbor in super().neighbors(node):
if (
self.has_edge(node, neighbor) and self.get_edge_data(node, neighbor, 0).get("type") == "bidirected"
) or (
self.has_edge(neighbor, node) and self.get_edge_data(neighbor, node, 0).get("type") == "bidirected"
):
bidirected_parents.add(neighbor)
return bidirected_parents
[docs]
def get_children(self, nodes):
"""
Get children of given nodes (i.e., targets of outgoing directed edges).
Parameters
----------
nodes : str or iterable of str
Node or list of nodes.
Returns
-------
set
Set of children nodes.
Examples
--------
>>> from pgmpy.base.ADMG import ADMG
>>> admg = ADMG(directed_ebunch=[("X", "Y"), ("X", "Z")])
>>> sorted(admg.get_children("X"))
['Y', 'Z']
>>> admg.get_children("Y")
set()
"""
nodes_set = {nodes} if isinstance(nodes, str) else set(nodes)
children = set()
for node in nodes_set:
if node not in self.nodes:
raise ValueError(f"Node {node} is not in the graph.")
for successor in super().successors(node):
# Only consider truly directed edges
if self.get_edge_data(node, successor, 0)["type"] == "directed":
children.add(successor)
return children
[docs]
def get_spouses(self, nodes):
"""
Get spouses of given nodes (i.e., nodes connected via bidirected edges).
Parameters
----------
nodes : str or iterable of str
Node or list of nodes.
Returns
-------
set
Set of spouses.
Examples
--------
>>> from pgmpy.base.ADMG import ADMG
>>> admg = ADMG(directed_ebunch=[("X", "Y")], bidirected_ebunch=[("X", "Z")])
>>> sorted(admg.get_spouses("X"))
['Z']
>>> admg.get_spouses("Y")
set()
"""
nodes_set = {nodes} if isinstance(nodes, str) else set(nodes)
spouses = set()
for node in nodes_set:
if node not in self.nodes:
raise ValueError(f"Node {node} is not in the graph.")
for neighbor in super().neighbors(node):
# Check if the edge to/from the neighbor is bidirected
if (
self.has_edge(node, neighbor) and self.get_edge_data(node, neighbor, 0).get("type") == "bidirected"
) or (
self.has_edge(neighbor, node) and self.get_edge_data(neighbor, node, 0).get("type") == "bidirected"
):
spouses.add(neighbor)
return spouses
[docs]
def get_ancestors(self, nodes):
"""
Get ancestors of given nodes via directed paths.
Parameters
----------
nodes : str or iterable of str
Node or list of nodes.
Returns
-------
set
Set of ancestor nodes including the input nodes.
Examples
--------
>>> from pgmpy.base.ADMG import ADMG
>>> admg = ADMG(directed_ebunch=[("X", "Y"), ("Y", "Z")])
>>> sorted(admg.get_ancestors("Z"))
['X', 'Y', 'Z']
>>> sorted(admg.get_ancestors("X"))
['X']
"""
nodes_set = {nodes} if isinstance(nodes, str) else set(nodes)
ancestors = set()
for node in nodes_set:
if node in self.nodes:
# Use a temporary graph containing only directed edges for ancestry
temp_dag = nx.DiGraph()
for u, v, key, data in self.edges(keys=True, data=True):
if data.get("type") == "directed":
temp_dag.add_edge(u, v)
if node in temp_dag: # Ensure node exists in the temp_dag
ancestors.update(nx.ancestors(temp_dag, node).union({node}))
return ancestors
[docs]
def get_descendants(self, nodes):
"""
Get descendants of given nodes via directed paths.
Parameters
----------
nodes : str or iterable of str
Node or list of nodes.
Returns
-------
set
Set of descendant nodes including the input nodes.
Examples
--------
>>> from pgmpy.base.ADMG import ADMG
>>> admg = ADMG(directed_ebunch=[("X", "Y"), ("Y", "Z")])
>>> sorted(admg.get_descendants("X"))
['X', 'Y', 'Z']
>>> sorted(admg.get_descendants("Z"))
['Z']
"""
nodes_set = {nodes} if isinstance(nodes, str) else set(nodes)
descendants = set()
for node in nodes_set:
if node in self.nodes:
# Use a temporary graph containing only directed edges for descendants
temp_dag = nx.DiGraph()
for u, v, key, data in self.edges(keys=True, data=True):
if data.get("type") == "directed":
temp_dag.add_edge(u, v)
if node in temp_dag: # Ensure node exists in the temp_dag
descendants.update(nx.descendants(temp_dag, node).union({node}))
return descendants
[docs]
def get_district(self, nodes):
"""
Return district of a node: maximal set connected via bidirected edges.
Parameters
----------
nodes : str or iterable of str
Node or list of nodes.
Returns
-------
set
Nodes in the same bidirected-connected component.
Examples
--------
>>> from pgmpy.base.ADMG import ADMG
>>> admg = ADMG(directed_ebunch=[("X", "Y")], bidirected_ebunch=[("X", "Z")])
>>> sorted(admg.get_district("X"))
['X', 'Z']
>>> admg.get_district("Y")
{'Y'}
"""
nodes_set = {nodes} if isinstance(nodes, str) else set(nodes)
all_districts = set()
for start_node in nodes_set:
if start_node not in self.nodes:
continue
district_components = set()
queue = collections.deque([start_node])
visited = {start_node}
while queue:
currentNode = queue.popleft()
district_components.add(currentNode)
# Iterate through all neighbors and check for bidirected edges
for neighbor in super().neighbors(currentNode):
if (
self.has_edge(currentNode, neighbor)
and self.get_edge_data(currentNode, neighbor, 0).get("type") == "bidirected"
) or (
self.has_edge(neighbor, currentNode)
and self.get_edge_data(neighbor, currentNode, 0).get("type") == "bidirected"
):
if neighbor not in visited:
visited.add(neighbor)
queue.append(neighbor)
for predecessor in super().predecessors(currentNode):
if (
self.has_edge(currentNode, predecessor)
and self.get_edge_data(currentNode, predecessor, 0).get("type") == "bidirected"
) or (
self.has_edge(predecessor, currentNode)
and self.get_edge_data(predecessor, currentNode, 0).get("type") == "bidirected"
):
if predecessor not in visited:
visited.add(predecessor)
queue.append(predecessor)
all_districts.update(district_components)
return all_districts
[docs]
def get_ancestral_graph(self, nodes):
"""
Return the ancestral graph induced by the input nodes.
Parameters
----------
nodes : str or iterable of str
Node or list of nodes to induce subgraph on.
Returns
-------
ADMG
Subgraph induced by ancestors of the given nodes.
Raises
------
ValueError
If any input node is not in the graph.
Examples
--------
>>> from pgmpy.base.ADMG import ADMG
>>> admg = ADMG(
... directed_ebunch=[("X", "Y"), ("Y", "Z")], bidirected_ebunch=[("X", "Z")]
... )
>>> anc = admg.get_ancestral_graph(["Y", "Z"])
>>> sorted(anc.nodes())
['Y', 'Z']
>>> anc2 = admg.get_ancestral_graph(["X", "Y", "Z"])
>>> sorted(anc2.nodes())
['X', 'Y', 'Z']
"""
nodes_set = {nodes} if isinstance(nodes, str) else set(nodes)
if not nodes_set.issubset(self.nodes):
raise ValueError("Input nodes must be subset of graph's nodes.")
# Create a new ADMG instance for the ancestral graph
new_admg = ADMG()
new_admg.add_nodes_from(list(nodes_set)) # Add all nodes in nodes_set
# Add directed edges from the original graph that have both endpoints in nodes_set
for u, v, key, data in self.edges(keys=True, data=True):
if data.get("type") == "directed" and u in nodes_set and v in nodes_set:
new_admg.add_directed_edges([(u, v)]) # Use add_directed_edges to maintain cycle check
# Add bidirected edges from the original graph that have both endpoints in nodes_set
processed_bidirected_pairs = set()
for u, v, key, data in self.edges(keys=True, data=True):
if data.get("type") == "bidirected":
if u in nodes_set and v in nodes_set:
# Ensure we add each bidirected pair only once in the new graph
if (u, v) not in processed_bidirected_pairs and (
v,
u,
) not in processed_bidirected_pairs:
new_admg.add_bidirected_edges([(u, v)])
processed_bidirected_pairs.add((u, v))
processed_bidirected_pairs.add((v, u)) # Mark both directions as processed
return new_admg
[docs]
def get_markov_blanket(self, nodes):
"""
Compute the Markov blanket for the given node(s).
Includes:
- Parents
- Children
- Spouses (nodes sharing a child)
- Parents of nodes in the district
Parameters
----------
nodes : str or iterable of str
Node or list of nodes.
Returns
-------
set
Set of nodes in the Markov blanket.
Examples
--------
>>> from pgmpy.base.ADMG import ADMG
>>> admg = ADMG(
... directed_ebunch=[("X", "Y"), ("Z", "Y")], bidirected_ebunch=[("X", "Z")]
... )
>>> sorted(admg.get_markov_blanket("Y"))
['X', 'Z']
"""
nodes_set = {nodes} if isinstance(nodes, set) else set(nodes)
if not nodes_set.issubset(self.nodes):
raise ValueError("Input nodes must be subset of graph's nodes.")
markov_blanket = set()
for node in nodes_set:
if node not in self.nodes:
raise ValueError(f"Node {node} is not in the graph.")
# Get parents
parents = self.get_directed_parents(node)
district_parents = self.get_bidirected_parents(node)
markov_blanket.update(parents)
markov_blanket.update(district_parents)
# Get children
children = self.get_children(node)
markov_blanket.update(children)
# Get spouses
spouses = self.get_spouses(node)
markov_blanket.update(spouses)
return markov_blanket
[docs]
def to_dag(self):
"""
Project ADMG into a DAG by introducing latent variables for bidirected edges.
Returns
-------
pgmpy.base.DAG.DAG
DAG with latent variables replacing bidirected edges.
Examples
--------
>>> from pgmpy.base.ADMG import ADMG
>>> admg = ADMG(directed_ebunch=[("X", "Y")], bidirected_ebunch=[("X", "Z")])
>>> dag = admg.to_dag()
>>> "L_X_Z" in dag.nodes()
True
>>> ("X", "Y") in dag.edges()
True
"""
dag_edges = []
# Add directed edges
for u, v, data in self.edges(data=True):
if data.get("type") == "directed":
dag_edges.append((u, v))
# add latent nodes and edges for bidirected edges
latent_nodes_map = {}
for u, v, data in self.edges(data=True):
if data.get("type") == "bidirected":
sorted_pair = tuple(sorted((u, v)))
if sorted_pair not in latent_nodes_map:
latent_var = f"L_{sorted_pair[0]}_{sorted_pair[1]}"
latent_nodes_map[sorted_pair] = latent_var
dag_edges.append((latent_var, sorted_pair[0]))
dag_edges.append((latent_var, sorted_pair[1]))
dag_nodes = set(self.nodes()) | set(latent_nodes_map.values())
# Create a new DAG instance
dag_instance = pgmpy_DAG()
dag_instance.add_nodes_from(dag_nodes)
dag_instance.add_edges_from(dag_edges)
return dag_instance
[docs]
def is_mseparated(
self,
nodes_u,
nodes_v,
conditional_set=None,
):
"""
Test m-separation between two sets of nodes given a conditioning set.
Parameters
----------
nodes_u : str or iterable of str
First set of nodes.
nodes_v : str or iterable of str
Second set of nodes.
conditional_set : set of str, optional
Conditioning set (default is empty set).
Returns
-------
bool
True if nodes_u and nodes_v are m-separated; False otherwise.
Examples
--------
>>> from pgmpy.base.ADMG import ADMG
>>> admg = ADMG(directed_ebunch=[("X", "Y"), ("Z", "Y")])
>>> admg.is_mseparated("X", "Z")
True
>>> admg.is_mseparated("X", "Z", conditional_set={"Y"})
False
"""
if conditional_set is None:
conditional_set = set()
# Convert nodes_u and nodes_v to sets
nodes_u_set = {nodes_u} if isinstance(nodes_u, str) else set(nodes_u)
nodes_v_set = {nodes_v} if isinstance(nodes_v, str) else set(nodes_v)
new_dag = self.to_dag()
for u in nodes_u_set:
for v in nodes_v_set:
# if they are dconnected, they are not mseparated
if new_dag.is_dconnected(u, v, observed=conditional_set):
return False
return True
[docs]
def is_mconnected(
self,
nodes_u,
nodes_v,
conditional_set=None,
):
"""
Test m-connectedness between two node sets given a conditioning set.
Parameters
----------
nodes_u : str or iterable of str
First set of nodes.
nodes_v : str or iterable of str
Second set of nodes.
conditional_set : set of str, optional
Conditioning set.
Returns
-------
bool
True if m-connected; False if m-separated.
Examples
--------
>>> from pgmpy.base.ADMG import ADMG
>>> admg = ADMG(directed_ebunch=[("X", "Y"), ("Z", "Y")])
>>> admg.is_mconnected("X", "Z", conditional_set={"Y"})
True
>>> admg.is_mconnected("X", "Z")
False
"""
return not self.is_mseparated(nodes_u, nodes_v, conditional_set)
[docs]
def mconnected_nodes(self, nodes_u, nodes_v=None, conditional_set=None):
"""
Find all nodes m-connected to nodes in `nodes_u` given `conditional_set`.
Parameters
----------
nodes_u : str or iterable of str
Set of source nodes.
nodes_v : str or iterable of str, optional
If provided, filters the result to this set.
conditional_set : set of str, optional
Conditioning set (default is empty set).
Returns
-------
set
Nodes m-connected to `nodes_u` (or their intersection with `nodes_v` if provided).
Examples
--------
>>> from pgmpy.base.ADMG import ADMG
>>> admg = ADMG(directed_ebunch=[("X", "Y"), ("Y", "Z")])
>>> sorted(admg.mconnected_nodes("X", nodes_v=["Y", "Z"]))
['Y', 'Z']
>>> sorted(admg.mconnected_nodes("X", nodes_v=["Z"]))
['Z']
"""
if conditional_set is None:
conditional_set = set()
dag = self.to_dag()
if isinstance(nodes_u, str):
nodes_u = [nodes_u]
m_connected_set = set()
for node in nodes_u:
active_trail = dag.active_trail_nodes(node, observed=conditional_set)
# active_trail_nodes returns a dict {node: set_of_active_nodes}
for active_nodes in active_trail.values():
m_connected_set.update({n for n in active_nodes if not str(n).startswith("L_")})
if nodes_v is not None:
nodes_v_set = {nodes_v} if isinstance(nodes_v, str) else set(nodes_v)
return m_connected_set & nodes_v_set
return m_connected_set
def __eq__(self, other):
"""
Check if two ADMGs are equal.
Two ADMGs are considered equal if they have the same nodes, edges,
latent variables, and variable roles.
Parameters
----------
other : ADMG
The other ADMG to compare with.
Returns
-------
bool
True if the ADMGs are equal, False otherwise.
Examples
--------
>>> from pgmpy.base.ADMG import ADMG
>>> admg1 = ADMG(directed_ebunch=[("X", "Y")], bidirected_ebunch=[("X", "Z")])
>>> admg2 = ADMG(directed_ebunch=[("X", "Y")], bidirected_ebunch=[("X", "Z")])
>>> admg1 == admg2
True
>>> admg3 = ADMG(directed_ebunch=[("X", "Y")])
>>> admg1 == admg3
False
"""
if not isinstance(other, ADMG):
return False
if (
set(self.nodes()) != set(other.nodes())
or self.latents != other.latents
or self.get_role_dict() != other.get_role_dict()
or set(self.edges()) != set(other.edges())
):
return False
# Check edges type more details ('directed' or 'bidirected').
for u, v in self.edges():
if self.get_edge_data(u, v, 0)["type"] != other.get_edge_data(u, v, 0)["type"]:
return False
return True