Simulating Data From Bayesian Networks

pgmpy implements the DiscreteBayesianNetwork.simulate method to allow users to simulate data from a fully defined Bayesian Network under various conditions. These conditions can be any combination of:

  1. Virtual Evidence

  2. Hard Evidence

  3. Virtual Intervention

  4. Hard Intervention

Users can also provide data corresponding to some of the variables in the model (partial samples) to the simulation method. This allows users to fix the values of those variables to the specified value.

Lastly, the user can also generate data with missing values, according to a user-defined CPD, to simulate realistic real-world data and evaluate how missingness affects inference.

[1]:
# A helper function to compute probability distributions from simulated samples.
def get_distribution(samples, variables=None):
    """
    For marginal distribution, P(A): get_distribution(samples, variables=['A'])
    For joint distribution, P(A, B): get_distribution(samples, variables=['A', 'B'])
    """
    if variables is None:
        raise ValueError("variables must be specified")

    return samples.groupby(variables, observed=False).size() / samples.shape[0]
[2]:
# Do not print warnings
import logging
from pgmpy.global_vars import logger
logger.setLevel(logging.ERROR)

# Specify the model to simulate data from.
from pgmpy.factors.discrete import TabularCPD
from pgmpy.utils import get_example_model

import numpy as np
import pandas as pd

alarm = get_example_model("alarm")

1. Standard simulation

Without any specified conditions for simulation, the simulate method draws samples from the joint distribution of the model.

[3]:
samples = alarm.simulate(n_samples=int(1e4))
samples.head()
[3]:
LVEDVOLUME MINVOLSET TPR HYPOVOLEMIA INTUBATION HR DISCONNECT ARTCO2 ANAPHYLAXIS LVFAILURE ... HRBP HRSAT PVSAT PAP FIO2 CO ERRLOWOUTPUT PULMEMBOLUS HISTORY HREKG
0 HIGH NORMAL NORMAL TRUE NORMAL HIGH TRUE HIGH FALSE FALSE ... HIGH HIGH NORMAL NORMAL NORMAL HIGH FALSE FALSE FALSE HIGH
1 LOW NORMAL NORMAL FALSE NORMAL HIGH TRUE LOW FALSE TRUE ... NORMAL HIGH NORMAL NORMAL LOW LOW TRUE FALSE FALSE HIGH
2 NORMAL NORMAL LOW FALSE NORMAL HIGH FALSE HIGH FALSE FALSE ... HIGH HIGH LOW NORMAL NORMAL HIGH FALSE FALSE FALSE HIGH
3 NORMAL NORMAL HIGH FALSE NORMAL HIGH FALSE HIGH FALSE FALSE ... HIGH HIGH LOW NORMAL NORMAL HIGH FALSE FALSE FALSE HIGH
4 NORMAL NORMAL LOW FALSE NORMAL HIGH FALSE HIGH FALSE FALSE ... HIGH HIGH LOW NORMAL NORMAL HIGH FALSE FALSE FALSE HIGH

5 rows × 37 columns

2. Simulation under specified evidence

Specifying hard evidence for some variables fixes their values to the specified evidence value during simulation.

[4]:
evidence = {"CVP": "NORMAL", "HR": "HIGH"}
samples = alarm.simulate(n_samples=int(1e4), evidence=evidence)
samples.head()
[4]:
LVEDVOLUME MINVOLSET TPR HYPOVOLEMIA INTUBATION HR DISCONNECT ARTCO2 ANAPHYLAXIS LVFAILURE ... HRBP HRSAT PVSAT PAP FIO2 CO ERRLOWOUTPUT PULMEMBOLUS HISTORY HREKG
0 NORMAL NORMAL NORMAL FALSE NORMAL HIGH FALSE LOW FALSE FALSE ... HIGH HIGH HIGH NORMAL NORMAL HIGH FALSE FALSE FALSE HIGH
1 NORMAL NORMAL NORMAL FALSE NORMAL HIGH FALSE HIGH FALSE FALSE ... NORMAL HIGH LOW NORMAL NORMAL LOW TRUE FALSE FALSE HIGH
2 NORMAL NORMAL HIGH FALSE NORMAL HIGH FALSE HIGH FALSE FALSE ... HIGH HIGH LOW NORMAL NORMAL HIGH FALSE FALSE FALSE HIGH
3 HIGH NORMAL LOW FALSE NORMAL HIGH FALSE HIGH FALSE FALSE ... HIGH HIGH LOW LOW NORMAL HIGH FALSE FALSE FALSE HIGH
4 HIGH NORMAL HIGH TRUE ONESIDED HIGH FALSE HIGH FALSE FALSE ... HIGH HIGH LOW HIGH NORMAL LOW FALSE FALSE FALSE HIGH

