6. Causal Inference Examples

6.1. Simpson’s paradox

[1]:
from pgmpy.models import BayesianNetwork
from pgmpy.inference import VariableElimination
from pgmpy.factors.discrete import TabularCPD
from pgmpy.inference import CausalInference

6.1.1. Model Definition

[2]:
simp_model = BayesianNetwork([('S', 'T'), ('T', 'C'), ('S', 'C')])
simp_model.to_daft(node_pos={'T': (0, 0), 'C': (2, 0), 'S': (1, 1)}).render()
[2]:
<matplotlib.axes._axes.Axes at 0x7f1761cbb700>
../_images/examples_Causal_Inference_4_1.png
[3]:
cpd_s = TabularCPD(variable='S',
                   variable_card=2,
                   values=[[0.5],
                           [0.5]],
                   state_names={'S': ['m', 'f']})
cpd_t = TabularCPD(variable='T',
                   variable_card=2,
                   values=[[0.25, 0.75],
                           [0.75, 0.25]],
                   evidence=['S'],
                   evidence_card=[2],
                   state_names={'S':['m', 'f'], 'T':[0, 1]})
cpd_c = TabularCPD(variable='C',
                   variable_card=2,
                   values=[[0.3, 0.4, 0.7, 0.8],
                           [0.7, 0.6, 0.3, 0.2]],
                   evidence=['S', 'T'],
                   evidence_card=[2, 2],
                   state_names={'S': ['m', 'f'], 'T': [0, 1], 'C': [0, 1]})

simp_model.add_cpds(cpd_s, cpd_t, cpd_c)

6.1.2. Inference conditioning on T

[4]:
# Non adjusted inference
infer_non_adjust = VariableElimination(simp_model)
print(infer_non_adjust.query(variables=['C'], evidence={'T': 1}))
print(infer_non_adjust.query(variables=['C'], evidence={'T': 0}))
+------+----------+
| C    |   phi(C) |
+======+==========+
| C(0) |   0.5000 |
+------+----------+
| C(1) |   0.5000 |
+------+----------+
+------+----------+
| C    |   phi(C) |
+======+==========+
| C(0) |   0.6000 |
+------+----------+
| C(1) |   0.4000 |
+------+----------+

6.1.3. Inference with do-operation on T

[5]:
infer_adjusted = CausalInference(simp_model)
print(infer_adjusted.query(variables=['C'], do={'T': 1}))
print(infer_adjusted.query(variables=['C'], do={'T': 0}))
+------+----------+
| C    |   phi(C) |
+======+==========+
| C(0) |   0.6000 |
+------+----------+
| C(1) |   0.4000 |
+------+----------+
+------+----------+
| C    |   phi(C) |
+======+==========+
| C(0) |   0.5000 |
+------+----------+
| C(1) |   0.5000 |
+------+----------+

6.2. Specifying adjustment sets

[6]:
model = BayesianNetwork([('X', 'Y'), ('Z', 'X'), ('Z', 'W'), ('W', 'Y')])
cpd_z = TabularCPD(variable='Z',
                   variable_card=2,
                   values=[[0.2],
                           [0.8]])

cpd_x = TabularCPD(variable='X',
                   variable_card=2,
                   values=[[0.1, 0.3],
                           [0.9, 0.7]],
                   evidence=['Z'],
                   evidence_card=[2])

cpd_w = TabularCPD(variable='W',
                   variable_card=2,
                   values=[[0.2, 0.9],
                           [0.8, 0.1]],
                   evidence=['Z'],
                   evidence_card=[2])

cpd_y = TabularCPD(variable='Y',
                   variable_card=2,
                   values=[[0.3, 0.4, 0.7, 0.8],
                           [0.7, 0.6, 0.3, 0.2]],
                   evidence=['X', 'W'],
                   evidence_card=[2, 2])

model.add_cpds(cpd_z, cpd_x, cpd_w, cpd_y)

model.to_daft(node_pos={'X': (0, 0), 'Y': (2, 0), 'Z': (0, 2), 'W': (2, 2)}).render()
[6]:
<matplotlib.axes._axes.Axes at 0x7f1760a8ce20>
../_images/examples_Causal_Inference_11_1.png
[7]:
# Do operation with a specified adjustment set.
infer = CausalInference(model)
do_X_W = infer.query(['Y'], do={'X': 1}, adjustment_set=['W'])
print(do_X_W)

do_X_Z = infer.query(['Y'], do={'X': 1}, adjustment_set=['Z'])
print(do_X_Z)

do_X_WZ = infer.query(['Y'], do={'X': 1}, adjustment_set=['W', 'Z'])
print(do_X_WZ)

infer_simp = CausalInference(simp_model)
do_simpson = infer_simp.query(['C'], do={'T': 1}, adjustment_set=['S'])
print(do_simpson)
+------+----------+
| Y    |   phi(Y) |
+======+==========+
| Y(0) |   0.7240 |
+------+----------+
| Y(1) |   0.2760 |
+------+----------+
+------+----------+
| Y    |   phi(Y) |
+======+==========+
| Y(0) |   0.7240 |
+------+----------+
| Y(1) |   0.2760 |
+------+----------+
+------+----------+
| Y    |   phi(Y) |
+======+==========+
| Y(0) |   0.7240 |
+------+----------+
| Y(1) |   0.2760 |
+------+----------+
+------+----------+
| C    |   phi(C) |
+======+==========+
| C(0) |   0.6000 |
+------+----------+
| C(1) |   0.4000 |
+------+----------+
[8]:
# Adjustment without do operation.
infer = CausalInference(model)
adj_W = infer.query(['Y'], adjustment_set=['W'])
print(adj_W)

adj_Z = infer.query(['Y'], adjustment_set=['Z'])
print(adj_Z)

adj_WZ = infer.query(['Y'], adjustment_set=['W', 'Z'])
print(adj_WZ)

infer_simp = CausalInference(simp_model)
adj_simpson = infer_simp.query(['C'], adjustment_set=['S'])
print(adj_simpson)
+------+----------+
| Y    |   phi(Y) |
+======+==========+
| Y(0) |   0.6200 |
+------+----------+
| Y(1) |   0.3800 |
+------+----------+
+------+----------+
| Y    |   phi(Y) |
+======+==========+
| Y(0) |   0.6200 |
+------+----------+
| Y(1) |   0.3800 |
+------+----------+
+------+----------+
| Y    |   phi(Y) |
+======+==========+
| Y(0) |   0.6200 |
+------+----------+
| Y(1) |   0.3800 |
+------+----------+
+------+----------+
| C    |   phi(C) |
+======+==========+
| C(0) |   0.5500 |
+------+----------+
| C(1) |   0.4500 |
+------+----------+