{
"cells": [
{
"cell_type": "markdown",
"id": "345065a4",
"metadata": {},
"source": [
"# Expert Knowledge Integration with Causal Discovery"
]
},
{
"cell_type": "markdown",
"id": "e1d422b1",
"metadata": {},
"source": [
"Causal Discovery (aka. Structure learning) seeks to uncover the causal relationships among random variables from observational datasets. Methods like the PC algorithm and Hill-climb Search can be used for causal discovery, however, they often struggle to recover the correct causal relationships due to issues such as data not satisfying the assumptions of the method, limited data can lead to error, high compute time when there are many variables or the model is dense.\n",
"\n",
"Incorporating expert knowledge can greatly improve the accuracy of these methods, as well as reduce the computational cost as it constrains the search space for these algorithms. pgmpy currently has the following options to specify expert knowledge:\n",
"- **Temporal Order / Tier information**: Specify the tiers/temporal order for sets of variables.\n",
"- **Search Space**: Specify the set of edges that the algorithms should search through.\n",
"- **Forbidden Edges**: Edges that should not be present in the final model.\n",
"- **Required Edges**: Edges that must be present in the final model.\n",
"\n",
"This tutorial walks through examples of specifying expert knowledge with the causal discovery algorithms implemented in pgmpy. Expert knowledge is supported in pgmpy with the following algorithms:\n",
"\n",
"- **PC**\n",
"- **Hill Climb Search**\n",
"- **GES**\n",
"- **ExpertInLoop**\n",
"\n",
"The models used in this notebook are described in detail at [the bnlearn repository](https://www.bnlearn.com/bnrepository/).\n"
]
},
{
"cell_type": "markdown",
"id": "d0d24cb4",
"metadata": {},
"source": [
"## Temporal Order with PC Algorithm"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "33f78cd0",
"metadata": {},
"outputs": [],
"source": [
"# Imports and configuration\n",
"import logging\n",
"import itertools\n",
"\n",
"from IPython.display import display\n",
"from pgmpy.utils import get_example_model\n",
"from pgmpy.global_vars import logger, config\n",
"\n",
"logger.setLevel(logging.ERROR)\n",
"config.set_backend(\"numpy\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "4f708849",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f12b3917ef2e45f9b0f0de13fbd37c2d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/5 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
">"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Load an example model and simulate data from it.\n",
"cancer = get_example_model(\"cancer\")\n",
"cancer_samples = cancer.simulate(n_samples=5000, seed=42)\n",
"\n",
"# Plot the `cancer` model\n",
"diag = cancer.to_graphviz()\n",
"diag.layout(prog=\"dot\")\n",
"display(diag)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "4b1a52da",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "22ce028ac25e4fb28411fa5bc2ac1997",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/5 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
">"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Run the standard PC algorithm and plot the results\n",
"\n",
"from pgmpy.estimators import PC, ExpertKnowledge\n",
"\n",
"learned_model = PC(cancer_samples).estimate()\n",
"\n",
"diag = learned_model.to_graphviz()\n",
"diag.layout(prog=\"dot\")\n",
"display(diag)"
]
},
{
"cell_type": "markdown",
"id": "7340f704",
"metadata": {},
"source": [
"As seen from the learned model, the PC algorithm was not able to correctly recover the original data-generating DAG. \n",
"\n",
"Now, we know that smoking (`Smoker`) and air pollution (`Pollution`) contribute to lung cancer (`Cancer`). Also, breathing difficulties (`Dyspnoea`) and a positive X-ray report (`Xray`) contribute to a cancer diagnosis (effect), so they cannot be the initial causes in this scenario. Based on this, we can craft a temporal order where the causes of cancer are at a higher temporal order, while the effects are lumped together at the end."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "66b91a8d",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e26c8f630ac94230b2a3d30c64511087",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/5 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
">"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# PC algorithm with the temporal knowledge.\n",
"\n",
"expert_knowledge = ExpertKnowledge(\n",
" temporal_order=[[\"Pollution\", \"Smoker\"], [\"Cancer\"], [\"Dyspnoea\", \"Xray\"]]\n",
")\n",
"\n",
"est_model = PC(cancer_samples).estimate(expert_knowledge=expert_knowledge)\n",
"\n",
"diag = est_model.to_graphviz()\n",
"diag.layout(prog=\"dot\")\n",
"display(diag)"
]
},
{
"cell_type": "markdown",
"id": "d943d672",
"metadata": {},
"source": [
"Note: Using any form of expert knowledge, especially with PC, is sensitive to the accuracy of the knowledge."
]
},