Functional Bayesian Networks#
Functional Bayesian Networks (FBNs) are Bayesian networks where each CPD is a Python function that returns a Pyro distribution. This lets you model arbitrary discrete, continuous, or mixed relationships, while keeping standard graph semantics for sampling, interventions, and learning.
Similar to other Bayesian Network classes in pgmpy, there are two main components of the model:
The graphical structure of the model,
The parameterization of the model (defined using
FunctionalCPD).
This tutorial introduces Functional Bayesian Networks (FBNs) and the accompanying Functional CPDs in pgmpy. You’ll learn how to:
Define FunctionalCPDs as functions that return Pyro distributions.
Build FunctionalBayesianNetworks on mixed data.
Simulate data from the model.
Use vectorized CPDs for performance.
Fit simple parametric FBNs using SVI and MCMC via Pyro.
[1]:
import pandas as pd
import torch
import pyro
import pyro.distributions as dist
import pgmpy
pgmpy.config.set_backend("torch")
from pgmpy.models import FunctionalBayesianNetwork
from pgmpy.factors.hybrid import FunctionalCPD
# Reproducibility
pyro.set_rng_seed(123)
Functional CPDs: the core idea#
A Functional CPD has three arguments:
variable: The variable for the FunctionalCPD is being defined. For a fully parameterized model, each node in the model needs aFunctionalCPDassociated with it.fn: A Python callable that takes a dictionary of parents’ values as input and returns a Pyro distribution.parents: The parents ofvariablein the model.
Example 1: A simple Gaussian chain: x1 → x2 → x3#
[2]:
# Build the structure
gauss_chain = FunctionalBayesianNetwork([("x1", "x2"), ("x2", "x3")])
# Define CPDs
cpd_x1 = FunctionalCPD("x1", fn=lambda _: dist.Normal(0.0, 1.0))
cpd_x2 = FunctionalCPD("x2", fn=lambda parents: dist.Normal(1.0 + 0.8 * parents["x1"], 0.5), parents=["x1"])
cpd_x3 = FunctionalCPD("x3", fn=lambda parents: dist.Normal(0.3 + 1.0 * parents["x2"], 1.0), parents=["x2"])
gauss_chain.add_cpds(cpd_x1, cpd_x2, cpd_x3)
gauss_chain.check_model()
# Draw a few samples
samples_gc = gauss_chain.simulate(n_samples=5, seed=123)
samples_gc
[2]:
| x1 | x2 | x3 | |
|---|---|---|---|
| 0 | -0.111467 | 1.015461 | 1.525792 |
| 1 | 0.120363 | 0.610113 | 0.519270 |
| 2 | -0.369635 | 0.326770 | 0.861743 |
| 3 | -0.240418 | 0.969617 | 1.934877 |
| 4 | -1.196924 | -0.011801 | 0.641020 |
Example 2: Complex model with mixture data#
[3]:
complex_bn = FunctionalBayesianNetwork(
[("x1", "w"), ("x2", "w"), ("x1", "y"),
("x2", "y"), ("w", "y"), ("y", "z"),
("w", "z"), ("y", "c"), ("w", "c")]
)
# Roots
cpd_x1 = FunctionalCPD("x1", fn=lambda _: dist.Normal(0.0, 1.0))
cpd_x2 = FunctionalCPD("x2", fn=lambda _: dist.Normal(0.5, 1.2))
# Continuous mediator: w = 0.7*x1 - 0.3*x2 + ε
cpd_w = FunctionalCPD(
"w",
fn=lambda parents: dist.Normal(0.7 * parents["x1"] - 0.3 * parents["x2"], 0.5),
parents=["x1", "x2"]
)
# Bernoulli target with logistic link: y ~ Bernoulli(sigmoid(-0.7 + 1.5*x1 + 0.8*x2 + 1.2*w))
cpd_y = FunctionalCPD(
"y",
fn=lambda parents: dist.Bernoulli(logits=(-0.7 + 1.5 * parents["x1"] + 0.8 * parents["x2"] + 1.2 * parents["w"])),
parents=["x1", "x2", "w"]
)
# Downstream Bernoulli influenced by y and w
cpd_z = FunctionalCPD(
"z",
fn=lambda parents: dist.Bernoulli(logits=(-1.2 + 0.8 * parents["y"] + 0.2 * parents["w"])),
parents=["y", "w"]
)
# Continuous outcome depending on y and w: c = 0.2 + 0.5*y + 0.3*w + ε
cpd_c = FunctionalCPD(
"c",
fn=lambda parents: dist.Normal(0.2 + 0.5 * parents["y"] + 0.3 * parents["w"], 0.7),
parents=["y", "w"]
)
complex_bn.add_cpds(cpd_x1, cpd_x2, cpd_w, cpd_y, cpd_z, cpd_c)
complex_bn.check_model()
# Simulate data from it
complex_bn.simulate(n_samples=8, seed=123)
[3]:
| x1 | x2 | w | y | z | c | |
|---|---|---|---|---|---|---|
| 0 | -0.111467 | 0.888683 | -0.363940 | 0.0 | 0.0 | 0.848689 |
| 1 | 0.120363 | 0.369773 | -0.469728 | 0.0 | 1.0 | -0.909814 |
| 2 | -0.369635 | 0.752397 | -0.719913 | 0.0 | 1.0 | 0.266363 |
| 3 | -0.240418 | 0.030989 | -0.391040 | 0.0 | 0.0 | -0.774038 |
| 4 | -1.196924 | 0.781968 | -1.086602 | 0.0 | 1.0 | 0.365244 |
| 5 | 0.209269 | 1.298313 | 0.468003 | 1.0 | 0.0 | 2.390150 |
| 6 | -0.972355 | 0.923385 | -1.151972 | 0.0 | 0.0 | -0.164933 |
| 7 | -0.755045 | 1.667385 | -1.473914 | 0.0 | 0.0 | 0.495661 |
Vectorized CPDs for speed#
Set vectorized=True and have your fn(parent_df) return a batched Pyro distribution whose batch size equals the number of rows in the provided parent DataFrame. This makes sampling much faster for large n_samples.
[4]:
from pgmpy import config
vec_bn = FunctionalBayesianNetwork([("x1", "x2")])
cpd_x1 = FunctionalCPD("x1", fn=lambda _: dist.Normal(0.0, 1.0))
def x2_fn_vec(P):
x1 = torch.tensor(P["x1"].values, dtype=config.get_dtype(), device=config.get_device())
mu = 0.5 + 0.9 * x1
sigma = torch.full_like(mu, 0.3)
return dist.Normal(mu, sigma)
cpd_x2 = FunctionalCPD("x2", fn=x2_fn_vec, parents=["x1"], vectorized=True)
vec_bn.add_cpds(cpd_x1, cpd_x2)
vec_bn.check_model()
# Large draw to highlight performance of vectorized CPDs
vec_samples = vec_bn.simulate(n_samples=5000, seed=123)
vec_samples.head()
[4]:
| x1 | x2 | |
|---|---|---|
| 0 | -0.111467 | 0.796616 |
| 1 | 0.120363 | 0.786959 |
| 2 | -0.369635 | -0.077290 |
| 3 | -0.240418 | 0.097413 |
| 4 | -1.196924 | -0.559637 |
Parameter learning with SVI#
When CPDs contain Pyro parameters (pyro.param(...)), you can fit them to data using model.fit(..., estimator="SVI"). Below, we synthesize data from a simple linear-Gaussian model and then recover the parameters.
[5]:
# Generate synthetic data
true_mu, true_sigma = 0.8, 0.6
N = 2000
x1 = torch.normal(mean=true_mu, std=true_sigma, size=(N,))
# FIX: vectorized draw for x2 (no size when mean is a tensor)
x2 = torch.normal(mean=1.2 + x1, std=0.7) # or: x2 = 1.2 + x1 + 0.7 * torch.randn_like(x1)
data = pd.DataFrame({"x1": x1.numpy(), "x2": x2.numpy()})
from torch.distributions import constraints
import pyro, pyro.distributions as dist
pyro.clear_param_store() # helpful if you re-run the cell
svi_bn = FunctionalBayesianNetwork([("x1", "x2")])
def x1_fn(_):
mu = pyro.param("x1_mu", torch.tensor(0.0))
sigma = pyro.param("x1_sigma", torch.tensor(1.0), constraint=constraints.positive)
return dist.Normal(mu, sigma)
def x2_fn(p):
inter = pyro.param("x2_inter", torch.tensor(0.0))
sigma = pyro.param("x2_sigma", torch.tensor(1.0), constraint=constraints.positive)
return dist.Normal(inter + p["x1"], sigma)
svi_bn.add_cpds(
FunctionalCPD("x1", fn=x1_fn),
FunctionalCPD("x2", fn=x2_fn, parents=["x1"]),
)
svi_bn.check_model()
# Fit with SVI
params_svi = svi_bn.fit(data, num_steps=300)
{k: v.item() if torch.is_tensor(v) and v.ndim == 0 else v for k, v in params_svi.items()}
[5]:
{'x1_mu': 0.7751224637031555,
'x1_sigma': 0.6131476759910583,
'x2_inter': 1.1970738172531128,
'x2_sigma': 0.6990763545036316}
Bayesian inference with MCMC#
estimator="MCMC".[6]:
mcmc_bn = FunctionalBayesianNetwork([("x1", "x2")])
# Priors with matched dtype/device (avoid dtype mismatches)
dtype = config.get_dtype()
device = config.get_device()
def prior_fn():
t = lambda v: torch.tensor(v, dtype=dtype, device=device)
return {
"x1_mu": dist.Normal(t(0.0), t(5.0)),
"x1_sigma": dist.HalfNormal(t(2.0)),
"x2_inter": dist.Normal(t(0.0), t(5.0)),
"x2_sigma": dist.HalfNormal(t(2.0)),
}
# CPDs consume *sampled* prior values (from the model) + parents
def x1_fn_prior(priors, _):
return dist.Normal(priors["x1_mu"], priors["x1_sigma"])
def x2_fn_prior(priors, parents):
return dist.Normal(priors["x2_inter"] + parents["x1"], priors["x2_sigma"])
mcmc_bn.add_cpds(
FunctionalCPD("x1", fn=x1_fn_prior),
FunctionalCPD("x2", fn=x2_fn_prior, parents=["x1"]),
)
pyro.clear_param_store()
post = mcmc_bn.fit(
data,
estimator="MCMC",
prior_fn=prior_fn,
num_steps=200,
nuts_kwargs={"target_accept_prob": 0.8},
mcmc_kwargs={"num_chains": 1, "warmup_steps": 200},
)
# Peek at posterior summaries
{k: (v.mean().item(), v.std().item()) for k, v in post.items() if torch.is_tensor(v)}
Sample: 100%|█████████████████████████████████████████| 400/400 [00:01, 221.79it/s, step size=6.89e-01, acc. prob=0.881]
[6]:
{'x1_mu': (0.7773640349223626, 0.014470493417145083),
'x1_sigma': (0.612960799064066, 0.010122547469200738),
'x2_inter': (1.197045500167558, 0.0148845266501543),
'x2_sigma': (0.7008062278971664, 0.01043389183451872)}
Interventions and conditioning (preview)#
do(X = value) severs incoming edges into ``X`` and replaces its CPD with a constant or a new distribution.FunctionalCPD that ignores parents and returns dist.Delta(value) (or any desired distribution).Conditioning on continuous point evidence should use likelihood weighting or proper inference (SVI/MCMC), not plain rejection sampling. The high-level recipe for likelihood weighting is:
Sample all non-evidence nodes in topological order (respecting any
do(...)replacements).Clamp evidence nodes to their observed values.
Weight each draw by the product of evidence likelihoods under their parents’ sampled values.
Normalize weights and either keep weighted samples or resample for an unweighted posterior sample.
Future versions may expose simulate(do=..., evidence=...) directly in the API.
Key takeaways#
Functional CPDs let you specify any distribution you can write in Pyro.
Mixed types (discrete/continuous) are straightforward.
Use vectorized CPDs for performance on large simulations.
For learning: quick SVI with
pyro.param(...)or fully Bayesian MCMC withprior_fn.