[10]:
def get_distribution(samples, variables=None):
    """
    For marginal distribution, P(A): get_distribution(samples, variables=['A'])
    For joint distribution, P(A, B): get_distribution(samples, variables=['A', 'B'])
    """
    if variables is None:
        raise ValueError("variables must be specified")

    return samples.groupby(variables).size() / samples.shape[0]
[24]:
# Define a model
from pgmpy.factors.discrete import TabularCPD

from pgmpy.utils import get_example_model

alarm = get_example_model("alarm")

11. Normal Bayesian Network (no time variation)

11.1. Normal simulation

[28]:
samples = alarm.simulate(n_samples=int(1e4))
get_distribution(samples, variables=["HISTORY"])  # P(HISTORY)
[28]:
HISTORY
FALSE    0.9452
TRUE     0.0548
dtype: float64

11.2. Simulation with some evidence

[23]:
samples = alarm.simulate(n_samples=int(1e4), evidence={"CVP": "NORMAL", "HR": "HIGH"})
get_distribution(samples, variables=["HISTORY"])  # P(HISTORY)
[23]:
HISTORY
FALSE    0.9847
TRUE     0.0153
dtype: float64

11.3. Simulation with soft/virtual evidence

[32]:
soft_evidence = [
    TabularCPD(
        "CVP", 3, [[0.2], [0.3], [0.5]], state_names={"CVP": ["NORMAL", "LOW", "HIGH"]}
    )
]
samples = alarm.simulate(n_samples=int(1e4), virtual_evidence=soft_evidence)
get_distribution(samples, variables=["HISTORY"])  # P(HISTORY)
[32]:
HISTORY
FALSE    0.9609
TRUE     0.0391
dtype: float64

11.4. Simulation with intervention

[33]:
samples = alarm.simulate(n_samples=int(1e4), do={"CVP": "NORMAL", "HR": "HIGH"})
get_distribution(samples, variables=["HISTORY"])  # P(HISTORY)
[33]:
HISTORY
FALSE    0.9488
TRUE     0.0512
dtype: float64

11.5. Simulation with soft/virtual intervention

[34]:
soft_evidence = [
    TabularCPD(
        "CVP", 3, [[0.2], [0.3], [0.5]], state_names={"CVP": ["NORMAL", "LOW", "HIGH"]}
    )
]
samples = alarm.simulate(n_samples=int(1e4), virtual_intervention=soft_evidence)
get_distribution(samples, variables=["HISTORY"])  # P(HISTORY)
[34]:
HISTORY
FALSE    0.9508
TRUE     0.0492
dtype: float64
[ ]: