{ "cells": [ { "cell_type": "markdown", "id": "55155365-701c-41e6-9079-51fb02bb2036", "metadata": {}, "source": [ "# Simulating Data From Bayesian Networks\n", "\n", "pgmpy implements the `DiscreteBayesianNetwork.simulate` method to allow users to simulate data from a fully defined Bayesian Network under various conditions. These conditions can be any combination of:\n", "1. Virtual Evidence\n", "2. Hard Evidence\n", "3. Virtual Intervention\n", "4. Hard Intervention\n", "\n", "Lastly, users can also provide data corresponding to some of the variables in the model (partial samples) to the simulation method. This allows users to fix the values of those variables to the specified value." ] }, { "cell_type": "code", "execution_count": 1, "id": "2feeb277", "metadata": {}, "outputs": [], "source": [ "# A helper function to compute probability distributions from simulated samples.\n", "def get_distribution(samples, variables=None):\n", " \"\"\"\n", " For marginal distribution, P(A): get_distribution(samples, variables=['A'])\n", " For joint distribution, P(A, B): get_distribution(samples, variables=['A', 'B'])\n", " \"\"\"\n", " if variables is None:\n", " raise ValueError(\"variables must be specified\")\n", "\n", " return samples.groupby(variables).size() / samples.shape[0]" ] }, { "cell_type": "code", "execution_count": 2, "id": "cdb2e2f1", "metadata": {}, "outputs": [], "source": [ "# Do not print warnings\n", "import logging\n", "from pgmpy.global_vars import logger\n", "logger.setLevel(logging.ERROR)\n", "\n", "# Specify the model to simulate data from.\n", "from pgmpy.factors.discrete import TabularCPD\n", "from pgmpy.utils import get_example_model\n", "\n", "import numpy as np\n", "import pandas as pd\n", "\n", "alarm = get_example_model(\"alarm\")" ] }, { "cell_type": "markdown", "id": "6c064112", "metadata": {}, "source": [ "## 1. Standard simulation\n", "\n", "Without any specified conditions for simulation, the `simulate` method draws samples from the joint distribution of the model." ] }, { "cell_type": "code", "execution_count": 3, "id": "94e02da7", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1d87045741e04846a89c99bdb785650b", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/37 [00:00\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
TPRPAPMINVOLHREKGEXPCO2DISCONNECTVENTMACHVENTLUNGLVEDVOLUMEHR...SHUNTVENTTUBEMINVOLSETLVFAILUREERRLOWOUTPUTHRBPFIO2BPHISTORYSTROKEVOLUME
0LOWNORMALZERONORMALLOWFALSENORMALZERONORMALHIGH...NORMALLOWNORMALFALSEFALSEHIGHNORMALLOWFALSENORMAL
1HIGHNORMALZEROHIGHLOWTRUENORMALZERONORMALHIGH...NORMALZERONORMALFALSEFALSEHIGHNORMALHIGHFALSENORMAL
2LOWNORMALZEROHIGHLOWFALSENORMALZERONORMALHIGH...NORMALLOWNORMALFALSEFALSEHIGHNORMALLOWFALSENORMAL
3LOWNORMALZEROHIGHLOWFALSENORMALZERONORMALHIGH...NORMALLOWNORMALFALSEFALSEHIGHNORMALLOWFALSENORMAL
4NORMALHIGHZEROHIGHLOWFALSENORMALZERONORMALHIGH...NORMALLOWNORMALFALSEFALSEHIGHNORMALNORMALFALSENORMAL
\n", "

5 rows × 37 columns

\n", "" ], "text/plain": [ " TPR PAP MINVOL HREKG EXPCO2 DISCONNECT VENTMACH VENTLUNG \\\n", "0 LOW NORMAL ZERO NORMAL LOW FALSE NORMAL ZERO \n", "1 HIGH NORMAL ZERO HIGH LOW TRUE NORMAL ZERO \n", "2 LOW NORMAL ZERO HIGH LOW FALSE NORMAL ZERO \n", "3 LOW NORMAL ZERO HIGH LOW FALSE NORMAL ZERO \n", "4 NORMAL HIGH ZERO HIGH LOW FALSE NORMAL ZERO \n", "\n", " LVEDVOLUME HR ... SHUNT VENTTUBE MINVOLSET LVFAILURE ERRLOWOUTPUT \\\n", "0 NORMAL HIGH ... NORMAL LOW NORMAL FALSE FALSE \n", "1 NORMAL HIGH ... NORMAL ZERO NORMAL FALSE FALSE \n", "2 NORMAL HIGH ... NORMAL LOW NORMAL FALSE FALSE \n", "3 NORMAL HIGH ... NORMAL LOW NORMAL FALSE FALSE \n", "4 NORMAL HIGH ... NORMAL LOW NORMAL FALSE FALSE \n", "\n", " HRBP FIO2 BP HISTORY STROKEVOLUME \n", "0 HIGH NORMAL LOW FALSE NORMAL \n", "1 HIGH NORMAL HIGH FALSE NORMAL \n", "2 HIGH NORMAL LOW FALSE NORMAL \n", "3 HIGH NORMAL LOW FALSE NORMAL \n", "4 HIGH NORMAL NORMAL FALSE NORMAL \n", "\n", "[5 rows x 37 columns]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "samples = alarm.simulate(n_samples=int(1e4))\n", "samples.head()" ] }, { "cell_type": "markdown", "id": "b6d6664d", "metadata": {}, "source": [ "## 2. Simulation under specified evidence\n", "\n", "Specifying hard evidence for some variables fixes their values to the specified evidence value during simulation." ] }, { "cell_type": "code", "execution_count": 4, "id": "6ebdeade", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8a54addf4b4f4f22b66f3b155c433192", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/10000 [00:00\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
TPRPAPMINVOLHREKGEXPCO2DISCONNECTVENTMACHVENTLUNGLVEDVOLUMEHR...SHUNTVENTTUBEMINVOLSETLVFAILUREERRLOWOUTPUTHRBPFIO2BPHISTORYSTROKEVOLUME
0NORMALNORMALZEROHIGHLOWFALSENORMALZERONORMALHIGH...NORMALLOWNORMALFALSEFALSEHIGHNORMALLOWFALSELOW
1LOWNORMALZEROHIGHLOWFALSENORMALZERONORMALHIGH...NORMALLOWNORMALFALSEFALSEHIGHNORMALLOWFALSELOW
2LOWNORMALZEROHIGHLOWFALSENORMALZERONORMALHIGH...NORMALLOWNORMALFALSEFALSEHIGHLOWLOWFALSENORMAL
3LOWNORMALZEROHIGHLOWFALSENORMALZERONORMALHIGH...NORMALLOWNORMALFALSEFALSEHIGHNORMALLOWFALSENORMAL
4LOWNORMALNORMALNORMALLOWFALSENORMALZERONORMALHIGH...HIGHLOWNORMALFALSEFALSEHIGHNORMALLOWFALSENORMAL
\n", "

5 rows × 37 columns

\n", "" ], "text/plain": [ " TPR PAP MINVOL HREKG EXPCO2 DISCONNECT VENTMACH VENTLUNG \\\n", "0 NORMAL NORMAL ZERO HIGH LOW FALSE NORMAL ZERO \n", "1 LOW NORMAL ZERO HIGH LOW FALSE NORMAL ZERO \n", "2 LOW NORMAL ZERO HIGH LOW FALSE NORMAL ZERO \n", "3 LOW NORMAL ZERO HIGH LOW FALSE NORMAL ZERO \n", "4 LOW NORMAL NORMAL NORMAL LOW FALSE NORMAL ZERO \n", "\n", " LVEDVOLUME HR ... SHUNT VENTTUBE MINVOLSET LVFAILURE ERRLOWOUTPUT \\\n", "0 NORMAL HIGH ... NORMAL LOW NORMAL FALSE FALSE \n", "1 NORMAL HIGH ... NORMAL LOW NORMAL FALSE FALSE \n", "2 NORMAL HIGH ... NORMAL LOW NORMAL FALSE FALSE \n", "3 NORMAL HIGH ... NORMAL LOW NORMAL FALSE FALSE \n", "4 NORMAL HIGH ... HIGH LOW NORMAL FALSE FALSE \n", "\n", " HRBP FIO2 BP HISTORY STROKEVOLUME \n", "0 HIGH NORMAL LOW FALSE LOW \n", "1 HIGH NORMAL LOW FALSE LOW \n", "2 HIGH LOW LOW FALSE NORMAL \n", "3 HIGH NORMAL LOW FALSE NORMAL \n", "4 HIGH NORMAL LOW FALSE NORMAL \n", "\n", "[5 rows x 37 columns]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "evidence = {\"CVP\": \"NORMAL\", \"HR\": \"HIGH\"}\n", "samples = alarm.simulate(n_samples=int(1e4), evidence=evidence)\n", "samples.head()" ] }, { "cell_type": "code", "execution_count": 5, "id": "2587ee03-7466-4bd2-8f7b-9d7e32780637", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "True\n", "True\n" ] } ], "source": [ "# All values of HR and CVP should be set to HIGH and NORMAL respectively.\n", "print(all(samples.HR == \"HIGH\"))\n", "print(all(samples.CVP == \"NORMAL\"))" ] }, { "cell_type": "markdown", "id": "e58aaee7", "metadata": {}, "source": [ "## 3. Simulation under soft/virtual evidence\n", "\n", "Unlike hard evidence where the value of the specified variables is fixed to the specified evidence, virtual evidence allows users to set the marginal distribution of variables." ] }, { "cell_type": "code", "execution_count": 6, "id": "edb1b6dd", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "bb0f2d625fdd4148b6bf35f760366ec2", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/10000 [00:00\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
TPRPAPMINVOLHREKGEXPCO2DISCONNECTVENTMACHVENTLUNGLVEDVOLUMEHR...SHUNTVENTTUBEMINVOLSETLVFAILUREERRLOWOUTPUTHRBPFIO2BPHISTORYSTROKEVOLUME
0HIGHNORMALZEROHIGHLOWFALSENORMALZEROLOWHIGH...NORMALLOWNORMALFALSEFALSENORMALNORMALHIGHFALSENORMAL
1NORMALNORMALZEROHIGHLOWFALSENORMALZEROLOWHIGH...NORMALLOWNORMALTRUEFALSEHIGHNORMALLOWFALSELOW
2NORMALNORMALZEROHIGHLOWFALSENORMALZEROLOWHIGH...NORMALLOWNORMALTRUEFALSEHIGHNORMALLOWTRUELOW
3LOWNORMALZEROHIGHLOWFALSENORMALZERONORMALHIGH...NORMALLOWNORMALFALSEFALSEHIGHLOWLOWFALSENORMAL
4NORMALNORMALZEROHIGHHIGHFALSENORMALZEROHIGHHIGH...NORMALLOWNORMALFALSEFALSEHIGHNORMALNORMALFALSELOW
\n", "

5 rows × 37 columns

\n", "" ], "text/plain": [ " TPR PAP MINVOL HREKG EXPCO2 DISCONNECT VENTMACH VENTLUNG LVEDVOLUME \\\n", "0 HIGH NORMAL ZERO HIGH LOW FALSE NORMAL ZERO LOW \n", "1 NORMAL NORMAL ZERO HIGH LOW FALSE NORMAL ZERO LOW \n", "2 NORMAL NORMAL ZERO HIGH LOW FALSE NORMAL ZERO LOW \n", "3 LOW NORMAL ZERO HIGH LOW FALSE NORMAL ZERO NORMAL \n", "4 NORMAL NORMAL ZERO HIGH HIGH FALSE NORMAL ZERO HIGH \n", "\n", " HR ... SHUNT VENTTUBE MINVOLSET LVFAILURE ERRLOWOUTPUT HRBP \\\n", "0 HIGH ... NORMAL LOW NORMAL FALSE FALSE NORMAL \n", "1 HIGH ... NORMAL LOW NORMAL TRUE FALSE HIGH \n", "2 HIGH ... NORMAL LOW NORMAL TRUE FALSE HIGH \n", "3 HIGH ... NORMAL LOW NORMAL FALSE FALSE HIGH \n", "4 HIGH ... NORMAL LOW NORMAL FALSE FALSE HIGH \n", "\n", " FIO2 BP HISTORY STROKEVOLUME \n", "0 NORMAL HIGH FALSE NORMAL \n", "1 NORMAL LOW FALSE LOW \n", "2 NORMAL LOW TRUE LOW \n", "3 LOW LOW FALSE NORMAL \n", "4 NORMAL NORMAL FALSE LOW \n", "\n", "[5 rows x 37 columns]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "samples = alarm.simulate(n_samples=int(1e4), do={\"CVP\": \"NORMAL\", \"HR\": \"HIGH\"})\n", "samples.head()" ] }, { "cell_type": "markdown", "id": "4b9fa901", "metadata": {}, "source": [ "## 5. Simulation under soft/virtual intervention\n", "\n", "Similar to virtual evidence, users can specify virtual intervention as well." ] }, { "cell_type": "code", "execution_count": 9, "id": "eab6d41b", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "215b2f38eada4346ac21fd581671b37c", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/10000 [00:00