Parameter Learning in Discrete Bayesian Networks

In this notebook, we demonstrate examples of learning the parameters (CPDs) of a Discrete Bayesian Network given the data and the model structure. pgmpy has three main algorithms for learning model parameters:

  1. Maximum Likelihood Estimator (pgmpy.estimators.MaximumLikelihoodEstimator): Simply estimates the Maximum Likelihood estimates of the parameters.

  2. Bayesian Estimator (pgmpy.estimators.BayesianEstimator): Allows users to specify priors.

  3. Expectation Maximization (pgmpy.estimators.ExpectationMaximization): Enables learning model parameters when latent variables are present in the model.

Each of the parameter estimation classes have the following two methods:

  1. estimate_cpd: Estimates the CPD of the specified variable.

  2. get_parameters: Estimates the CPDs of all the variables in the model.

Step 0: Generate some simulated data and a model structure

To do parameter estimation we require two things: 1. Data: For the examples, we simulate some data from the alarm model (https://www.bnlearn.com/bnrepository/discrete-medium.html#alarm) and use it to learn back the model parameters. 2. Model structure: We also need to specify the model structure to which to fit the data to. In this example, we simply use the structure to the alarm model.

[1]:
from pgmpy.utils import get_example_model
from pgmpy.models import BayesianNetwork

# Load the alarm model and simulate data from it.
alarm_model = get_example_model(model="alarm")
samples = alarm_model.simulate(n_samples=int(1e3))

print(samples.head())

# Define a new model with the same structure as the alarm model.
new_model = BayesianNetwork(ebunch=alarm_model.edges())
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: -1.4901161193847656e-08. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1175870895385742e-08. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1175870895385742e-08. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 2.2351741790771484e-08. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 2.2351741790771484e-08. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: -9.313225746154785e-09. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 2.2351741790771484e-08. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1175870895385742e-08. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: -9.313225746154785e-09. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: -1.6763806343078613e-08. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 2.2351741790771484e-08. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 2.2351741790771484e-08. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
   EXPCO2  INTUBATION    PCWP   HREKG VENTLUNG  SAO2 VENTALV PULMEMBOLUS  \
0     LOW      NORMAL  NORMAL    HIGH     ZERO   LOW    ZERO       FALSE
1     LOW      NORMAL  NORMAL  NORMAL     ZERO   LOW    ZERO       FALSE
2     LOW      NORMAL    HIGH    HIGH      LOW  HIGH    HIGH       FALSE
3  NORMAL    ONESIDED  NORMAL    HIGH      LOW   LOW     LOW       FALSE
4     LOW  ESOPHAGEAL     LOW    HIGH     ZERO   LOW     LOW       FALSE

  ERRLOWOUTPUT      HR  ... LVFAILURE KINKEDTUBE HISTORY HYPOVOLEMIA  \
0        FALSE  NORMAL  ...     FALSE      FALSE   FALSE       FALSE
1        FALSE    HIGH  ...     FALSE      FALSE   FALSE       FALSE
2         TRUE    HIGH  ...     FALSE      FALSE   FALSE        TRUE
3        FALSE    HIGH  ...     FALSE      FALSE   FALSE       FALSE
4        FALSE    HIGH  ...     FALSE      FALSE   FALSE       FALSE

  STROKEVOLUME VENTMACH VENTTUBE     CVP   SHUNT MINVOLSET
0       NORMAL   NORMAL      LOW  NORMAL  NORMAL    NORMAL
1       NORMAL      LOW     ZERO  NORMAL  NORMAL    NORMAL
2       NORMAL     HIGH     HIGH    HIGH  NORMAL    NORMAL
3       NORMAL   NORMAL      LOW  NORMAL    HIGH    NORMAL
4       NORMAL   NORMAL     ZERO     LOW  NORMAL    NORMAL

[5 rows x 37 columns]

Using the Maximumum Likelihood Estimator

[2]:
from pgmpy.estimators import MaximumLikelihoodEstimator

# Initialize the estimator object.
mle_est = MaximumLikelihoodEstimator(model=new_model, data=samples)
[3]:
# Estimate the CPD of the node FIO2.
print(mle_est.estimate_cpd(node="FIO2"))
+--------------+-------+
| FIO2(LOW)    | 0.058 |
+--------------+-------+
| FIO2(NORMAL) | 0.942 |
+--------------+-------+
[4]:
# Estimate the CPD of node CVP
mle_est.estimate_cpd(node="CVP")
[4]:
<TabularCPD representing P(CVP:3 | LVEDVOLUME:3) at 0x7b043ec68860>
[5]:
# Estimate all the CPDs for `new_model`
all_cpds = mle_est.get_parameters(n_jobs=1)

# Add the estimated CPDs to the model.
new_model.add_cpds(*all_cpds)

# Check if the CPDs are added to the model
new_model.get_cpds('PCWP')
[5]:
<TabularCPD representing P(PCWP:3 | LVEDVOLUME:3) at 0x7b0427cb85f0>

Using the Bayesian Estimator

[6]:
# Initialize the Bayesian Estimator
from pgmpy.estimators import BayesianEstimator

be_est = BayesianEstimator(model=new_model, data=samples)

The estimator methods in BayesianEstimator class allows for a few different ways to specify the priors. The prior type can be chosen by specifying the prior_type argument. Please refer the documentation for details on different ways these priors can be specified: https://pgmpy.org/param_estimator/bayesian_est.html#bayesian-estimator

  1. Dirichlet prior (prior_type="dirichlet"): Requires specifying pseudo_counts argument. The pseudo_counts arguments specifies the priors to use for the CPD estimation.

  2. BDeu prior (prior_type="BDeu"): Requires specifying equivalent_sample_size arguemnt. The equivaluent_sample_size is used to compute the priors to use for CPD estimation.

  3. K2 (prior_type="K2"): Short hand for dirichlet prior with pseudo_count=1.

[7]:
print(be_est.estimate_cpd(node="FIO2", prior_type="BDeu", equivalent_sample_size=1000))
print(be_est.estimate_cpd(node="CVP", prior_type="dirichlet", pseudo_counts=100))
+--------------+-------+
| FIO2(LOW)    | 0.279 |
+--------------+-------+
| FIO2(NORMAL) | 0.721 |
+--------------+-------+
+-------------+---------------------+-----+---------------------+
| LVEDVOLUME  | LVEDVOLUME(HIGH)    | ... | LVEDVOLUME(NORMAL)  |
+-------------+---------------------+-----+---------------------+
| CVP(HIGH)   | 0.48841698841698844 | ... | 0.10685483870967742 |
+-------------+---------------------+-----+---------------------+
| CVP(LOW)    | 0.19884169884169883 | ... | 0.13709677419354838 |
+-------------+---------------------+-----+---------------------+
| CVP(NORMAL) | 0.3127413127413127  | ... | 0.7560483870967742  |
+-------------+---------------------+-----+---------------------+
[8]:
be_est.get_parameters(prior_type="K2", equivalent_sample_size=1000)
[8]:
[<TabularCPD representing P(HYPOVOLEMIA:2) at 0x7b050bad4590>,
 <TabularCPD representing P(LVEDVOLUME:3 | HYPOVOLEMIA:2, LVFAILURE:2) at 0x7b0427cb9c40>,
 <TabularCPD representing P(STROKEVOLUME:3 | HYPOVOLEMIA:2, LVFAILURE:2) at 0x7b0427cb9250>,
 <TabularCPD representing P(CVP:3 | LVEDVOLUME:3) at 0x7b0427cb9070>,
 <TabularCPD representing P(PCWP:3 | LVEDVOLUME:3) at 0x7b0427cbb080>,
 <TabularCPD representing P(LVFAILURE:2) at 0x7b0427cbb110>,
 <TabularCPD representing P(HISTORY:2 | LVFAILURE:2) at 0x7b0427cbaf90>,
 <TabularCPD representing P(CO:3 | HR:3, STROKEVOLUME:3) at 0x7b0427cbaa20>,
 <TabularCPD representing P(ERRLOWOUTPUT:2) at 0x7b0427cbb020>,
 <TabularCPD representing P(HRBP:3 | ERRLOWOUTPUT:2, HR:3) at 0x7b0427cbb2c0>,
 <TabularCPD representing P(ERRCAUTER:2) at 0x7b0427cbb4d0>,
 <TabularCPD representing P(HREKG:3 | ERRCAUTER:2, HR:3) at 0x7b0427cbb1a0>,
 <TabularCPD representing P(HRSAT:3 | ERRCAUTER:2, HR:3) at 0x7b0427cbac00>,
 <TabularCPD representing P(INSUFFANESTH:2) at 0x7b0427cba3c0>,
 <TabularCPD representing P(CATECHOL:2 | ARTCO2:3, INSUFFANESTH:2, SAO2:3, TPR:3) at 0x7b0427cbb170>,
 <TabularCPD representing P(ANAPHYLAXIS:2) at 0x7b0427cbaea0>,
 <TabularCPD representing P(TPR:3 | ANAPHYLAXIS:2) at 0x7b0427cbb200>,
 <TabularCPD representing P(BP:3 | CO:3, TPR:3) at 0x7b0427cbb0b0>,
 <TabularCPD representing P(KINKEDTUBE:2) at 0x7b0427cba1b0>,
 <TabularCPD representing P(PRESS:4 | INTUBATION:3, KINKEDTUBE:2, VENTTUBE:4) at 0x7b0427cb98e0>,
 <TabularCPD representing P(VENTLUNG:4 | INTUBATION:3, KINKEDTUBE:2, VENTTUBE:4) at 0x7b0427cbb350>,
 <TabularCPD representing P(FIO2:2) at 0x7b0427cbade0>,
 <TabularCPD representing P(PVSAT:3 | FIO2:2, VENTALV:4) at 0x7b0427cba900>,
 <TabularCPD representing P(SAO2:3 | PVSAT:3, SHUNT:2) at 0x7b0427cb8d70>,
 <TabularCPD representing P(PULMEMBOLUS:2) at 0x7b0427cbb2f0>,
 <TabularCPD representing P(PAP:3 | PULMEMBOLUS:2) at 0x7b0427cb9790>,
 <TabularCPD representing P(SHUNT:2 | INTUBATION:3, PULMEMBOLUS:2) at 0x7b0427cbb3e0>,
 <TabularCPD representing P(INTUBATION:3) at 0x7b0427cbb4a0>,
 <TabularCPD representing P(MINVOL:4 | INTUBATION:3, VENTLUNG:4) at 0x7b0427cbb530>,
 <TabularCPD representing P(VENTALV:4 | INTUBATION:3, VENTLUNG:4) at 0x7b0427cbb560>,
 <TabularCPD representing P(DISCONNECT:2) at 0x7b0427cbb590>,
 <TabularCPD representing P(VENTTUBE:4 | DISCONNECT:2, VENTMACH:4) at 0x7b0427cbb5c0>,
 <TabularCPD representing P(MINVOLSET:3) at 0x7b0427cbb5f0>,
 <TabularCPD representing P(VENTMACH:4 | MINVOLSET:3) at 0x7b0427cbb620>,
 <TabularCPD representing P(EXPCO2:4 | ARTCO2:3, VENTLUNG:4) at 0x7b0427cbb650>,
 <TabularCPD representing P(ARTCO2:3 | VENTALV:4) at 0x7b0427cbb680>,
 <TabularCPD representing P(HR:3 | CATECHOL:2) at 0x7b0427cbb6b0>]

Using Expectation Maximization

The Expectation Maximization (EM) estimator can work in the case when latent variables are present in the model. To simulate this scenario, we will specify some of the variables in our new_model as latents and drop those variables from samples to simulate missing data.

[9]:
model_latent = BayesianNetwork(alarm_model.edges(), latents={'HISTORY', 'CVP'})
samples_latent = samples.drop(['HISTORY', 'CVP'], axis=1)
[10]:
from pgmpy.estimators import ExpectationMaximization as EM
em_est = EM(model=model_latent, data=samples_latent)
em_est.get_parameters()
[10]:
[<TabularCPD representing P(EXPCO2:4 | ARTCO2:3, VENTLUNG:4) at 0x7b0427e26930>,
 <TabularCPD representing P(INTUBATION:3) at 0x7b0427cbbe90>,
 <TabularCPD representing P(PCWP:3 | LVEDVOLUME:3) at 0x7b0427e27590>,
 <TabularCPD representing P(HREKG:3 | ERRCAUTER:2, HR:3) at 0x7b0427d01fd0>,
 <TabularCPD representing P(VENTLUNG:4 | INTUBATION:3, KINKEDTUBE:2, VENTTUBE:4) at 0x7b0427cb9eb0>,
 <TabularCPD representing P(SAO2:3 | PVSAT:3, SHUNT:2) at 0x7b0427e27d10>,
 <TabularCPD representing P(VENTALV:4 | INTUBATION:3, VENTLUNG:4) at 0x7b0427d019a0>,
 <TabularCPD representing P(PULMEMBOLUS:2) at 0x7b0427e260f0>,
 <TabularCPD representing P(ERRLOWOUTPUT:2) at 0x7b0427d01910>,
 <TabularCPD representing P(HR:3 | CATECHOL:2) at 0x7b0427d01bb0>,
 <TabularCPD representing P(HRSAT:3 | ERRCAUTER:2, HR:3) at 0x7b0427d02480>,
 <TabularCPD representing P(DISCONNECT:2) at 0x7b0427ca7290>,
 <TabularCPD representing P(ERRCAUTER:2) at 0x7b0427d01f40>,
 <TabularCPD representing P(CO:3 | HR:3, STROKEVOLUME:3) at 0x7b0427cba150>,
 <TabularCPD representing P(BP:3 | CO:3, TPR:3) at 0x7b0427d02720>,
 <TabularCPD representing P(LVEDVOLUME:3 | HYPOVOLEMIA:2, LVFAILURE:2) at 0x7b0427d02870>,
 <TabularCPD representing P(ANAPHYLAXIS:2) at 0x7b0427ca6c90>,
 <TabularCPD representing P(TPR:3 | ANAPHYLAXIS:2) at 0x7b0427d025a0>,
 <TabularCPD representing P(HRBP:3 | ERRLOWOUTPUT:2, HR:3) at 0x7b0427cba0c0>,
 <TabularCPD representing P(PVSAT:3 | FIO2:2, VENTALV:4) at 0x7b0427d02b10>,
 <TabularCPD representing P(CATECHOL:2 | ARTCO2:3, INSUFFANESTH:2, SAO2:3, TPR:3) at 0x7b0427cb96d0>,
 <TabularCPD representing P(INSUFFANESTH:2) at 0x7b0427cba8d0>,
 <TabularCPD representing P(FIO2:2) at 0x7b0427d02390>,
 <TabularCPD representing P(ARTCO2:3 | VENTALV:4) at 0x7b0427d02ab0>,
 <TabularCPD representing P(PAP:3 | PULMEMBOLUS:2) at 0x7b0427d01af0>,
 <TabularCPD representing P(MINVOL:4 | INTUBATION:3, VENTLUNG:4) at 0x7b0427d02fc0>,
 <TabularCPD representing P(PRESS:4 | INTUBATION:3, KINKEDTUBE:2, VENTTUBE:4) at 0x7b0427cbadb0>,
 <TabularCPD representing P(LVFAILURE:2) at 0x7b044b06f4a0>,
 <TabularCPD representing P(KINKEDTUBE:2) at 0x7b0427d02e10>,
 <TabularCPD representing P(HYPOVOLEMIA:2) at 0x7b0427d03170>,
 <TabularCPD representing P(STROKEVOLUME:3 | HYPOVOLEMIA:2, LVFAILURE:2) at 0x7b0427cbafc0>,
 <TabularCPD representing P(VENTMACH:4 | MINVOLSET:3) at 0x7b0427d02f30>,
 <TabularCPD representing P(VENTTUBE:4 | DISCONNECT:2, VENTMACH:4) at 0x7b0427d01ca0>,
 <TabularCPD representing P(SHUNT:2 | INTUBATION:3, PULMEMBOLUS:2) at 0x7b0427ca6e10>,
 <TabularCPD representing P(MINVOLSET:3) at 0x7b0427e256d0>,
 <TabularCPD representing P(CVP:2 | LVEDVOLUME:3) at 0x7b043ef333b0>,
 <TabularCPD representing P(HISTORY:2 | LVFAILURE:2) at 0x7b043ec68b60>]

Shortcut for learning and adding CPDs to the model

The BayesianNetwork class also provies a fit method that acts as a shortcut way to estimate and add CPDs to the model.

[11]:
# Shortcut for learning all the parameters and adding the CPDs to the model.

model_struct = BayesianNetwork(ebunch=alarm_model.edges())
model_struct.fit(data=samples, estimator=MaximumLikelihoodEstimator)
print(model_struct.get_cpds("FIO2"))

model_struct = BayesianNetwork(ebunch=alarm_model.edges())
model_struct.fit(
    data=samples,
    estimator=BayesianEstimator,
    prior_type="BDeu",
    equivalent_sample_size=1000,
)
print(model_struct.get_cpds("FIO2"))
+--------------+-------+
| FIO2(LOW)    | 0.058 |
+--------------+-------+
| FIO2(NORMAL) | 0.942 |
+--------------+-------+
+--------------+-------+
| FIO2(LOW)    | 0.279 |
+--------------+-------+
| FIO2(NORMAL) | 0.721 |
+--------------+-------+