Source code for pgmpy.models.JunctionTree

#!/usr/bin/env python3

import networkx as nx

from pgmpy.models import ClusterGraph


[docs]class JunctionTree(ClusterGraph): """ Class for representing Junction Tree. Junction tree is undirected graph where each node represents a clique (list, tuple or set of nodes) and edges represent sepset between two cliques. Each sepset in G separates the variables strictly on one side of edge to other. Parameters ---------- data: input graph Data to initialize graph. If data=None (default) an empty graph is created. The data is an edge list. Examples -------- Create an empty JunctionTree with no nodes and no edges >>> from pgmpy.models import JunctionTree >>> G = JunctionTree() G can be grown by adding clique nodes. **Nodes:** Add a tuple (or list or set) of nodes as single clique node. >>> G.add_node(('a', 'b', 'c')) >>> G.add_nodes_from([('a', 'b'), ('a', 'b', 'c')]) **Edges:** G can also be grown by adding edges. >>> G.add_edge(('a', 'b', 'c'), ('a', 'b')) or a list of edges >>> G.add_edges_from([(('a', 'b', 'c'), ('a', 'b')), ... (('a', 'b', 'c'), ('a', 'c'))]) """ def __init__(self, ebunch=None): super(JunctionTree, self).__init__() if ebunch: self.add_edges_from(ebunch)
[docs] def add_edge(self, u, v, **kwargs): """ Add an edge between two clique nodes. Parameters ---------- u, v: nodes Nodes can be any list or set or tuple of nodes forming a clique. Examples -------- >>> from pgmpy.models import JunctionTree >>> G = JunctionTree() >>> G.add_nodes_from([('a', 'b', 'c'), ('a', 'b'), ('a', 'c')]) >>> G.add_edges_from([(('a', 'b', 'c'), ('a', 'b')), ... (('a', 'b', 'c'), ('a', 'c'))]) """ if u in self.nodes() and v in self.nodes() and nx.has_path(self, u, v): raise ValueError( f"Addition of edge between {str(u)} and {str(v)} forms a cycle breaking the properties of Junction Tree" ) super(JunctionTree, self).add_edge(u, v, **kwargs)
@property def states(self): """ Returns a dictionary mapping each node to its list of possible states. Returns ------- state_dict: dict Dictionary of nodes to possible states """ state_names_list = [phi.state_names for phi in self.factors] state_dict = { node: states for d in state_names_list for node, states in d.items() } return state_dict
[docs] def check_model(self): """ Check the model for various errors. This method checks for the following errors. In the same time also updates the cardinalities of all the random variables. * Checks if clique potentials are defined for all the cliques or not. * Check for running intersection property is not done explicitly over here as it done in the add_edges method. Returns ------- check: boolean True if all the checks are passed """ if not nx.is_connected(self): raise ValueError("The Junction Tree defined is not fully connected.") return super(JunctionTree, self).check_model()
[docs] def copy(self): """ Returns a copy of JunctionTree. Returns ------- JunctionTree : copy of JunctionTree Examples -------- >>> import numpy as np >>> from pgmpy.factors.discrete import DiscreteFactor >>> from pgmpy.models import JunctionTree >>> G = JunctionTree() >>> G.add_edges_from([(('a', 'b', 'c'), ('a', 'b')), (('a', 'b', 'c'), ('a', 'c'))]) >>> phi1 = DiscreteFactor(['a', 'b'], [1, 2], np.random.rand(2)) >>> phi2 = DiscreteFactor(['a', 'c'], [1, 2], np.random.rand(2)) >>> G.add_factors(phi1,phi2) >>> modelCopy = G.copy() >>> modelCopy.edges() EdgeView([(('a', 'b'), ('a', 'b', 'c')), (('a', 'c'), ('a', 'b', 'c'))]) >>> G.factors [<DiscreteFactor representing phi(a:1, b:2) at 0xb720ee4c>, <DiscreteFactor representing phi(a:1, c:2) at 0xb4e1e06c>] >>> modelCopy.factors [<DiscreteFactor representing phi(a:1, b:2) at 0xb4bd11ec>, <DiscreteFactor representing phi(a:1, c:2) at 0xb4bd138c>] """ copy = JunctionTree(self.edges()) copy.add_nodes_from(self.nodes()) if self.factors: factors_copy = [factor.copy() for factor in self.factors] copy.add_factors(*factors_copy) return copy