Source code for pgmpy.example_models._base
import gzip
import io
from skbase.base import BaseObject
from skbase.lookup import all_objects
from pgmpy.base import DAG
from pgmpy.readwrite import BIFReader
from pgmpy.utils.hf_hub import read_hf_file
class _BaseExampleModel(BaseObject):
"""
Base class for all models in pgmpy.
Inherits from `skbase.base.BaseObject` to utilize its tag and lookup functionality.
"""
_tags = {
"name": bool,
"n_nodes": None,
"n_edges": None,
"is_parameterized": bool,
"is_discrete": bool,
"is_continuous": bool,
"is_hybrid": bool,
}
repo_id = "pgmpy/example_models"
revision = "main"
@classmethod
def _get_raw_data(cls) -> bytes:
"""
Fetches the model file from the Hugging Face Hub cache.
"""
return read_hf_file(
repo_id=cls.repo_id,
filename=cls.data_url,
revision=cls.revision,
)
class DiscreteMixin:
"""
Mixin class for loading discrete Bayesian networks from BIF files.
"""
@classmethod
def load_model_object(cls):
return BIFReader(string=gzip.decompress(cls._get_raw_data()).decode("utf-8")).get_model()
class BIFMixin:
"""
Mixin class for loading discrete Bayesian networks from plain (non-gzipped) BIF files.
"""
@classmethod
def load_model_object(cls):
return BIFReader(string=cls._get_raw_data().decode("utf-8")).get_model()
class ContinuousMixin:
"""
Mixin class for loading continuous Bayesian networks from JSON files.
"""
@classmethod
def load_model_object(cls):
from pgmpy.models import LinearGaussianBayesianNetwork
raw_data = cls._get_raw_data()
file_obj = io.BytesIO(raw_data)
return LinearGaussianBayesianNetwork.load(file_obj)
class DAGMixin:
"""
Mixin class for loading DAGs from dagitty string format.
"""
@classmethod
def load_model_object(cls):
return DAG.from_dagitty(string=cls._get_raw_data().decode("utf-8"))
[docs]
def load_model(name: str):
"""
Loads an example model by name.
To find all available example models, use the `list_models()` function.
Parameters
----------
name : str
Name of the example model to load.
Returns
-------
model: pgmpy.base.DAG or pgmpy.models.DiscreteBayesianNetwork or pgmpy.models.LinearGaussianBayesianNetwork or
pgmpy.models.FunctionalBayesianNetwork
The loaded example model.
Examples
--------
# Loading a discrete Bayesian network with parameters.
>>> from pgmpy.example_models import load_model
>>> model = load_model("bnlearn/alarm")
>>> print(model)
DiscreteBayesianNetwork named 'unknown' with 37 nodes and 46 edges
>>> len(model.nodes())
37
>>> model.get_cpds("HISTORY")
<TabularCPD representing P(HISTORY:2 | LVFAILURE:2) at 0x7d4527a84230>
# Loading a DAG without parameters.
>>> model = load_model("dagitty/acid_1996")
>>> print(model)
DAG with 18 nodes and 22 edges
>>> len(model.nodes())
18
# Loading a continuous Bayesian network with parameters.
>>> model = load_model("bnlearn/arth150")
>>> print(model)
LinearGaussianBayesianNetwork with 107 nodes and 150 edges
# Loading a bnRep discrete Bayesian network.
>>> model = load_model("bnrep/asia")
>>> print(model)
DiscreteBayesianNetwork named 'unknown' with 8 nodes and 8 edges
"""
target_model = all_objects(
object_types=_BaseExampleModel,
package_name="pgmpy.example_models",
filter_tags={"name": name},
return_names=False,
)
if not target_model:
raise ValueError(f"Model with name '{name}' not found. Please use list_models() to see available datasets.")
return target_model[0].load_model_object()
[docs]
def list_models(**filter_tags) -> list[str]:
"""
Lists all available example models.
The models can be filtered based on their tags by providing keyword arguments. The available tags are:
- name: str
- n_nodes: No. of nodes in the model.
- n_edges: No. of edges in the model.
- is_parameterized: Whether it is just the network structure or also has parameters (CPDs) defined.
- is_discrete: Whether the model has only discrete variables / parameterization.
- is_continuous: Whether the model has only continuous variables / parameterization.
- is_hybrid: Whether the model has both discrete and continuous variables / parameterization.
Returns
-------
list
List of names of all available example models.
Examples
--------
>>> from pgmpy.example_models import list_models
>>> list_models()
['bnlearn/alarm', 'bnlearn/arth150', ..... ]
>>> list_models(is_discrete=True)
['bnlearn/alarm', 'bnlearn/asia', 'bnlearn/cancer', ..... ]
>>> list_models(is_parameterized=False)
['dagitty/acid_1996', ...., ]
"""
valid_tags = set(_BaseExampleModel._tags.keys())
if invalid_tags := set(filter_tags.keys()) - valid_tags:
raise ValueError(
f"Unrecognized filter argument(s): {sorted(invalid_tags)}. Valid filter tags are: {sorted(valid_tags)}."
)
all_models = all_objects(
object_types=_BaseExampleModel,
package_name="pgmpy.example_models",
return_names=False,
filter_tags=filter_tags,
)
model_names = [cls.get_class_tag("name") for cls in all_models if cls.get_class_tag("name") is not None]
return sorted(model_names)