Simulating Data From Bayesian Networks#

pgmpy implements the DiscreteBayesianNetwork.simulate method to allow users to simulate data from a fully defined Bayesian Network under various conditions. These conditions can be any combination of:

  1. Virtual Evidence

  2. Hard Evidence

  3. Virtual Intervention

  4. Hard Intervention

Users can also provide data corresponding to some of the variables in the model (partial samples) to the simulation method. This allows users to fix the values of those variables to the specified value.

Lastly, the user can also generate data with missing values, according to a user-defined CPD, to simulate realistic real-world data and evaluate how missingness affects inference.

[1]:
# A helper function to compute probability distributions from simulated samples.
def get_distribution(samples, variables=None):
    """
    For marginal distribution, P(A): get_distribution(samples, variables=['A'])
    For joint distribution, P(A, B): get_distribution(samples, variables=['A', 'B'])
    """
    if variables is None:
        raise ValueError("variables must be specified")

    return samples.groupby(variables, observed=False).size() / samples.shape[0]
[2]:
# Do not print warnings
import logging
from pgmpy.global_vars import logger
logger.setLevel(logging.ERROR)

# Specify the model to simulate data from.
from pgmpy.factors.discrete import TabularCPD
from pgmpy.example_models import load_model

import numpy as np
import pandas as pd

alarm = load_model("bnlearn/alarm")

1. Standard simulation#

Without any specified conditions for simulation, the simulate method draws samples from the joint distribution of the model.

