{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Structure Learning in Bayesian Networks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this notebook, we show a few examples of Causal Discovery or Structure Learning in pgmpy. pgmpy currently has the following algorithm for causal discovery:\n", "\n", "1. **PC**: Has $3$ variants original, stable, and parallel. PC is a constraint-based algorithm that utilizes Conditional Independence tests to construct the model.\n", "2. **Hill-Climb Search**: Hill-Climb Search is a greedy optimization-based algorithm that makes iterative local changes to the model structure such that it improves the overall score of the model.\n", "3. **Greedy Equivalence Search (GES)**: Another score-based method that makes greedy modifications to the model to improve its score iteratively.\n", "4. **ExpertInLoop**: An iterative algorithm that combines Conditional Independence testing with expert knowledge. The user or an LLM can act as the expert.\n", "5. **Exhaustive Search**: Exhaustive search iterates over all possible network structures on the given variables to find the most optimal one. As it tries to enumerate all possible network structures, it is intractable when the number of variables in the data is large.\n", "\n", "The following Conditional Independence Tests are available to use with PC algorithm.\n", "1. **Discrete Data**: When all variables are discrete/categorical.\n", " 1. **Chi-square test**: `ci_test=\"chi_square\"`\n", " 2. **G-squared**: `ci_test=\"g_sq\"`\n", " 3. **Log-likelihood**: Is equivalent to G-squared test. `ci_test=\"log_likelihood`\n", "2. **Continuous Data**: When all variables are continuous/numerical.\n", " 1. **Partial Correlation**: `ci_test=\"pearsonr\"`\n", "3. **Mixed Data**: When there is a mix of categorical and continuous variables.\n", " 1. **Pillai**: `ci_test=\"pillai\"`\n", "\n", "For Hill-Climb, Exhausitive Search, and GES the following scoring methods can be used:\n", "1. **Discrete Data**: When all variables are discrete/categorical. \n", " 1. **BIC Score**: `scoring_method=\"bic-d\"`\n", " 2. **AIC Score**: `scoring_method=\"aic-d\"`\n", " 3. **K2 Score**: `scoring_method=\"k2\"`\n", " 4. **BDeU Score**: `scoring_method=\"bdeu\"`\n", " 5. **BDs Score**: `scoring_method=\"bds\"`\n", "2. **Continuous Data**: When all variables are continuous/numerical.\n", " 1. **Log-Likelihood**: `scoring_method=\"ll-g\"`\n", " 2. **AIC**: `scoring_method=\"aic-g\"`\n", " 3. **BIC**: `scoring_method=\"bic-g\"`\n", "3. **Mixed Data**: When there is a mix of discrete and continuous variables.\n", " 1. **AIC**: `scoring_method=\"aic-cg\"`\n", " 2. **BIC**: `scoring_method=\"bic-cg\"`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 0. Simulate some sample datasets" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from itertools import combinations\n", "\n", "import networkx as nx\n", "import numpy as np\n", "from sklearn.metrics import f1_score\n", "\n", "from pgmpy.estimators import PC, HillClimbSearch, GES\n", "from pgmpy.utils import get_example_model\n", "from pgmpy.metrics import SHD" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "aa35b1a7897646d7ac141e27e04a6165", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/37 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: -2.220446049250313e-16. Adjusting values.\n" ] }, { "data": { "text/html": [ "
\n", " | b1191 | \n", "cspG | \n", "eutG | \n", "fixC | \n", "cspA | \n", "yecO | \n", "yedE | \n", "sucA | \n", "cchB | \n", "yceP | \n", "... | \n", "dnaK | \n", "folK | \n", "ycgX | \n", "lacZ | \n", "nuoM | \n", "dnaG | \n", "b1583 | \n", "mopB | \n", "yaeM | \n", "ftsJ | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "1.060641 | \n", "2.044477 | \n", "0.341216 | \n", "1.448399 | \n", "-0.351716 | \n", "2.189750 | \n", "-1.689554 | \n", "-0.228456 | \n", "2.871002 | \n", "-0.433597 | \n", "... | \n", "1.522817 | \n", "1.645983 | \n", "1.595502 | \n", "2.465247 | \n", "-0.532987 | \n", "1.126289 | \n", "-0.302589 | \n", "0.773483 | \n", "3.884857 | \n", "0.729557 | \n", "
1 | \n", "0.632151 | \n", "0.964321 | \n", "0.830229 | \n", "0.696598 | \n", "0.639204 | \n", "0.058108 | \n", "-0.736189 | \n", "0.712095 | \n", "1.467498 | \n", "0.320727 | \n", "... | \n", "1.222602 | \n", "1.790727 | \n", "1.763590 | \n", "2.945772 | \n", "-2.532464 | \n", "1.460699 | \n", "2.732595 | \n", "0.097982 | \n", "2.566064 | \n", "0.652853 | \n", "
2 | \n", "0.585766 | \n", "2.862437 | \n", "0.922291 | \n", "0.370000 | \n", "0.723932 | \n", "2.487161 | \n", "-1.916624 | \n", "-0.300359 | \n", "2.050980 | \n", "-0.064301 | \n", "... | \n", "0.599305 | \n", "1.302091 | \n", "0.509717 | \n", "3.090268 | \n", "-1.745613 | \n", "0.168043 | \n", "0.851346 | \n", "-0.640472 | \n", "5.800712 | \n", "0.031888 | \n", "
3 | \n", "1.802866 | \n", "2.277038 | \n", "0.608559 | \n", "2.180283 | \n", "0.116453 | \n", "2.539035 | \n", "-1.656839 | \n", "-1.420540 | \n", "2.605192 | \n", "-0.160302 | \n", "... | \n", "2.015663 | \n", "2.823588 | \n", "2.101625 | \n", "3.521299 | \n", "-1.212391 | \n", "1.485369 | \n", "2.449296 | \n", "2.032116 | \n", "3.888917 | \n", "3.036650 | \n", "
4 | \n", "1.868548 | \n", "2.480999 | \n", "1.079364 | \n", "2.413862 | \n", "0.961743 | \n", "2.280195 | \n", "-1.740610 | \n", "-0.472692 | \n", "3.338063 | \n", "-0.509144 | \n", "... | \n", "1.142139 | \n", "1.251832 | \n", "-0.145789 | \n", "2.057217 | \n", "-2.862230 | \n", "-0.089721 | \n", "1.020832 | \n", "0.040589 | \n", "4.925848 | \n", "0.575977 | \n", "
5 rows × 46 columns
\n", "