from typing import Any, Callable, Dict, Hashable, List, Optional, Set, Tuple, Union
import networkx as nx
import numpy as np
import pandas as pd
from skbase.utils.dependencies import _check_soft_dependencies
from pgmpy import config
from pgmpy.factors.hybrid import FunctionalCPD
from pgmpy.global_vars import logger
from pgmpy.models import DiscreteBayesianNetwork
from pgmpy.utils._safe_import import _safe_import
pyro = _safe_import("pyro", pkg_name="pyro-ppl")
[docs]
class FunctionalBayesianNetwork(DiscreteBayesianNetwork):
"""
Class for representing Functional Bayesian Network.
Functional Bayesian Networks allow for representation of any probability
distribution using CPDs in functional form (Functional CPD). Functional
CPDs return a pyro.distribution object allowing for flexible representation
of any distribution.
"""
def __init__(
self,
ebunch: Optional[List[Tuple[Hashable, Hashable]]] = None,
latents: Set[Hashable] = set(),
lavaan_str: Optional[str] = None,
dagitty_str: Optional[str] = None,
):
"""
Initializes a FunctionalBayesianNetwork.
Parameters
----------
ebunch: list
List of edges to build the Bayesian Network. Each edge should be a tuple (u, v)
where u, v are nodes representing the edge u -> v.
Examples
--------
>>> from pgmpy.models import FunctionalBayesianNetwork
>>> model = FunctionalBayesianNetwork([("x1", "x2"), ("x2", "x3")])
"""
if config.get_backend() == "numpy":
msg = (
f"{type(self)} requires pytorch backend, currently it is "
"set to numpy."
"Call pgmpy.config.set_backend('torch') to switch the backend globally."
)
logger.info(msg)
raise ValueError(msg)
_check_soft_dependencies("pyro-ppl", obj=self)
super(FunctionalBayesianNetwork, self).__init__(
ebunch=ebunch,
latents=latents,
)
[docs]
def add_cpds(self, *cpds: FunctionalCPD) -> None:
"""
Adds FunctionalCPDs to the Bayesian Network.
Parameters
----------
cpds: instances of FunctionalCPD
List of FunctionalCPDs which will be associated with the model
Examples
--------
>>> from pgmpy.factors.hybrid import FunctionalCPD
>>> from pgmpy.models import FunctionalBayesianNetwork
>>> import pyro.distributions as dist
>>> import numpy as np
>>> model = FunctionalBayesianNetwork([("x1", "x2"), ("x2", "x3")])
>>> cpd1 = FunctionalCPD("x1", lambda _: dist.Normal(0, 1))
>>> cpd2 = FunctionalCPD(
... "x2", lambda parent: dist.Normal(parent["x1"] + 2.0, 1), parents=["x1"]
... )
>>> cpd3 = FunctionalCPD(
... "x3", lambda parent: dist.Normal(parent["x2"] + 0.3, 2), parents=["x2"]
... )
>>> model.add_cpds(cpd1, cpd2, cpd3)
"""
for cpd in cpds:
if not isinstance(cpd, FunctionalCPD):
raise ValueError(
"Only FunctionalCPD can be added to Functional Bayesian Network."
)
if set(cpd.variables) - set(cpd.variables).intersection(set(self.nodes())):
raise ValueError(f"CPD defined on variable not in the model: {cpd}")
for prev_cpd_index in range(len(self.cpds)):
if self.cpds[prev_cpd_index].variable == cpd.variable:
logger.warning(f"Replacing existing CPD for {cpd.variable}")
self.cpds[prev_cpd_index] = cpd
break
else:
self.cpds.append(cpd)
[docs]
def get_cpds(
self, node: Optional[Any] = None
) -> Union[List[FunctionalCPD], FunctionalCPD]:
"""
Returns the cpd of the node. If node is not specified returns all the CPDs
that have been added till now to the graph
Parameter
---------
node: any hashable python object (optional)
The node whose CPD we want. If node not specified returns all the
CPDs added to the model.
Returns
-------
A list of Functional CPDs.
Examples
--------
>>> from pgmpy.factors.hybrid import FunctionalCPD
>>> from pgmpy.models import FunctionalBayesianNetwork
>>> import numpy as np
>>> import pyro.distributions as dist
>>> model = FunctionalBayesianNetwork([("x1", "x2"), ("x2", "x3")])
>>> cpd1 = FunctionalCPD("x1", lambda _: dist.Normal(0, 1))
>>> cpd2 = FunctionalCPD(
... "x2", lambda parent: dist.Normal(parent["x1"] + 2.0, 1), parents=["x1"]
... )
>>> cpd3 = FunctionalCPD(
... "x3", lambda parent: dist.Normal(parent["x2"] + 0.3, 2), parents=["x2"]
... )
>>> model.add_cpds(cpd1, cpd2, cpd3)
>>> model.get_cpds()
"""
return super(FunctionalBayesianNetwork, self).get_cpds(node)
[docs]
def remove_cpds(self, *cpds: FunctionalCPD) -> None:
"""
Removes the given `cpds` from the model.
Parameters
----------
*cpds: FunctionalCPD objects
A list of FunctionalCPD objects that need to be removed from the model.
Examples
--------
>>> from pgmpy.factors.hybrid import FunctionalCPD
>>> from pgmpy.models import FunctionalBayesianNetwork
>>> import numpy as np
>>> import pyro.distributions as dist
>>> model = FunctionalBayesianNetwork([("x1", "x2"), ("x2", "x3")])
>>> cpd1 = FunctionalCPD("x1", lambda _: dist.Normal(0, 1))
>>> cpd2 = FunctionalCPD(
... "x2", lambda parent: dist.Normal(parent["x1"] + 2.0, 1), parents=["x1"]
... )
>>> cpd3 = FunctionalCPD(
... "x3", lambda parent: dist.Normal(parent["x2"] + 0.3, 2), parents=["x2"]
... )
>>> model.add_cpds(cpd1, cpd2, cpd3)
>>> for cpd in model.get_cpds():
... print(cpd)
...
>>> model.remove_cpds(cpd2, cpd3)
>>> for cpd in model.get_cpds():
... print(cpd)
...
"""
return super(FunctionalBayesianNetwork, self).remove_cpds(*cpds)
[docs]
def check_model(self) -> bool:
"""
Checks the model for various errors. This method checks for the following
error -
* Checks if the CPDs associated with nodes are consistent with their parents.
Returns
-------
check: boolean
True if all the checks pass.
"""
for node in self.nodes():
cpd = self.get_cpds(node=node)
if isinstance(cpd, FunctionalCPD):
if set(cpd.parents) != set(self.get_parents(node)):
raise ValueError(
f"CPD associated with {node} doesn't have proper parents associated with it."
)
return True
[docs]
def simulate(
self,
n_samples: int = 1000,
do: Optional[Dict[Hashable, Any]] = None,
virtual_intervention: Optional[List[FunctionalCPD]] = None,
seed: Optional[int] = None,
) -> pd.DataFrame:
"""
Simulate samples from the model.
Parameters
----------
n_samples : int, optional (default: 1000)
Number of samples to generate
seed : int, optional
The seed value for the random number generator.
do : dict, optional
Specifies hard interventions to the model. The dict should be of
the form {variable: value}. Incoming edges into each intervened
variable are severed and the variable is set to the given constant
for all rows.
virtual_intervention : list[FunctionalCPD], optional
A list of unconditional FunctionalCPD objects (no parents) that
replace the corresponding node’s CPD during simulation (i.e.,
stochastic interventions like do(X ~ Normal(...))).
Returns
-------
pandas.DataFrame
Simulated samples with columns corresponding to network variables
Examples
--------
>>> from pgmpy.factors.hybrid import FunctionalCPD
>>> from pgmpy.models import FunctionalBayesianNetwork
>>> import numpy as np
>>> import pyro.distributions as dist
>>> model = FunctionalBayesianNetwork([("x1", "x2"), ("x2", "x3")])
>>> cpd1 = FunctionalCPD("x1", lambda _: dist.Normal(0, 1))
>>> cpd2 = FunctionalCPD(
... "x2", lambda parent: dist.Normal(parent["x1"] + 2.0, 1), parents=["x1"]
... )
>>> cpd3 = FunctionalCPD(
... "x3", lambda parent: dist.Normal(parent["x2"] + 0.3, 2), parents=["x2"]
... )
>>> model.add_cpds(cpd1, cpd2, cpd3)
>>> model.simulate(n_samples=1000)
"""
# Step 0: Set the seed if specified, check arguments and initialize data structures.
if seed is not None:
pyro.set_rng_seed(seed)
if do is None:
do = {}
if virtual_intervention is None:
virtual_intervention = []
# Check if all variables in do and virtual_intervention are valid
extra_do = set(do.keys()) - set(self.nodes())
if extra_do:
raise ValueError(
f"`do` contains nodes not in the model: {sorted(extra_do)}"
)
vi_map = {}
for cpd in virtual_intervention:
if not isinstance(cpd, FunctionalCPD):
raise ValueError(
"`virtual_intervention` must be a list of FunctionalCPD objects. Got {type(cpd)}"
)
if cpd.variable not in set(self.nodes()):
raise ValueError(
f"Virtual intervention CPD variable not in the model: {cpd.variable}"
)
if cpd.parents:
raise ValueError(
f"Virtual intervention CPD for {cpd.variable} must be unconditional (no parents)."
)
vi_map[cpd.variable] = cpd
overlap = set(do.keys()) & set(vi_map.keys())
if overlap:
raise ValueError(
"Cannot specify both `do` and `virtual_intervention` for the same node(s): "
f"{sorted(overlap)}"
)
nodes = list(nx.topological_sort(self))
samples = pd.DataFrame(index=range(n_samples))
# Step 1: Simulate data
for node in nodes:
# Step 1.1: Handle hard interventions
if node in do:
samples[node] = np.full(n_samples, do[node])
continue
# Step 1.2: Handle virtual interventions
if node in vi_map:
samples[node] = vi_map[node].sample(
n_samples=n_samples, parent_sample=None
)
continue
# Step 1.3: Standard sampling from the node's CPD
cpd = self.get_cpds(node)
parent_samples = samples[cpd.parents] if cpd.parents else None
samples[node] = cpd.sample(
n_samples=n_samples, parent_sample=parent_samples
)
# Step 2: Return the simulated samples
return samples
[docs]
def fit(
self,
data: pd.DataFrame,
estimator: str = "SVI",
optimizer: pyro.optim.PyroOptim = pyro.optim.Adam({"lr": 1e-2}),
prior_fn: Optional[Callable] = None,
num_steps: int = 1000,
seed: Optional[int] = None,
nuts_kwargs: Optional[Dict] = None,
mcmc_kwargs: Optional[Dict] = None,
) -> Dict[str, Any]:
"""
Fit the Bayesian network to data using Pyro's stochastic variational inference.
Parameters
----------
data: pandas.DataFrame
DataFrame with observations of variables.
estimator: str (default: "SVI")
Fitting method to use. Currently supports "SVI" and "MCMC".
optimizer: Instance of pyro optimizer (default: pyro.optim.Adam({"lr": 1e-2}))
Only used if `estimator` is "SVI". The optimizer to use for optimization.
prior_fn: function
Only used if `estimator` is "MCMC". A function that returns a dictionary of
pyro distributions for each parameter in the model.
num_steps: int (default: 100)
Number of optimization steps. For SVI it is the `num_steps`
argument for pyro.infer.SVI. For MCMC, it is the `num_samples`
argument for pyro.infer.MCMC.
seed: int (default: None)
Seed value for random number generator.
nuts_kwargs: dict (default: None)
Only used if `estimator` is "MCMC". Additional arguments to pass to
pyro.infer.NUTS.
mcmc_kwargs: dict (default: None)
Only used if `estimator` is "MCMC". Additional arguments to pass to
pyro.infer.MCMC.
Returns
-------
dict: If `estimator` is "SVI", returns a dictionary of parameter values.
If `estimator` is "MCMC", returns a dictionary of posterior samples for each parameter.
Examples
--------
>>> from pgmpy.factors.hybrid import FunctionalCPD
>>> from pgmpy.models import FunctionalBayesianNetwork
>>> import numpy as np
>>> import pyro.distributions as dist
>>> model = FunctionalBayesianNetwork([("x1", "x2")])
>>> x1 = np.random.normal(0.2, 0.8, size=10000)
>>> x2 = np.random.normal(0.6 + x1, 1)
>>> data = pd.DataFrame({"x1": x1, "x2": x2})
>>> def x1_fn(parents):
... mu = pyro.param("x1_mu", torch.tensor(1.0))
... sigma = pyro.param(
... "x1_sigma", torch.tensor(1.0), constraint=constraints.positive
... )
... return dist.Normal(mu, sigma)
...
>>> def x2_fn(parents):
... intercept = pyro.param("x2_inter", torch.tensor(1.0))
... sigma = pyro.param(
... "x2_sigma", torch.tensor(1.0), constraint=constraints.positive
... )
... return dist.Normal(intercept + parents["x1"], sigma)
...
>>> cpd1 = FunctionalCPD("x1", fn=x1_prior)
>>> cpd2 = FunctionalCPD("x2", fn=x2_prior, parents=["x1"])
>>> model.add_cpds(cpd1, cpd2)
>>> params = model.fit(data, estimator="SVI", num_steps=100)
>>> print(params)
>>> def prior_fn():
... return {
... "x1_mu": dist.Uniform(0, 1),
... "x1_sigma": dist.HalfNormal(5),
... "x2_inter": dist.Normal(1.0),
... "x2_sigma": dist.HalfNormal(1),
... }
...
>>> def x1_fn(priors, parents):
... return dist.Normal(priors["x1_mu"], priors["x1_sigma"])
...
>>> def x2_fn(priors, parents):
... return dist.Normal(
... priors["x2_inter"] + parent["x1"], priors["x2_sigma"]
... )
...
>>> cpd1 = FunctionalCPD("x1", fn=x1_fn)
>>> cpd2 = FunctionalCPD("x2", fn=x2_fn, parents=["x1"])
>>> model.add_cpds(cpd1, cpd2)
>>> params = model.fit(data, estimator="MCMC", prior_fn=prior_fn, num_steps=100)
>>> print(params["x1_mu"].mean(), params["x1_std"].mean())
"""
# Step 0: Checks for specified arguments.
if not isinstance(data, pd.DataFrame):
raise ValueError(
f"data should be a pandas.DataFrame object. Got: {type(data)}."
)
if not isinstance(num_steps, int):
raise ValueError(f"num_steps should be an integer. Got: {type(num_steps)}.")
if estimator.lower() not in ["svi", "mcmc"]:
raise ValueError(
f"`estimator` argument needs to be either 'SVI' or 'MCMC'. Got: {estimator}."
)
# Step 1: Preprocess the data and initialize data structures.
if seed is not None:
pyro.set_rng_seed(seed)
sort_nodes = list(nx.topological_sort(self))
tensor_data = {}
for node in sort_nodes:
if node not in data.columns:
raise ValueError(f"data doesn't contain column for the node: {node}.")
else:
import torch
tensor_data[node] = torch.tensor(
data[node].values,
dtype=config.get_dtype(),
device=config.get_device(),
)
nuts_kwargs = nuts_kwargs or {}
mcmc_kwargs = mcmc_kwargs or {}
cpds_dict = {node: self.get_cpds(node) for node in sort_nodes}
# Step 2: Fit the model using the specified method.
if estimator.lower() == "svi":
def guide(tensor_data):
pass
# Step 2.1: Define the combined model for SVI.
def combined_model_svi(tensor_data):
with pyro.plate("data", data.shape[0]):
for node in sort_nodes:
pyro.sample(
f"{node}",
cpds_dict[node].fn(
{p: tensor_data[p] for p in cpds_dict[node].parents}
),
obs=tensor_data[node],
)
# Step 2.2: Fit the model using SVI.
svi = pyro.infer.SVI(
model=combined_model_svi,
guide=guide,
optim=optimizer,
loss=pyro.infer.Trace_ELBO(),
)
for step in range(num_steps):
loss = svi.step(tensor_data)
if step % 50 == 0:
logger.info(f"Step {step} | Loss: {loss:.4f}")
# Step 3: Fit the model using specified estimator
elif estimator.lower() == "mcmc":
# Step 3.1: Define the combined model for MCMC.
def combined_model_mcmc(tensor_data):
priors_dists = prior_fn()
priors_vals = {
name: pyro.sample(name, d) for name, d in priors_dists.items()
}
with pyro.plate("data", data.shape[0]):
for node in sort_nodes:
dist_node = cpds_dict[node].fn(
priors_vals,
{p: tensor_data[p] for p in cpds_dict[node].parents},
)
pyro.sample(f"{node}", dist_node, obs=tensor_data[node])
# Step 3.2: Fit the model using MCMC.
nuts_kernel = pyro.infer.NUTS(combined_model_mcmc, **nuts_kwargs)
mcmc = pyro.infer.MCMC(nuts_kernel, num_samples=num_steps, **mcmc_kwargs)
mcmc.run(tensor_data)
# Step 4: Return the fitted parameter values.
if estimator.lower() == "svi":
return dict(pyro.get_param_store().items())
else:
return mcmc.get_samples()