Source code for pgmpy.factors.hybrid.FunctionalCPD
import numpy as np
import pandas as pd
from pgmpy.factors.base import BaseFactor
import pyro
[docs]
class FunctionalCPD(BaseFactor):
"""
Defines a Functional CPD.
Functional CPD can represent any arbitrary conditional probability
distribution where the distribution to represented is defined by function
(input as parameter) which calls pyro.sample function.
Parameters
----------
variable: str
Name of the variable for which this CPD is defined.
fn: callable
A lambda function that takes a dictionary of parent variable values
and returns a sampled value for the variable by calling pyro.sample.
parents: list[str], optional
List of parent variable names (default is None for no parents).
Examples
--------
# For P(X3| X1, X2) = N(0.2x1 + 0.3x2 + 1.0; 1), we can write
>>> from pgmpy.factors.hybrid import FunctionalCPD
>>> import pyro.distributions as dist
>>> cpd = FunctionalCPD(
... variable="x3",
... fn=lambda parent_sample: dist.Normal(
... 0.2 * parent_sample["x1"] + 0.3 * parent_sample["x2"] + 1.0, 1),
... parents=["x1", "x2"])
>>> cpd.variable
'x3'
>>> cpd.parents
['x1', 'x2']
"""
def __init__(self, variable, fn, parents=[]):
self.variable = variable
if not callable(fn):
raise ValueError("`fn` must be a callable function.")
self.fn = fn
self.parents = parents if parents else []
self.variables = [variable] + self.parents
[docs]
def sample(self, n_samples=100, parent_sample=None):
"""
Simulates a value for the variable based on its CPD.
Parameters:
----------
n_samples: int, (default: 100)
The number of samples to generate.
parent_sample: pandas.DataFrame, optional
A DataFrame where each column represents a parent variable and rows are samples.
Returns:
-------
sampled_values: numpy.ndarray
Array of sampled values for the variable.
Examples
--------
>>> from pgmpy.factors.hybrid import FunctionalCPD
>>> import pyro.distributions as dist
>>> cpd = FunctionalCPD(
... variable="x3",
... fn=lambda parent_sample: dist.Normal(
... 1.0 + 0.2 * parent_sample["x1"] + 0.3 * parent_sample["x2"], 1),
... parents=["x1", "x2"])
>>> parent_samples = pd.DataFrame({'x1' : [5, 10], 'x2' : [1, -1]})
>>> cpd.sample(2, parent_samples)
"""
sampled_values = []
if parent_sample is not None:
if not isinstance(parent_sample, pd.DataFrame):
raise TypeError("`parent_sample` must be a pandas DataFrame.")
if not all(parent in parent_sample.columns for parent in self.parents):
missing_parents = [
p for p in self.parents if p not in parent_sample.columns
]
raise ValueError(
f"Missing values for parent variables: {missing_parents}"
)
if len(parent_sample) != n_samples:
raise ValueError("Length of `parent_sample` must match `n_samples`.")
for i in range(n_samples):
sampled_values.append(
pyro.sample(
f"{self.variable}", self.fn(parent_sample.iloc[i, :])
).item()
)
else:
for i in range(n_samples):
sampled_values.append(
pyro.sample(f"{self.variable}", self.fn(parent_sample)).item()
)
sampled_values = np.array(sampled_values)
return sampled_values
def __str__(self):
fn_name = "lambda fun." if self.fn.__name__ == "<lambda>" else self.fn.__name__
if self.parents:
return f"P({self.variable} | {', '.join(self.parents)}) = {fn_name}"
return f"P({self.variable}) = {fn_name}"
def __repr__(self):
return f"<FunctionalCPD: {self.__str__()}> at {hex(id(self))}"