Junction Tree Exact Inference#
This example demonstrates how to construct a Junction Tree, define clique potentials and perform exact inference using belief propagation. We will model a simple chain of influence regarding university admission:
Difficulty (D) of a course affects the Grade (G).
Grade (G) affects the Recommendation Letter (L) quality.
Recommendation Letter (L) affects the Admission (S) result.
[1]:
from pgmpy.models import JunctionTree
from pgmpy.factors.discrete import DiscreteFactor
from pgmpy.inference import BeliefPropagation
/home/pranjal/Downloads/pgmpy/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
In a junction tree, each of the node is a clique. For our chain scenario, the cliques are pairs of adjacent variables:
Clique 1: (Difficulty, Grade)
Clique 2: (Grade, Letter)
Clique 3: (Letter, Admission)
[2]:
# Initializing the Junction Tree
junction_tree = JunctionTree()
print(junction_tree)
JunctionTree with 0 nodes and 0 edges
[3]:
# Defining Cliques
clique_d_g = ("Difficulty", "Grade")
clique_g_l = ("Grade", "Letter")
clique_l_a = ("Letter", "Admission")
# Adding Nodes
junction_tree.add_nodes_from([clique_d_g, clique_g_l, clique_l_a])
[4]:
print(f"Nodes: ",junction_tree.nodes())
Nodes: [('Difficulty', 'Grade'), ('Grade', 'Letter'), ('Letter', 'Admission')]
[5]:
# We connect cliques that share variables
junction_tree.add_edges_from([
(clique_d_g, clique_g_l),
(clique_g_l, clique_l_a)
])
[6]:
print(f"Edges: ",junction_tree.edges())
Edges: [(('Difficulty', 'Grade'), ('Grade', 'Letter')), (('Grade', 'Letter'), ('Letter', 'Admission'))]
For this example, we define binary discrete factors (0 = Low/False, 1 = High/True) to represent the relationships between these variables.
Phi 1 (Difficulty, Grade): Hard courses make high grades less likely.
Phi 2 (Grade, Letter): High grades lead to stronger letters.
Phi 3 (Letter, Admission): Strong letters increase admission chances.
[7]:
# Easy course: Low Grade, High Grade
# Hard course: Low Grade, High Grade
phi_d_g = DiscreteFactor(
variables=["Difficulty", "Grade"],
cardinality=[2, 2],
values=[0.3, 0.7,
0.8, 0.2]
)
# Low Grade: Weak Letter, Strong Letter
# High Grade: Weak Letter, Strong Letter
phi_g_l = DiscreteFactor(
variables=["Grade", "Letter"],
cardinality=[2, 2],
values=[0.9, 0.1,
0.1, 0.9]
)
# Weak Letter: Rejected, Accepted
# Strong Letter: Rejected, Accepted
phi_l_a = DiscreteFactor(
variables=["Letter", "Admission"],
cardinality=[2, 2],
values=[0.95, 0.05,
0.2, 0.8]
)
junction_tree.add_factors(phi_d_g, phi_g_l, phi_l_a)
[8]:
print(f"Model Validity: {junction_tree.check_model()}")
Model Validity: True
Since the Junction Tree is a tree structure, we can use Belief Propagation (BP) to perform exact inference. We will query the probability of getting Accepted.
[9]:
belief_propagation = BeliefPropagation(junction_tree)
[10]:
# Query: What is the probability of Admission?
marginal_admission = belief_propagation.query(variables=["Admission"])
print("\nMarginal Probability of Admission:")
print(marginal_admission)
Marginal Probability of Admission:
+--------------+------------------+
| Admission | phi(Admission) |
+==============+==================+
| Admission(0) | 0.6050 |
+--------------+------------------+
| Admission(1) | 0.3950 |
+--------------+------------------+
[11]:
# Query: Joint Probability of Grade and Letter
joint_grade_letter = belief_propagation.query(variables=["Grade", "Letter"])
print("\nJoint Probability (Grade, Letter):")
print(joint_grade_letter)
Joint Probability (Grade, Letter):
+----------+-----------+---------------------+
| Grade | Letter | phi(Grade,Letter) |
+==========+===========+=====================+
| Grade(0) | Letter(0) | 0.4950 |
+----------+-----------+---------------------+
| Grade(0) | Letter(1) | 0.0550 |
+----------+-----------+---------------------+
| Grade(1) | Letter(0) | 0.0450 |
+----------+-----------+---------------------+
| Grade(1) | Letter(1) | 0.4050 |
+----------+-----------+---------------------+
We can also query the model given some evidence.
[12]:
# Query: Probability of Admission given the course was Hard
print("\nProbability of Admission | Difficulty=Hard:")
query_evidence = belief_propagation.query(
variables=["Admission"],
evidence={"Difficulty": 1}
)
print(query_evidence)
Probability of Admission | Difficulty=Hard:
+--------------+------------------+
| Admission | phi(Admission) |
+==============+==================+
| Admission(0) | 0.7550 |
+--------------+------------------+
| Admission(1) | 0.2450 |
+--------------+------------------+