5 rows × 37 columns

[5]:
# All values of HR and CVP should be set to HIGH and NORMAL respectively.
print(all(samples.HR == "HIGH"))
print(all(samples.CVP == "NORMAL"))
True
True

3. Simulation under soft/virtual evidence

Unlike hard evidence where the value of the specified variables is fixed to the specified evidence, virtual evidence allows users to set the marginal distribution of variables.

[6]:
# The virtual evidence is specified using TabularCPDs. Here, P(CVP=NORMAL) = 0.2, P(CVP=LOW) = 0.3, and P(CPV=HIGH) = 0.5
cvp_evidence = TabularCPD(variable="CVP",
                          variable_card=3,
                          values=[[0.2], [0.3], [0.5]],
                          state_names={"CVP": ["LOW", "NORMAL", "HIGH"]})
samples = alarm.simulate(n_samples=int(1e4), virtual_evidence=[cvp_evidence])
[7]:
# Check the marginal distribution of CVP
get_distribution(samples, variables=['CVP'])
[7]:
CVP
HIGH      0.2445
LOW       0.0689
NORMAL    0.6866
dtype: float64

4. Simulation under specified intervention

Using the do argument, users can specify interventions to the model. The value of the intervened variables are set to the specified value and all incoming edges to these variables are removed in the model.

[8]:
samples = alarm.simulate(n_samples=int(1e4), do={"CVP": "NORMAL", "HR": "HIGH"})
samples.head()
[8]:
LVEDVOLUME MINVOLSET TPR HYPOVOLEMIA INTUBATION HR DISCONNECT ARTCO2 ANAPHYLAXIS LVFAILURE ... HRBP HRSAT PVSAT PAP FIO2 CO ERRLOWOUTPUT PULMEMBOLUS HISTORY HREKG
0 NORMAL NORMAL HIGH FALSE NORMAL HIGH FALSE LOW FALSE FALSE ... HIGH HIGH HIGH NORMAL NORMAL HIGH FALSE FALSE FALSE HIGH
1 HIGH NORMAL NORMAL TRUE NORMAL HIGH FALSE HIGH FALSE FALSE ... HIGH HIGH LOW HIGH NORMAL LOW FALSE FALSE FALSE HIGH
2 HIGH NORMAL LOW TRUE NORMAL HIGH FALSE HIGH FALSE FALSE ... HIGH HIGH LOW NORMAL NORMAL LOW FALSE FALSE FALSE HIGH
3 NORMAL NORMAL LOW FALSE NORMAL HIGH TRUE LOW TRUE FALSE ... HIGH HIGH HIGH LOW NORMAL HIGH FALSE FALSE FALSE HIGH
4 NORMAL NORMAL NORMAL FALSE NORMAL HIGH FALSE HIGH FALSE FALSE ... HIGH HIGH LOW NORMAL NORMAL HIGH FALSE FALSE FALSE HIGH

5 rows × 37 columns

5. Simulation under soft/virtual intervention

Similar to virtual evidence, users can specify virtual intervention as well.

[9]:
cvp_intervention = TabularCPD(variable="CVP",
                              variable_card=3,
                              values=[[0.2], [0.3], [0.5]],
                              state_names={"CVP": ["LOW", "NORMAL", "HIGH"]})
samples = alarm.simulate(n_samples=int(1e4), virtual_intervention=[cvp_intervention])
get_distribution(samples, variables=["CVP"])  # P(HISTORY)
[9]:
CVP
HIGH      0.3735
LOW       0.2165
NORMAL    0.4100
dtype: float64

6. Partial samples

Users can also pass already generated data for some variables (for example, from some other simulation) to the simulation. This is equivalent to separately specifying evidence for each sample that is generate.

[10]:
# Generate some data on CVP.
partial_cvp = pd.DataFrame(np.random.choice(["LOW", "NORMAL", "HIGH"], int(1e4)), columns=['CVP'])
samples = alarm.simulate(n_samples=int(1e4), partial_samples=partial_cvp)
[11]:
print(all(samples["CVP"] == partial_cvp["CVP"]))
True

7. Simulate missing data

Lastly, users can generate data with missing values for some specified variables, according to a user defined CPD. The name of the missing variable should be followed by a * to indicate missingness, and should contain 2 states: 1 (Missing) and 0 (Not Missing). Optionally, we can use the return_full argument to get back the removed values for comparison.

