{ "cells": [ { "cell_type": "markdown", "id": "55155365-701c-41e6-9079-51fb02bb2036", "metadata": {}, "source": [ "# Simulating Data From Bayesian Networks\n", "\n", "pgmpy implements the `DiscreteBayesianNetwork.simulate` method to allow users to simulate data from a fully defined Bayesian Network under various conditions. These conditions can be any combination of:\n", "1. Virtual Evidence\n", "2. Hard Evidence\n", "3. Virtual Intervention\n", "4. Hard Intervention\n", "\n", "Lastly, users can also provide data corresponding to some of the variables in the model (partial samples) to the simulation method. This allows users to fix the values of those variables to the specified value." ] }, { "cell_type": "code", "execution_count": 1, "id": "2feeb277", "metadata": {}, "outputs": [], "source": [ "# A helper function to compute probability distributions from simulated samples.\n", "def get_distribution(samples, variables=None):\n", " \"\"\"\n", " For marginal distribution, P(A): get_distribution(samples, variables=['A'])\n", " For joint distribution, P(A, B): get_distribution(samples, variables=['A', 'B'])\n", " \"\"\"\n", " if variables is None:\n", " raise ValueError(\"variables must be specified\")\n", "\n", " return samples.groupby(variables).size() / samples.shape[0]" ] }, { "cell_type": "code", "execution_count": 2, "id": "cdb2e2f1", "metadata": {}, "outputs": [], "source": [ "# Do not print warnings\n", "import logging\n", "from pgmpy.global_vars import logger\n", "logger.setLevel(logging.ERROR)\n", "\n", "# Specify the model to simulate data from.\n", "from pgmpy.factors.discrete import TabularCPD\n", "from pgmpy.utils import get_example_model\n", "\n", "import numpy as np\n", "import pandas as pd\n", "\n", "alarm = get_example_model(\"alarm\")" ] }, { "cell_type": "markdown", "id": "6c064112", "metadata": {}, "source": [ "## 1. Standard simulation\n", "\n", "Without any specified conditions for simulation, the `simulate` method draws samples from the joint distribution of the model." ] }, { "cell_type": "code", "execution_count": 3, "id": "94e02da7", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1d87045741e04846a89c99bdb785650b", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/37 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", " | TPR | \n", "PAP | \n", "MINVOL | \n", "HREKG | \n", "EXPCO2 | \n", "DISCONNECT | \n", "VENTMACH | \n", "VENTLUNG | \n", "LVEDVOLUME | \n", "HR | \n", "... | \n", "SHUNT | \n", "VENTTUBE | \n", "MINVOLSET | \n", "LVFAILURE | \n", "ERRLOWOUTPUT | \n", "HRBP | \n", "FIO2 | \n", "BP | \n", "HISTORY | \n", "STROKEVOLUME | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "LOW | \n", "NORMAL | \n", "ZERO | \n", "NORMAL | \n", "LOW | \n", "FALSE | \n", "NORMAL | \n", "ZERO | \n", "NORMAL | \n", "HIGH | \n", "... | \n", "NORMAL | \n", "LOW | \n", "NORMAL | \n", "FALSE | \n", "FALSE | \n", "HIGH | \n", "NORMAL | \n", "LOW | \n", "FALSE | \n", "NORMAL | \n", "
1 | \n", "HIGH | \n", "NORMAL | \n", "ZERO | \n", "HIGH | \n", "LOW | \n", "TRUE | \n", "NORMAL | \n", "ZERO | \n", "NORMAL | \n", "HIGH | \n", "... | \n", "NORMAL | \n", "ZERO | \n", "NORMAL | \n", "FALSE | \n", "FALSE | \n", "HIGH | \n", "NORMAL | \n", "HIGH | \n", "FALSE | \n", "NORMAL | \n", "
2 | \n", "LOW | \n", "NORMAL | \n", "ZERO | \n", "HIGH | \n", "LOW | \n", "FALSE | \n", "NORMAL | \n", "ZERO | \n", "NORMAL | \n", "HIGH | \n", "... | \n", "NORMAL | \n", "LOW | \n", "NORMAL | \n", "FALSE | \n", "FALSE | \n", "HIGH | \n", "NORMAL | \n", "LOW | \n", "FALSE | \n", "NORMAL | \n", "
3 | \n", "LOW | \n", "NORMAL | \n", "ZERO | \n", "HIGH | \n", "LOW | \n", "FALSE | \n", "NORMAL | \n", "ZERO | \n", "NORMAL | \n", "HIGH | \n", "... | \n", "NORMAL | \n", "LOW | \n", "NORMAL | \n", "FALSE | \n", "FALSE | \n", "HIGH | \n", "NORMAL | \n", "LOW | \n", "FALSE | \n", "NORMAL | \n", "
4 | \n", "NORMAL | \n", "HIGH | \n", "ZERO | \n", "HIGH | \n", "LOW | \n", "FALSE | \n", "NORMAL | \n", "ZERO | \n", "NORMAL | \n", "HIGH | \n", "... | \n", "NORMAL | \n", "LOW | \n", "NORMAL | \n", "FALSE | \n", "FALSE | \n", "HIGH | \n", "NORMAL | \n", "NORMAL | \n", "FALSE | \n", "NORMAL | \n", "
5 rows × 37 columns
\n", "\n", " | TPR | \n", "PAP | \n", "MINVOL | \n", "HREKG | \n", "EXPCO2 | \n", "DISCONNECT | \n", "VENTMACH | \n", "VENTLUNG | \n", "LVEDVOLUME | \n", "HR | \n", "... | \n", "SHUNT | \n", "VENTTUBE | \n", "MINVOLSET | \n", "LVFAILURE | \n", "ERRLOWOUTPUT | \n", "HRBP | \n", "FIO2 | \n", "BP | \n", "HISTORY | \n", "STROKEVOLUME | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "NORMAL | \n", "NORMAL | \n", "ZERO | \n", "HIGH | \n", "LOW | \n", "FALSE | \n", "NORMAL | \n", "ZERO | \n", "NORMAL | \n", "HIGH | \n", "... | \n", "NORMAL | \n", "LOW | \n", "NORMAL | \n", "FALSE | \n", "FALSE | \n", "HIGH | \n", "NORMAL | \n", "LOW | \n", "FALSE | \n", "LOW | \n", "
1 | \n", "LOW | \n", "NORMAL | \n", "ZERO | \n", "HIGH | \n", "LOW | \n", "FALSE | \n", "NORMAL | \n", "ZERO | \n", "NORMAL | \n", "HIGH | \n", "... | \n", "NORMAL | \n", "LOW | \n", "NORMAL | \n", "FALSE | \n", "FALSE | \n", "HIGH | \n", "NORMAL | \n", "LOW | \n", "FALSE | \n", "LOW | \n", "
2 | \n", "LOW | \n", "NORMAL | \n", "ZERO | \n", "HIGH | \n", "LOW | \n", "FALSE | \n", "NORMAL | \n", "ZERO | \n", "NORMAL | \n", "HIGH | \n", "... | \n", "NORMAL | \n", "LOW | \n", "NORMAL | \n", "FALSE | \n", "FALSE | \n", "HIGH | \n", "LOW | \n", "LOW | \n", "FALSE | \n", "NORMAL | \n", "
3 | \n", "LOW | \n", "NORMAL | \n", "ZERO | \n", "HIGH | \n", "LOW | \n", "FALSE | \n", "NORMAL | \n", "ZERO | \n", "NORMAL | \n", "HIGH | \n", "... | \n", "NORMAL | \n", "LOW | \n", "NORMAL | \n", "FALSE | \n", "FALSE | \n", "HIGH | \n", "NORMAL | \n", "LOW | \n", "FALSE | \n", "NORMAL | \n", "
4 | \n", "LOW | \n", "NORMAL | \n", "NORMAL | \n", "NORMAL | \n", "LOW | \n", "FALSE | \n", "NORMAL | \n", "ZERO | \n", "NORMAL | \n", "HIGH | \n", "... | \n", "HIGH | \n", "LOW | \n", "NORMAL | \n", "FALSE | \n", "FALSE | \n", "HIGH | \n", "NORMAL | \n", "LOW | \n", "FALSE | \n", "NORMAL | \n", "
5 rows × 37 columns
\n", "\n", " | TPR | \n", "PAP | \n", "MINVOL | \n", "HREKG | \n", "EXPCO2 | \n", "DISCONNECT | \n", "VENTMACH | \n", "VENTLUNG | \n", "LVEDVOLUME | \n", "HR | \n", "... | \n", "SHUNT | \n", "VENTTUBE | \n", "MINVOLSET | \n", "LVFAILURE | \n", "ERRLOWOUTPUT | \n", "HRBP | \n", "FIO2 | \n", "BP | \n", "HISTORY | \n", "STROKEVOLUME | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "HIGH | \n", "NORMAL | \n", "ZERO | \n", "HIGH | \n", "LOW | \n", "FALSE | \n", "NORMAL | \n", "ZERO | \n", "LOW | \n", "HIGH | \n", "... | \n", "NORMAL | \n", "LOW | \n", "NORMAL | \n", "FALSE | \n", "FALSE | \n", "NORMAL | \n", "NORMAL | \n", "HIGH | \n", "FALSE | \n", "NORMAL | \n", "
1 | \n", "NORMAL | \n", "NORMAL | \n", "ZERO | \n", "HIGH | \n", "LOW | \n", "FALSE | \n", "NORMAL | \n", "ZERO | \n", "LOW | \n", "HIGH | \n", "... | \n", "NORMAL | \n", "LOW | \n", "NORMAL | \n", "TRUE | \n", "FALSE | \n", "HIGH | \n", "NORMAL | \n", "LOW | \n", "FALSE | \n", "LOW | \n", "
2 | \n", "NORMAL | \n", "NORMAL | \n", "ZERO | \n", "HIGH | \n", "LOW | \n", "FALSE | \n", "NORMAL | \n", "ZERO | \n", "LOW | \n", "HIGH | \n", "... | \n", "NORMAL | \n", "LOW | \n", "NORMAL | \n", "TRUE | \n", "FALSE | \n", "HIGH | \n", "NORMAL | \n", "LOW | \n", "TRUE | \n", "LOW | \n", "
3 | \n", "LOW | \n", "NORMAL | \n", "ZERO | \n", "HIGH | \n", "LOW | \n", "FALSE | \n", "NORMAL | \n", "ZERO | \n", "NORMAL | \n", "HIGH | \n", "... | \n", "NORMAL | \n", "LOW | \n", "NORMAL | \n", "FALSE | \n", "FALSE | \n", "HIGH | \n", "LOW | \n", "LOW | \n", "FALSE | \n", "NORMAL | \n", "
4 | \n", "NORMAL | \n", "NORMAL | \n", "ZERO | \n", "HIGH | \n", "HIGH | \n", "FALSE | \n", "NORMAL | \n", "ZERO | \n", "HIGH | \n", "HIGH | \n", "... | \n", "NORMAL | \n", "LOW | \n", "NORMAL | \n", "FALSE | \n", "FALSE | \n", "HIGH | \n", "NORMAL | \n", "NORMAL | \n", "FALSE | \n", "LOW | \n", "
5 rows × 37 columns
\n", "