
{
"cell_type": "markdown",
"id": "705df5c8",
"metadata": {},
"source": [
"## Expert knowledge based on search space "
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "129dbf0c",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "22c18d7324bf42158d08524faac44a58",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/8 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
">"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Load the Asia model and simulate data (https://www.bnlearn.com/bnrepository/discrete-small.html#asia)\n",
"asia = get_example_model(\"asia\")\n",
"asia_samples = asia.simulate(n_samples=10000, seed=42)\n",
"\n",
"diag = asia.to_graphviz()\n",
"diag.layout(prog=\"dot\")\n",
"display(diag)"
]
},
{
"cell_type": "markdown",
"id": "f88e492f",
"metadata": {},
"source": [
"The idea with the search space is that it is usually difficult to distinguish between direct and indirect effects. For example, in the above Asia model, we might know that `smoke` would have a causal effect on `dysp` but might not exactly know through which pathways. With the search space, we try to specify a hypergraph over our variables such that we specify a causal edge between any two variables that we think should have a causal relationship, but aren't sure about the causal path/edges.\n",
"\n",
"As an example, for the `asia` model, we can consider additional edges from all three diseases (`tub`, `lung`, and `bronc`) to both the symptoms (`xray` and `dysp`). "
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "56a67b5a",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5d011a252da542e2b86c84964a3635bd",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/5 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"8"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from pgmpy.metrics import SHD\n",
"\n",
"search_space = [\n",
" (\"smoke\", \"cancer\"),\n",
" (\"smoke\", \"lung\"),\n",
" (\"smoke\", \"tub\"),\n",
" (\"tub\", \"asia\"),\n",
" (\"asia\", \"tub\"),\n",
" (\"tub\", \"either\"),\n",
" (\"lung\", \"either\"),\n",
" (\"bronc\", \"either\"),\n",
" (\"tub\", \"xray\"),\n",
" (\"tub\", \"dysp\"),\n",
" (\"lung\", \"x-ray\"),\n",
" (\"lung\", \"dysp\"),\n",
" (\"bronc\", \"xray\"),\n",
" (\"bronc\", \"dysp\"),\n",
" (\"either\", \"xray\"),\n",
" (\"either\", \"dysp\"),\n",
" (\"tub\", \"dysp\"),\n",
"]\n",
"\n",
"expert_knowledge = ExpertKnowledge(search_space=search_space)\n",
"\n",
"est_model = PC(asia_samples).estimate(expert_knowledge=expert_knowledge, ci_test='pillai', enforce_expert_knowledge=True)\n",
"SHD(asia, est_model)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "40b488fa",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
">"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Plot the estimated model\n",
"\n",
"diag = est_model.to_graphviz()\n",
"diag.layout(prog=\"dot\")\n",
"display(diag)"
]
},
{
"cell_type": "markdown",
"id": "3cdc03f6",
"metadata": {},
"source": [
"## Knowledge base of required and forbidden edges"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "b81d935b",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c2e6b6fc46b547369d2ec99db880478b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/20 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
">"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"child = get_example_model(\"child\")\n",
"child_samples = child.simulate(n_samples=10000, seed=42)\n",
"\n",
"diag = child.to_graphviz()\n",
"diag.layout(prog=\"dot\")\n",
"display(diag)"
]
},