7.1. Missing completely at random (MCAR)

[12]:
# CVP data missing completely randomly with 0.4 probability
missing_CVP = TabularCPD(
    variable="CVP*",
    variable_card=2,
    values=[[0.6],
            [0.4]], # Missing probability = 0.4
    state_names={"CVP*": [0, 1]}
)

samples = alarm.simulate(n_samples=1000, missing_prob=[missing_CVP], return_full=True)
samples.head()
[12]:
LVEDVOLUME MINVOLSET TPR CVP_full HYPOVOLEMIA INTUBATION HR DISCONNECT ARTCO2 ANAPHYLAXIS ... HRBP HRSAT PVSAT PAP FIO2 CO ERRLOWOUTPUT PULMEMBOLUS HISTORY HREKG
0 NORMAL HIGH NORMAL NORMAL FALSE NORMAL NORMAL FALSE LOW FALSE ... LOW LOW HIGH NORMAL NORMAL LOW FALSE FALSE FALSE LOW
1 HIGH LOW LOW NORMAL TRUE NORMAL HIGH FALSE LOW FALSE ... HIGH HIGH HIGH NORMAL NORMAL LOW FALSE FALSE FALSE HIGH
2 HIGH NORMAL NORMAL NORMAL TRUE NORMAL HIGH FALSE HIGH FALSE ... HIGH HIGH LOW NORMAL LOW LOW FALSE FALSE FALSE HIGH
3 NORMAL NORMAL NORMAL NORMAL FALSE NORMAL HIGH FALSE LOW FALSE ... HIGH HIGH HIGH LOW NORMAL HIGH FALSE FALSE FALSE HIGH
4 NORMAL NORMAL NORMAL NORMAL FALSE NORMAL HIGH FALSE LOW FALSE ... NORMAL HIGH LOW NORMAL NORMAL HIGH TRUE FALSE FALSE HIGH

5 rows × 38 columns

[13]:
print(f"Missing values: {samples["CVP"].isna().sum()}/{len(samples["CVP"])}")
print()

print("Original Distribution:")
print(get_distribution(samples, variables="CVP_full"))
print()
print("Distribution of Missing/Removed")
print(get_distribution(samples.loc[samples["CVP"].isna()], variables="CVP_full")) # Since removal was completely random, we expect minimal change in distribution
Missing values: 434/1000

Original Distribution:
CVP_full
HIGH      0.158
LOW       0.115
NORMAL    0.727
dtype: float64

Distribution of Missing/Removed
CVP_full
HIGH      0.168203
LOW       0.117512
NORMAL    0.714286
dtype: float64

7.2. Missing at random (MAR)

[14]:
# CVP data missing depending on the observed LVEDVOLUME
missing_CVP = TabularCPD(
    variable="CVP*",
    variable_card=2,
    values=[[0.8, 0.2, 0.7],
            [0.2, 0.8, 0.3]], # Missing probabilities: LOW = 0.2, NORMAL = 0.8, HIGH = 0.3
    evidence=["LVEDVOLUME"],
    evidence_card=[3],
    state_names={
        "CVP*": [0, 1],
        "LVEDVOLUME": ["LOW", "NORMAL", "HIGH"]}
)

samples = alarm.simulate(n_samples=1000, missing_prob=[missing_CVP], return_full=True)
samples.head()
[14]:
LVEDVOLUME MINVOLSET TPR CVP_full HYPOVOLEMIA INTUBATION HR DISCONNECT ARTCO2 ANAPHYLAXIS ... HRBP HRSAT PVSAT PAP FIO2 CO ERRLOWOUTPUT PULMEMBOLUS HISTORY HREKG
0 NORMAL NORMAL HIGH NORMAL FALSE NORMAL NORMAL FALSE HIGH FALSE ... LOW LOW LOW NORMAL LOW NORMAL FALSE FALSE FALSE LOW
1 NORMAL NORMAL LOW NORMAL FALSE NORMAL HIGH FALSE HIGH FALSE ... HIGH HIGH LOW NORMAL NORMAL HIGH FALSE FALSE FALSE HIGH
2 NORMAL NORMAL NORMAL NORMAL FALSE NORMAL HIGH FALSE HIGH FALSE ... NORMAL HIGH LOW NORMAL NORMAL LOW TRUE FALSE FALSE HIGH
3 NORMAL NORMAL HIGH NORMAL FALSE ONESIDED HIGH FALSE HIGH FALSE ... HIGH HIGH LOW NORMAL NORMAL HIGH FALSE FALSE FALSE HIGH
4 NORMAL NORMAL HIGH NORMAL FALSE NORMAL HIGH FALSE HIGH FALSE ... HIGH HIGH LOW NORMAL NORMAL HIGH FALSE FALSE FALSE HIGH

