Parameter Learning in Discrete Bayesian Networks¶
In this notebook, we demonstrate examples of learning the parameters (CPDs) of a Discrete Bayesian Network given the data and the model structure. pgmpy has three main algorithms for learning model parameters:
Maximum Likelihood Estimator (
pgmpy.estimators.MaximumLikelihoodEstimator
): Simply estimates the Maximum Likelihood estimates of the parameters.Bayesian Estimator (
pgmpy.estimators.BayesianEstimator
): Allows users to specify priors.Expectation Maximization (
pgmpy.estimators.ExpectationMaximization
): Enables learning model parameters when latent variables are present in the model.
Each of the parameter estimation classes have the following two methods:
estimate_cpd
: Estimates the CPD of the specified variable.get_parameters
: Estimates the CPDs of all the variables in the model.
Step 0: Generate some simulated data and a model structure¶
To do parameter estimation we require two things: 1. Data: For the examples, we simulate some data from the alarm model (https://www.bnlearn.com/bnrepository/discrete-medium.html#alarm) and use it to learn back the model parameters. 2. Model structure: We also need to specify the model structure to which to fit the data to. In this example, we simply use the structure to the alarm model.
[1]:
from pgmpy.utils import get_example_model
from pgmpy.models import BayesianNetwork
# Load the alarm model and simulate data from it.
alarm_model = get_example_model(model="alarm")
samples = alarm_model.simulate(n_samples=int(1e3))
print(samples.head())
# Define a new model with the same structure as the alarm model.
new_model = BayesianNetwork(ebunch=alarm_model.edges())
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: -1.4901161193847656e-08. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1175870895385742e-08. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1175870895385742e-08. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 2.2351741790771484e-08. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 2.2351741790771484e-08. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: -9.313225746154785e-09. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 2.2351741790771484e-08. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1175870895385742e-08. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: -9.313225746154785e-09. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: -1.6763806343078613e-08. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 2.2351741790771484e-08. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 2.2351741790771484e-08. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: 1.1102230246251565e-16. Adjusting values.
EXPCO2 INTUBATION PCWP HREKG VENTLUNG SAO2 VENTALV PULMEMBOLUS \
0 LOW NORMAL NORMAL HIGH ZERO LOW ZERO FALSE
1 LOW NORMAL NORMAL NORMAL ZERO LOW ZERO FALSE
2 LOW NORMAL HIGH HIGH LOW HIGH HIGH FALSE
3 NORMAL ONESIDED NORMAL HIGH LOW LOW LOW FALSE
4 LOW ESOPHAGEAL LOW HIGH ZERO LOW LOW FALSE
ERRLOWOUTPUT HR ... LVFAILURE KINKEDTUBE HISTORY HYPOVOLEMIA \
0 FALSE NORMAL ... FALSE FALSE FALSE FALSE
1 FALSE HIGH ... FALSE FALSE FALSE FALSE
2 TRUE HIGH ... FALSE FALSE FALSE TRUE
3 FALSE HIGH ... FALSE FALSE FALSE FALSE
4 FALSE HIGH ... FALSE FALSE FALSE FALSE
STROKEVOLUME VENTMACH VENTTUBE CVP SHUNT MINVOLSET
0 NORMAL NORMAL LOW NORMAL NORMAL NORMAL
1 NORMAL LOW ZERO NORMAL NORMAL NORMAL
2 NORMAL HIGH HIGH HIGH NORMAL NORMAL
3 NORMAL NORMAL LOW NORMAL HIGH NORMAL
4 NORMAL NORMAL ZERO LOW NORMAL NORMAL
[5 rows x 37 columns]
Using the Maximumum Likelihood Estimator¶
[2]:
from pgmpy.estimators import MaximumLikelihoodEstimator
# Initialize the estimator object.
mle_est = MaximumLikelihoodEstimator(model=new_model, data=samples)
[3]:
# Estimate the CPD of the node FIO2.
print(mle_est.estimate_cpd(node="FIO2"))
+--------------+-------+
| FIO2(LOW) | 0.058 |
+--------------+-------+
| FIO2(NORMAL) | 0.942 |
+--------------+-------+
[4]:
# Estimate the CPD of node CVP
mle_est.estimate_cpd(node="CVP")
[4]:
<TabularCPD representing P(CVP:3 | LVEDVOLUME:3) at 0x7b043ec68860>
[5]:
# Estimate all the CPDs for `new_model`
all_cpds = mle_est.get_parameters(n_jobs=1)
# Add the estimated CPDs to the model.
new_model.add_cpds(*all_cpds)
# Check if the CPDs are added to the model
new_model.get_cpds('PCWP')
[5]:
<TabularCPD representing P(PCWP:3 | LVEDVOLUME:3) at 0x7b0427cb85f0>
Using the Bayesian Estimator¶
[6]:
# Initialize the Bayesian Estimator
from pgmpy.estimators import BayesianEstimator
be_est = BayesianEstimator(model=new_model, data=samples)
The estimator methods in BayesianEstimator
class allows for a few different ways to specify the priors. The prior type can be chosen by specifying the prior_type
argument. Please refer the documentation for details on different ways these priors can be specified: https://pgmpy.org/param_estimator/bayesian_est.html#bayesian-estimator
Dirichlet prior (
prior_type="dirichlet"
): Requires specifyingpseudo_counts
argument. The pseudo_counts arguments specifies the priors to use for the CPD estimation.BDeu prior (
prior_type="BDeu"
): Requires specifyingequivalent_sample_size
arguemnt. The equivaluent_sample_size is used to compute the priors to use for CPD estimation.K2 (
prior_type="K2"
): Short hand for dirichlet prior with pseudo_count=1.
[7]:
print(be_est.estimate_cpd(node="FIO2", prior_type="BDeu", equivalent_sample_size=1000))
print(be_est.estimate_cpd(node="CVP", prior_type="dirichlet", pseudo_counts=100))
+--------------+-------+
| FIO2(LOW) | 0.279 |
+--------------+-------+
| FIO2(NORMAL) | 0.721 |
+--------------+-------+
+-------------+---------------------+-----+---------------------+
| LVEDVOLUME | LVEDVOLUME(HIGH) | ... | LVEDVOLUME(NORMAL) |
+-------------+---------------------+-----+---------------------+
| CVP(HIGH) | 0.48841698841698844 | ... | 0.10685483870967742 |
+-------------+---------------------+-----+---------------------+
| CVP(LOW) | 0.19884169884169883 | ... | 0.13709677419354838 |
+-------------+---------------------+-----+---------------------+
| CVP(NORMAL) | 0.3127413127413127 | ... | 0.7560483870967742 |
+-------------+---------------------+-----+---------------------+
[8]:
be_est.get_parameters(prior_type="K2", equivalent_sample_size=1000)
[8]:
[<TabularCPD representing P(HYPOVOLEMIA:2) at 0x7b050bad4590>,
<TabularCPD representing P(LVEDVOLUME:3 | HYPOVOLEMIA:2, LVFAILURE:2) at 0x7b0427cb9c40>,
<TabularCPD representing P(STROKEVOLUME:3 | HYPOVOLEMIA:2, LVFAILURE:2) at 0x7b0427cb9250>,
<TabularCPD representing P(CVP:3 | LVEDVOLUME:3) at 0x7b0427cb9070>,
<TabularCPD representing P(PCWP:3 | LVEDVOLUME:3) at 0x7b0427cbb080>,
<TabularCPD representing P(LVFAILURE:2) at 0x7b0427cbb110>,
<TabularCPD representing P(HISTORY:2 | LVFAILURE:2) at 0x7b0427cbaf90>,
<TabularCPD representing P(CO:3 | HR:3, STROKEVOLUME:3) at 0x7b0427cbaa20>,
<TabularCPD representing P(ERRLOWOUTPUT:2) at 0x7b0427cbb020>,
<TabularCPD representing P(HRBP:3 | ERRLOWOUTPUT:2, HR:3) at 0x7b0427cbb2c0>,
<TabularCPD representing P(ERRCAUTER:2) at 0x7b0427cbb4d0>,
<TabularCPD representing P(HREKG:3 | ERRCAUTER:2, HR:3) at 0x7b0427cbb1a0>,
<TabularCPD representing P(HRSAT:3 | ERRCAUTER:2, HR:3) at 0x7b0427cbac00>,
<TabularCPD representing P(INSUFFANESTH:2) at 0x7b0427cba3c0>,
<TabularCPD representing P(CATECHOL:2 | ARTCO2:3, INSUFFANESTH:2, SAO2:3, TPR:3) at 0x7b0427cbb170>,
<TabularCPD representing P(ANAPHYLAXIS:2) at 0x7b0427cbaea0>,
<TabularCPD representing P(TPR:3 | ANAPHYLAXIS:2) at 0x7b0427cbb200>,
<TabularCPD representing P(BP:3 | CO:3, TPR:3) at 0x7b0427cbb0b0>,
<TabularCPD representing P(KINKEDTUBE:2) at 0x7b0427cba1b0>,
<TabularCPD representing P(PRESS:4 | INTUBATION:3, KINKEDTUBE:2, VENTTUBE:4) at 0x7b0427cb98e0>,
<TabularCPD representing P(VENTLUNG:4 | INTUBATION:3, KINKEDTUBE:2, VENTTUBE:4) at 0x7b0427cbb350>,
<TabularCPD representing P(FIO2:2) at 0x7b0427cbade0>,
<TabularCPD representing P(PVSAT:3 | FIO2:2, VENTALV:4) at 0x7b0427cba900>,
<TabularCPD representing P(SAO2:3 | PVSAT:3, SHUNT:2) at 0x7b0427cb8d70>,
<TabularCPD representing P(PULMEMBOLUS:2) at 0x7b0427cbb2f0>,
<TabularCPD representing P(PAP:3 | PULMEMBOLUS:2) at 0x7b0427cb9790>,
<TabularCPD representing P(SHUNT:2 | INTUBATION:3, PULMEMBOLUS:2) at 0x7b0427cbb3e0>,
<TabularCPD representing P(INTUBATION:3) at 0x7b0427cbb4a0>,
<TabularCPD representing P(MINVOL:4 | INTUBATION:3, VENTLUNG:4) at 0x7b0427cbb530>,
<TabularCPD representing P(VENTALV:4 | INTUBATION:3, VENTLUNG:4) at 0x7b0427cbb560>,
<TabularCPD representing P(DISCONNECT:2) at 0x7b0427cbb590>,
<TabularCPD representing P(VENTTUBE:4 | DISCONNECT:2, VENTMACH:4) at 0x7b0427cbb5c0>,
<TabularCPD representing P(MINVOLSET:3) at 0x7b0427cbb5f0>,
<TabularCPD representing P(VENTMACH:4 | MINVOLSET:3) at 0x7b0427cbb620>,
<TabularCPD representing P(EXPCO2:4 | ARTCO2:3, VENTLUNG:4) at 0x7b0427cbb650>,
<TabularCPD representing P(ARTCO2:3 | VENTALV:4) at 0x7b0427cbb680>,
<TabularCPD representing P(HR:3 | CATECHOL:2) at 0x7b0427cbb6b0>]
Using Expectation Maximization¶
The Expectation Maximization (EM) estimator can work in the case when latent variables are present in the model. To simulate this scenario, we will specify some of the variables in our new_model
as latents and drop those variables from samples
to simulate missing data.
[9]:
model_latent = BayesianNetwork(alarm_model.edges(), latents={'HISTORY', 'CVP'})
samples_latent = samples.drop(['HISTORY', 'CVP'], axis=1)
[10]:
from pgmpy.estimators import ExpectationMaximization as EM
em_est = EM(model=model_latent, data=samples_latent)
em_est.get_parameters()
[10]:
[<TabularCPD representing P(EXPCO2:4 | ARTCO2:3, VENTLUNG:4) at 0x7b0427e26930>,
<TabularCPD representing P(INTUBATION:3) at 0x7b0427cbbe90>,
<TabularCPD representing P(PCWP:3 | LVEDVOLUME:3) at 0x7b0427e27590>,
<TabularCPD representing P(HREKG:3 | ERRCAUTER:2, HR:3) at 0x7b0427d01fd0>,
<TabularCPD representing P(VENTLUNG:4 | INTUBATION:3, KINKEDTUBE:2, VENTTUBE:4) at 0x7b0427cb9eb0>,
<TabularCPD representing P(SAO2:3 | PVSAT:3, SHUNT:2) at 0x7b0427e27d10>,
<TabularCPD representing P(VENTALV:4 | INTUBATION:3, VENTLUNG:4) at 0x7b0427d019a0>,
<TabularCPD representing P(PULMEMBOLUS:2) at 0x7b0427e260f0>,
<TabularCPD representing P(ERRLOWOUTPUT:2) at 0x7b0427d01910>,
<TabularCPD representing P(HR:3 | CATECHOL:2) at 0x7b0427d01bb0>,
<TabularCPD representing P(HRSAT:3 | ERRCAUTER:2, HR:3) at 0x7b0427d02480>,
<TabularCPD representing P(DISCONNECT:2) at 0x7b0427ca7290>,
<TabularCPD representing P(ERRCAUTER:2) at 0x7b0427d01f40>,
<TabularCPD representing P(CO:3 | HR:3, STROKEVOLUME:3) at 0x7b0427cba150>,
<TabularCPD representing P(BP:3 | CO:3, TPR:3) at 0x7b0427d02720>,
<TabularCPD representing P(LVEDVOLUME:3 | HYPOVOLEMIA:2, LVFAILURE:2) at 0x7b0427d02870>,
<TabularCPD representing P(ANAPHYLAXIS:2) at 0x7b0427ca6c90>,
<TabularCPD representing P(TPR:3 | ANAPHYLAXIS:2) at 0x7b0427d025a0>,
<TabularCPD representing P(HRBP:3 | ERRLOWOUTPUT:2, HR:3) at 0x7b0427cba0c0>,
<TabularCPD representing P(PVSAT:3 | FIO2:2, VENTALV:4) at 0x7b0427d02b10>,
<TabularCPD representing P(CATECHOL:2 | ARTCO2:3, INSUFFANESTH:2, SAO2:3, TPR:3) at 0x7b0427cb96d0>,
<TabularCPD representing P(INSUFFANESTH:2) at 0x7b0427cba8d0>,
<TabularCPD representing P(FIO2:2) at 0x7b0427d02390>,
<TabularCPD representing P(ARTCO2:3 | VENTALV:4) at 0x7b0427d02ab0>,
<TabularCPD representing P(PAP:3 | PULMEMBOLUS:2) at 0x7b0427d01af0>,
<TabularCPD representing P(MINVOL:4 | INTUBATION:3, VENTLUNG:4) at 0x7b0427d02fc0>,
<TabularCPD representing P(PRESS:4 | INTUBATION:3, KINKEDTUBE:2, VENTTUBE:4) at 0x7b0427cbadb0>,
<TabularCPD representing P(LVFAILURE:2) at 0x7b044b06f4a0>,
<TabularCPD representing P(KINKEDTUBE:2) at 0x7b0427d02e10>,
<TabularCPD representing P(HYPOVOLEMIA:2) at 0x7b0427d03170>,
<TabularCPD representing P(STROKEVOLUME:3 | HYPOVOLEMIA:2, LVFAILURE:2) at 0x7b0427cbafc0>,
<TabularCPD representing P(VENTMACH:4 | MINVOLSET:3) at 0x7b0427d02f30>,
<TabularCPD representing P(VENTTUBE:4 | DISCONNECT:2, VENTMACH:4) at 0x7b0427d01ca0>,
<TabularCPD representing P(SHUNT:2 | INTUBATION:3, PULMEMBOLUS:2) at 0x7b0427ca6e10>,
<TabularCPD representing P(MINVOLSET:3) at 0x7b0427e256d0>,
<TabularCPD representing P(CVP:2 | LVEDVOLUME:3) at 0x7b043ef333b0>,
<TabularCPD representing P(HISTORY:2 | LVFAILURE:2) at 0x7b043ec68b60>]
Shortcut for learning and adding CPDs to the model¶
The BayesianNetwork
class also provies a fit
method that acts as a shortcut way to estimate and add CPDs to the model.
[11]:
# Shortcut for learning all the parameters and adding the CPDs to the model.
model_struct = BayesianNetwork(ebunch=alarm_model.edges())
model_struct.fit(data=samples, estimator=MaximumLikelihoodEstimator)
print(model_struct.get_cpds("FIO2"))
model_struct = BayesianNetwork(ebunch=alarm_model.edges())
model_struct.fit(
data=samples,
estimator=BayesianEstimator,
prior_type="BDeu",
equivalent_sample_size=1000,
)
print(model_struct.get_cpds("FIO2"))
+--------------+-------+
| FIO2(LOW) | 0.058 |
+--------------+-------+
| FIO2(NORMAL) | 0.942 |
+--------------+-------+
+--------------+-------+
| FIO2(LOW) | 0.279 |
+--------------+-------+
| FIO2(NORMAL) | 0.721 |
+--------------+-------+