
{
"cell_type": "markdown",
"id": "9c51da29",
"metadata": {},
"source": [
"Based on similar reasoning as in the first section, we can now create a knowledge base consisting of forbidden and required edges. For example, having birth asphyxia ('Disease') directly contributes to the 6 subclassses/markers of disease ('LVH', 'LungParench' etc.). It is the root cause, so it cannot have incoming edges from any of these nodes. Similarly, the CO2 level ('CO2') affects the CO2 report ('CO2Report'), and not the other way around."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "03bcfb85",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1697474157f046a6b2cbe6f71f7fcde1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/5 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"ename": "KeyError",
"evalue": "frozenset({'Disease', 'CO2'})",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[9], line 11\u001b[0m\n\u001b[1;32m 6\u001b[0m forbidden_edges \u001b[38;5;241m=\u001b[39m [(u, v) \u001b[38;5;28;01mfor\u001b[39;00m u, v \u001b[38;5;129;01min\u001b[39;00m itertools\u001b[38;5;241m.\u001b[39mcombinations(child\u001b[38;5;241m.\u001b[39mnodes(), \u001b[38;5;241m2\u001b[39m) \u001b[38;5;28;01mif\u001b[39;00m v \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mDisease\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[1;32m 8\u001b[0m expert_knowledge \u001b[38;5;241m=\u001b[39m ExpertKnowledge(\n\u001b[1;32m 9\u001b[0m required_edges\u001b[38;5;241m=\u001b[39mrequired_edges, forbidden_edges\u001b[38;5;241m=\u001b[39mforbidden_edges\n\u001b[1;32m 10\u001b[0m )\n\u001b[0;32m---> 11\u001b[0m est_model \u001b[38;5;241m=\u001b[39m \u001b[43mPC\u001b[49m\u001b[43m(\u001b[49m\u001b[43mchild_samples\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mestimate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mexpert_knowledge\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mexpert_knowledge\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mci_test\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mpillai\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43menforce_expert_knowledge\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 12\u001b[0m SHD(child, est_model)\n",
"File \u001b[0;32m~/pgmpy/examples/pgmpy/estimators/PC.py:241\u001b[0m, in \u001b[0;36mPC.estimate\u001b[0;34m(self, variant, ci_test, return_type, significance_level, max_cond_vars, expert_knowledge, enforce_expert_knowledge, n_jobs, show_progress, **kwargs)\u001b[0m\n\u001b[1;32m 238\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m skel, separating_sets\n\u001b[1;32m 240\u001b[0m \u001b[38;5;66;03m# Step 2: Orient the edges based on collider structures.\u001b[39;00m\n\u001b[0;32m--> 241\u001b[0m pdag \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43morient_colliders\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 242\u001b[0m \u001b[43m \u001b[49m\u001b[43mskel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mseparating_sets\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mexpert_knowledge\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtemporal_ordering\u001b[49m\n\u001b[1;32m 243\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 245\u001b[0m \u001b[38;5;66;03m# Step 3: Either return the CPDAG, integrate expert knowledge or fully orient the edges to build a DAG.\u001b[39;00m\n\u001b[1;32m 246\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m expert_knowledge\u001b[38;5;241m.\u001b[39mtemporal_order \u001b[38;5;241m!=\u001b[39m [[]]:\n",
"File \u001b[0;32m~/pgmpy/examples/pgmpy/estimators/PC.py:548\u001b[0m, in \u001b[0;36mPC.orient_colliders\u001b[0;34m(skeleton, separating_sets, temporal_ordering)\u001b[0m\n\u001b[1;32m 546\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m skeleton\u001b[38;5;241m.\u001b[39mhas_edge(X, Y):\n\u001b[1;32m 547\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m Z \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mset\u001b[39m(skeleton\u001b[38;5;241m.\u001b[39mneighbors(X)) \u001b[38;5;241m&\u001b[39m \u001b[38;5;28mset\u001b[39m(skeleton\u001b[38;5;241m.\u001b[39mneighbors(Y)):\n\u001b[0;32m--> 548\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m Z \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m \u001b[43mseparating_sets\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;28;43mfrozenset\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mY\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m:\n\u001b[1;32m 549\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (temporal_ordering \u001b[38;5;241m==\u001b[39m \u001b[38;5;28mdict\u001b[39m()) \u001b[38;5;129;01mor\u001b[39;00m (\n\u001b[1;32m 550\u001b[0m (temporal_ordering[Z] \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m temporal_ordering[X])\n\u001b[1;32m 551\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m (temporal_ordering[Z] \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m temporal_ordering[Y])\n\u001b[1;32m 552\u001b[0m ):\n\u001b[1;32m 553\u001b[0m pdag\u001b[38;5;241m.\u001b[39mremove_edges_from([(Z, X), (Z, Y)])\n",
"\u001b[0;31mKeyError\u001b[0m: frozenset({'Disease', 'CO2'})"
]
}
],
"source": [
"required_edges = [(\"CO2\", \"CO2Report\"),\n",
" (\"ChestXray\", \"XrayReport\"),\n",
" (\"LVH\", \"LVHreport\"),\n",
" (\"Gruntin\", \"GruntingReport\")]\n",
"# As disease should be an exogenous node, do not allow any incoming edges except `BirthAsphyxia'.\n",
"forbidden_edges = [(u, v) for u, v in itertools.combinations(child.nodes(), 2) if v == 'Disease']\n",
"\n",
"expert_knowledge = ExpertKnowledge(\n",
" required_edges=required_edges, forbidden_edges=forbidden_edges\n",
")\n",
"est_model = PC(child_samples).estimate(expert_knowledge=expert_knowledge, ci_test='pillai', enforce_expert_knowledge=True)\n",
"SHD(child, est_model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d33b6f03",
"metadata": {},
"outputs": [],
"source": [
"# Plot the estimated model\n",
"\n",
"diag = est_model.to_graphviz()\n",
"diag.layout(prog=\"dot\")\n",
"display(diag)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}