[3]:
samples = alarm.simulate(n_samples=int(1e4))
samples.head()
/home/ankur/work/pgmpy/pgmpy/pgmpy/estimators/__init__.py:4: FutureWarning: `pgmpy.estimators.StructureScore` is deprecated and will be removed in a future release. Use `pgmpy.structure_score` instead.
  from pgmpy.estimators.StructureScore import (
[3]:
VENTMACH HYPOVOLEMIA CVP STROKEVOLUME PRESS SAO2 PCWP HR MINVOLSET BP ... ARTCO2 SHUNT CATECHOL ANAPHYLAXIS CO PAP EXPCO2 INTUBATION FIO2 HRBP
0 NORMAL FALSE NORMAL NORMAL NORMAL LOW NORMAL HIGH NORMAL LOW ... HIGH NORMAL HIGH FALSE HIGH NORMAL LOW NORMAL NORMAL HIGH
1 NORMAL FALSE NORMAL NORMAL HIGH LOW NORMAL HIGH NORMAL LOW ... HIGH NORMAL HIGH FALSE HIGH HIGH LOW NORMAL NORMAL HIGH
2 NORMAL FALSE NORMAL NORMAL HIGH LOW NORMAL HIGH NORMAL LOW ... HIGH NORMAL HIGH FALSE HIGH NORMAL NORMAL ESOPHAGEAL NORMAL HIGH
3 NORMAL FALSE NORMAL HIGH LOW HIGH NORMAL LOW NORMAL NORMAL ... LOW NORMAL NORMAL FALSE NORMAL NORMAL LOW NORMAL NORMAL LOW
4 NORMAL FALSE NORMAL NORMAL HIGH NORMAL NORMAL HIGH NORMAL LOW ... HIGH NORMAL HIGH FALSE HIGH LOW LOW NORMAL NORMAL HIGH

5 rows × 37 columns

2. Simulation under specified evidence#

Specifying hard evidence for some variables fixes their values to the specified evidence value during simulation.

[4]:
evidence = {"CVP": "NORMAL", "HR": "HIGH"}
samples = alarm.simulate(n_samples=int(1e4), evidence=evidence)
samples.head()
[4]:
VENTMACH HYPOVOLEMIA CVP STROKEVOLUME PRESS SAO2 PCWP HR MINVOLSET BP ... ARTCO2 SHUNT CATECHOL ANAPHYLAXIS CO PAP EXPCO2 INTUBATION FIO2 HRBP
0 NORMAL FALSE NORMAL HIGH LOW LOW NORMAL HIGH NORMAL NORMAL ... HIGH NORMAL HIGH FALSE HIGH NORMAL LOW NORMAL NORMAL HIGH
1 NORMAL FALSE NORMAL NORMAL ZERO LOW NORMAL HIGH NORMAL NORMAL ... HIGH NORMAL HIGH FALSE HIGH NORMAL LOW NORMAL NORMAL HIGH
2 NORMAL FALSE NORMAL HIGH HIGH LOW NORMAL HIGH NORMAL HIGH ... HIGH NORMAL HIGH FALSE HIGH NORMAL LOW NORMAL NORMAL HIGH
3 NORMAL FALSE NORMAL NORMAL LOW LOW NORMAL HIGH NORMAL HIGH ... HIGH NORMAL HIGH FALSE HIGH NORMAL LOW NORMAL NORMAL HIGH
4 NORMAL FALSE NORMAL NORMAL HIGH LOW NORMAL HIGH NORMAL HIGH ... HIGH NORMAL HIGH FALSE HIGH LOW LOW NORMAL NORMAL HIGH

5 rows × 37 columns

[5]:
# All values of HR and CVP should be set to HIGH and NORMAL respectively.
print(all(samples.HR == "HIGH"))
print(all(samples.CVP == "NORMAL"))
True
True

3. Simulation under soft/virtual evidence#

Unlike hard evidence where the value of the specified variables is fixed to the specified evidence, virtual evidence allows users to set the marginal distribution of variables.

[6]:
# The virtual evidence is specified using TabularCPDs. Here, P(CVP=NORMAL) = 0.2, P(CVP=LOW) = 0.3, and P(CPV=HIGH) = 0.5
cvp_evidence = TabularCPD(variable="CVP",
                          variable_card=3,
                          values=[[0.2], [0.3], [0.5]],
                          state_names={"CVP": ["LOW", "NORMAL", "HIGH"]})
samples = alarm.simulate(n_samples=int(1e4), virtual_evidence=[cvp_evidence])
[7]:
# Check the marginal distribution of CVP
get_distribution(samples, variables=['CVP'])
[7]:
CVP
HIGH      0.2375
LOW       0.0710
NORMAL    0.6915
dtype: float64

4. Simulation under specified intervention#

Using the do argument, users can specify interventions to the model. The value of the intervened variables are set to the specified value and all incoming edges to these variables are removed in the model.

[8]:
samples = alarm.simulate(n_samples=int(1e4), do={"CVP": "NORMAL", "HR": "HIGH"})
samples.head()
[8]:
VENTMACH HYPOVOLEMIA CVP STROKEVOLUME PRESS SAO2 PCWP HR MINVOLSET BP ... ARTCO2 SHUNT CATECHOL ANAPHYLAXIS CO PAP EXPCO2 INTUBATION FIO2 HRBP
0 NORMAL FALSE NORMAL NORMAL HIGH LOW NORMAL HIGH NORMAL HIGH ... HIGH NORMAL HIGH FALSE HIGH HIGH HIGH ESOPHAGEAL NORMAL HIGH
1 NORMAL FALSE NORMAL LOW HIGH LOW LOW HIGH NORMAL NORMAL ... HIGH NORMAL HIGH FALSE NORMAL NORMAL LOW NORMAL NORMAL HIGH
2 HIGH TRUE NORMAL LOW LOW HIGH LOW HIGH HIGH HIGH ... LOW NORMAL NORMAL FALSE NORMAL NORMAL LOW NORMAL NORMAL HIGH
3 NORMAL TRUE NORMAL NORMAL LOW LOW HIGH HIGH NORMAL NORMAL ... HIGH NORMAL HIGH FALSE HIGH NORMAL LOW NORMAL NORMAL HIGH
4 NORMAL FALSE NORMAL LOW NORMAL LOW LOW HIGH NORMAL LOW ... HIGH NORMAL HIGH FALSE LOW NORMAL LOW NORMAL NORMAL LOW

5 rows × 37 columns

5. Simulation under soft/virtual intervention#

Similar to virtual evidence, users can specify virtual intervention as well.

[9]:
cvp_intervention = TabularCPD(variable="CVP",
                              variable_card=3,
                              values=[[0.2], [0.3], [0.5]],
                              state_names={"CVP": ["LOW", "NORMAL", "HIGH"]})
samples = alarm.simulate(n_samples=int(1e4), virtual_intervention=[cvp_intervention])
get_distribution(samples, variables=["CVP"])  # P(HISTORY)
[9]:
CVP
HIGH      0.3805
LOW       0.2079
NORMAL    0.4116
dtype: float64

6. Partial samples#

Users can also pass already generated data for some variables (for example, from some other simulation) to the simulation. This is equivalent to separately specifying evidence for each sample that is generate.

[10]:
# Generate some data on CVP.
partial_cvp = pd.DataFrame(np.random.choice(["LOW", "NORMAL", "HIGH"], int(1e4)), columns=['CVP'])
samples = alarm.simulate(n_samples=int(1e4), partial_samples=partial_cvp)
[11]:
print(all(samples["CVP"] == partial_cvp["CVP"]))
True

7. Simulate missing data#

Lastly, users can generate data with missing values for some specified variables, according to a user defined CPD. The name of the missing variable should be followed by a * to indicate missingness, and should contain 2 states: 1 (Missing) and 0 (Not Missing). Optionally, we can use the return_full argument to get back the removed values for comparison.

7.1. Missing completely at random (MCAR)#

[12]:
# CVP data missing completely randomly with 0.4 probability
missing_CVP = TabularCPD(
    variable="CVP*",
    variable_card=2,
    values=[[0.6],
            [0.4]], # Missing probability = 0.4
    state_names={"CVP*": [0, 1]}
)

samples = alarm.simulate(n_samples=1000, missing_prob=[missing_CVP], return_full=True)
samples.head()
[12]:
VENTMACH HYPOVOLEMIA CVP STROKEVOLUME PRESS SAO2 PCWP HR MINVOLSET BP ... SHUNT CATECHOL ANAPHYLAXIS CO PAP EXPCO2 INTUBATION FIO2 CVP_full HRBP
0 NORMAL FALSE NORMAL NORMAL HIGH HIGH NORMAL NORMAL NORMAL NORMAL ... NORMAL NORMAL FALSE NORMAL NORMAL LOW NORMAL NORMAL NORMAL LOW
1 NORMAL FALSE NORMAL NORMAL LOW LOW NORMAL HIGH NORMAL NORMAL ... NORMAL HIGH FALSE HIGH NORMAL LOW NORMAL NORMAL NORMAL HIGH
2 NORMAL FALSE LOW LOW LOW LOW LOW HIGH NORMAL NORMAL ... NORMAL HIGH FALSE LOW NORMAL LOW NORMAL NORMAL LOW HIGH
3 NORMAL FALSE NaN NORMAL NORMAL LOW NORMAL HIGH NORMAL LOW ... HIGH HIGH FALSE HIGH NORMAL LOW NORMAL NORMAL NORMAL HIGH
4 NORMAL FALSE NORMAL NORMAL LOW LOW NORMAL HIGH NORMAL HIGH ... NORMAL HIGH FALSE HIGH NORMAL ZERO NORMAL NORMAL NORMAL HIGH

5 rows × 38 columns

[13]:
print(f"Missing values: {samples['CVP'].isna().sum()}/{len(samples['CVP'])}")
print()

print("Original Distribution:")
print(get_distribution(samples, variables="CVP_full"))
print()
print("Distribution of Missing/Removed")
print(get_distribution(samples.loc[samples["CVP"].isna()], variables="CVP_full")) # Since removal was completely random, we expect minimal change in distribution
Missing values: 413/1000

Original Distribution:
CVP_full
HIGH      0.175
LOW       0.125
NORMAL    0.700
dtype: float64

Distribution of Missing/Removed
CVP_full
HIGH      0.159806
LOW       0.130751
NORMAL    0.709443
dtype: float64

7.2. Missing at random (MAR)#

[14]:
# CVP data missing depending on the observed LVEDVOLUME
missing_CVP = TabularCPD(
    variable="CVP*",
    variable_card=2,
    values=[[0.8, 0.2, 0.7],
            [0.2, 0.8, 0.3]], # Missing probabilities: LOW = 0.2, NORMAL = 0.8, HIGH = 0.3
    evidence=["LVEDVOLUME"],
    evidence_card=[3],
    state_names={
        "CVP*": [0, 1],
        "LVEDVOLUME": ["LOW", "NORMAL", "HIGH"]}
)

samples = alarm.simulate(n_samples=1000, missing_prob=[missing_CVP], return_full=True)
samples.head()
[14]:
VENTMACH HYPOVOLEMIA CVP STROKEVOLUME PRESS SAO2 PCWP HR MINVOLSET BP ... SHUNT CATECHOL ANAPHYLAXIS CO PAP EXPCO2 INTUBATION FIO2 CVP_full HRBP
0 NORMAL FALSE NaN NORMAL NORMAL LOW LOW HIGH NORMAL HIGH ... NORMAL HIGH FALSE HIGH NORMAL LOW NORMAL NORMAL LOW HIGH
1 NORMAL FALSE NaN NORMAL HIGH HIGH NORMAL NORMAL NORMAL HIGH ... NORMAL NORMAL FALSE NORMAL NORMAL LOW NORMAL NORMAL NORMAL LOW
2 NORMAL TRUE HIGH NORMAL HIGH LOW HIGH HIGH NORMAL NORMAL ... HIGH HIGH FALSE NORMAL NORMAL NORMAL ONESIDED NORMAL HIGH HIGH
3 NORMAL FALSE NaN NORMAL NORMAL LOW NORMAL HIGH NORMAL HIGH ... NORMAL HIGH FALSE HIGH NORMAL LOW NORMAL NORMAL LOW HIGH
4 NORMAL FALSE NaN NORMAL LOW LOW NORMAL HIGH NORMAL HIGH ... NORMAL HIGH FALSE HIGH NORMAL LOW NORMAL NORMAL NORMAL HIGH

5 rows × 38 columns

[15]:
print(f"Missing values: {samples['CVP'].isna().sum()}/{len(samples['CVP'])}")
print()

print("Original Distribution:")
print(get_distribution(samples, variables=["LVEDVOLUME", "CVP_full"]))
print()
print("Distribution of Missing/Removed")
print(get_distribution(samples.loc[samples["CVP"].isna()], variables=["LVEDVOLUME", "CVP_full"])) # Since probability of missing is higher when LVEDVOLUME is "NORMAL" we expect distribution to be higher values there, and lesser otherwise
Missing values: 647/1000

Original Distribution:
LVEDVOLUME  CVP_full
HIGH        HIGH        0.124
            LOW         0.002
            NORMAL      0.058
LOW         HIGH        0.000
            LOW         0.092
            NORMAL      0.008
NORMAL      HIGH        0.006
            LOW         0.042
            NORMAL      0.668
dtype: float64

Distribution of Missing/Removed
LVEDVOLUME  CVP_full
HIGH        HIGH        0.057187
            LOW         0.000000
            NORMAL      0.026275
LOW         HIGH        0.000000
            LOW         0.024730
            NORMAL      0.001546
NORMAL      HIGH        0.009274
            LOW         0.043277
            NORMAL      0.837713
dtype: float64

7.3 Missing not at random (MNAR)#

[16]:
# CVP data missing depending on the unobserved original CVP value
missing_CVP = TabularCPD(
    variable="CVP*",
    variable_card=2,
    values=[[0.2, 0.4, 0.6],
            [0.8, 0.6, 0.4]], # Missing probabilities: LOW = 0.8, NORMAL = 0.6, HIGH = 0.4
    evidence=["CVP"],
    evidence_card=[3],
    state_names={
        "CVP*": [0, 1],
        "CVP": ["LOW", "NORMAL", "HIGH"]}
)

samples = alarm.simulate(n_samples=1000, missing_prob=[missing_CVP], return_full=True)
samples.head()
[16]:
VENTMACH HYPOVOLEMIA CVP STROKEVOLUME PRESS SAO2 PCWP HR MINVOLSET BP ... SHUNT CATECHOL ANAPHYLAXIS CO PAP EXPCO2 INTUBATION FIO2 CVP_full HRBP
0 NORMAL FALSE NaN NORMAL LOW LOW LOW HIGH NORMAL HIGH ... NORMAL HIGH FALSE HIGH NORMAL LOW NORMAL NORMAL NORMAL HIGH
1 NORMAL FALSE HIGH NORMAL NORMAL LOW LOW HIGH NORMAL LOW ... NORMAL HIGH FALSE HIGH NORMAL LOW NORMAL NORMAL HIGH HIGH
2 NORMAL FALSE NaN NORMAL NORMAL LOW NORMAL HIGH NORMAL HIGH ... NORMAL HIGH FALSE HIGH NORMAL LOW NORMAL NORMAL NORMAL HIGH
3 NORMAL FALSE NaN NORMAL HIGH HIGH NORMAL NORMAL NORMAL HIGH ... NORMAL NORMAL FALSE HIGH NORMAL LOW NORMAL NORMAL NORMAL HIGH
4 HIGH FALSE NORMAL LOW LOW HIGH NORMAL HIGH HIGH LOW ... NORMAL HIGH FALSE LOW NORMAL LOW NORMAL NORMAL NORMAL HIGH

5 rows × 38 columns

[17]:
print(f"Missing values: {samples['CVP'].isna().sum()}/{len(samples['CVP'])}")
print()

print("Original Distribution:")
print(get_distribution(samples, variables="CVP_full"))
print()
print("Distribution of Missing/Removed")
print(get_distribution(samples.loc[samples["CVP"].isna()], variables="CVP_full")) # Since probability of missing is higher when CVP is "LOW" and lower when "CVP" is high we expect missing distribution for "LOW" to be greater, and for "HIGH" to be lower
Missing values: 563/1000

Original Distribution:
CVP_full
HIGH      0.174
LOW       0.112
NORMAL    0.714
dtype: float64

Distribution of Missing/Removed
CVP_full
HIGH      0.117229
LOW       0.174067
NORMAL    0.708703
dtype: float64