5 rows × 38 columns

[15]:
print(f"Missing values: {samples["CVP"].isna().sum()}/{len(samples["CVP"])}")
print()

print("Original Distribution:")
print(get_distribution(samples, variables=["LVEDVOLUME", "CVP_full"]))
print()
print("Distribution of Missing/Removed")
print(get_distribution(samples.loc[samples["CVP"].isna()], variables=["LVEDVOLUME", "CVP_full"])) # Since probability of missing is higher when LVEDVOLUME is "NORMAL" we expect distribution to be higher values there, and lesser otherwise
Missing values: 656/1000

Original Distribution:
LVEDVOLUME  CVP_full
HIGH        HIGH        0.140
            LOW         0.000
            NORMAL      0.065
LOW         HIGH        0.000
            LOW         0.095
            NORMAL      0.003
NORMAL      HIGH        0.009
            LOW         0.026
            NORMAL      0.662
dtype: float64

Distribution of Missing/Removed
LVEDVOLUME  CVP_full
HIGH        HIGH        0.083841
            LOW         0.000000
            NORMAL      0.027439
LOW         HIGH        0.000000
            LOW         0.022866
            NORMAL      0.001524
NORMAL      HIGH        0.009146
            LOW         0.032012
            NORMAL      0.823171
dtype: float64

7.3 Missing not at random (MNAR)

[16]:
# CVP data missing depending on the unobserved original CVP value
missing_CVP = TabularCPD(
    variable="CVP*",
    variable_card=2,
    values=[[0.2, 0.4, 0.6],
            [0.8, 0.6, 0.4]], # Missing probabilities: LOW = 0.8, NORMAL = 0.6, HIGH = 0.4
    evidence=["CVP"],
    evidence_card=[3],
    state_names={
        "CVP*": [0, 1],
        "CVP": ["LOW", "NORMAL", "HIGH"]}
)

samples = alarm.simulate(n_samples=1000, missing_prob=[missing_CVP], return_full=True)
samples.head()
[16]:
LVEDVOLUME MINVOLSET TPR CVP_full HYPOVOLEMIA INTUBATION HR DISCONNECT ARTCO2 ANAPHYLAXIS ... HRBP HRSAT PVSAT PAP FIO2 CO ERRLOWOUTPUT PULMEMBOLUS HISTORY HREKG
0 NORMAL NORMAL NORMAL NORMAL FALSE NORMAL HIGH FALSE HIGH FALSE ... HIGH HIGH LOW NORMAL NORMAL HIGH FALSE FALSE FALSE HIGH
1 NORMAL NORMAL LOW NORMAL FALSE NORMAL HIGH TRUE LOW FALSE ... HIGH HIGH HIGH NORMAL NORMAL HIGH FALSE FALSE FALSE HIGH
2 NORMAL NORMAL HIGH NORMAL FALSE NORMAL HIGH FALSE HIGH FALSE ... HIGH HIGH LOW NORMAL NORMAL LOW FALSE FALSE FALSE HIGH
3 HIGH NORMAL NORMAL NORMAL TRUE NORMAL HIGH TRUE HIGH FALSE ... HIGH HIGH LOW NORMAL NORMAL NORMAL FALSE FALSE FALSE HIGH
4 HIGH NORMAL HIGH HIGH TRUE NORMAL NORMAL FALSE HIGH FALSE ... LOW LOW LOW NORMAL NORMAL LOW FALSE FALSE FALSE LOW

5 rows × 38 columns

[17]:
print(f"Missing values: {samples["CVP"].isna().sum()}/{len(samples["CVP"])}")
print()

print("Original Distribution:")
print(get_distribution(samples, variables="CVP_full"))
print()
print("Distribution of Missing/Removed")
print(get_distribution(samples.loc[samples["CVP"].isna()], variables="CVP_full")) # Since probability of missing is higher when CVP is "LOW" and lower when "CVP" is high we expect missing distribution for "LOW" to be greater, and for "HIGH" to be lower
Missing values: 605/1000

Original Distribution:
CVP_full
HIGH      0.166
LOW       0.130
NORMAL    0.704
dtype: float64

Distribution of Missing/Removed
CVP_full
HIGH      0.112397
LOW       0.173554
NORMAL    0.714050
dtype: float64