Approximate Inference Using Sampling¶
- class pgmpy.inference.ApproxInference.ApproxInference(model)[source]¶
- get_distribution(samples, variables, joint=True)[source]¶
Computes distribution of variables from given data samples.
- Parameters:
samples (pandas.DataFrame) – A dataframe of samples generated from the model.
variables (list (array-like)) – A list of variables whose distribution needs to be computed.
joint (boolean) – If joint=True, computes the joint distribution over variables. Else, returns a dict with marginal distribution of each variable in variables.
- query(variables, n_samples=10000, evidence=None, virtual_evidence=None, joint=True, show_progress=True)[source]¶
Method for doing approximate inference based on sampling in Bayesian Networks and Dynamic Bayesian Networks.
- Parameters:
variables (list) – List of variables for which the probability distribution needs to be calculated.
n_samples (int) – The number of samples to generate for computing the distributions. Higher n_samples results in more accurate results at the cost of more computation time.
evidence (dict (default: None)) – The observed values. A dict key, value pair of the form {var: state_name}.
virtual_evidence (list (default: None)) – A list of pgmpy.factors.discrete.TabularCPD representing the virtual/soft evidence.
show_progress (boolean (default: True)) – If True, shows a progress bar when generating samples.
- Returns:
Probability distribution – The queried probability distribution.
- Return type:
pgmpy.factors.discrete.TabularCPD
Examples
>>> from pgmpy.utils import get_example_model >>> from pgmpy.inference import ApproxInference >>> model = get_example_model("alarm") >>> infer = ApproxInference(model) >>> infer.query(variables=["HISTORY"]) <DiscreteFactor representing phi(HISTORY:2) at 0x7f92d9f5b910> >>> infer.query(variables=["HISTORY", "CVP"], joint=True) <DiscreteFactor representing phi(HISTORY:2, CVP:3) at 0x7f92d9f77610> >>> infer.query(variables=["HISTORY", "CVP"], joint=False) {'HISTORY': <DiscreteFactor representing phi(HISTORY:2) at 0x7f92dc61eb50>, 'CVP': <DiscreteFactor representing phi(CVP:3) at 0x7f92d915ec40>}