Source code for pgmpy.example_models._base

import gzip
import io

from skbase.base import BaseObject
from skbase.lookup import all_objects

from pgmpy.base import DAG
from pgmpy.readwrite import BIFReader
from pgmpy.utils.hf_hub import read_hf_file


class _BaseExampleModel(BaseObject):
    """
    Base class for all models in pgmpy.

    Inherits from `skbase.base.BaseObject` to utilize its tag and lookup functionality.
    """

    _tags = {
        "name": bool,
        "n_nodes": None,
        "n_edges": None,
        "is_parameterized": bool,
        "is_discrete": bool,
        "is_continuous": bool,
        "is_hybrid": bool,
    }

    repo_id = "pgmpy/example_models"
    revision = "main"

    @classmethod
    def _get_raw_data(cls) -> bytes:
        """
        Fetches the model file from the Hugging Face Hub cache.
        """
        return read_hf_file(
            repo_id=cls.repo_id,
            filename=cls.data_url,
            revision=cls.revision,
        )


class DiscreteMixin:
    """
    Mixin class for loading discrete Bayesian networks from BIF files.
    """

    @classmethod
    def load_model_object(cls):
        return BIFReader(string=gzip.decompress(cls._get_raw_data()).decode("utf-8")).get_model()


class BIFMixin:
    """
    Mixin class for loading discrete Bayesian networks from plain (non-gzipped) BIF files.
    """

    @classmethod
    def load_model_object(cls):
        return BIFReader(string=cls._get_raw_data().decode("utf-8")).get_model()


class ContinuousMixin:
    """
    Mixin class for loading continuous Bayesian networks from JSON files.
    """

    @classmethod
    def load_model_object(cls):
        from pgmpy.models import LinearGaussianBayesianNetwork

        raw_data = cls._get_raw_data()

        file_obj = io.BytesIO(raw_data)

        return LinearGaussianBayesianNetwork.load(file_obj)


class DAGMixin:
    """
    Mixin class for loading DAGs from dagitty string format.
    """

    @classmethod
    def load_model_object(cls):
        return DAG.from_dagitty(string=cls._get_raw_data().decode("utf-8"))


[docs] def load_model(name: str): """ Loads an example model by name. To find all available example models, use the `list_models()` function. Parameters ---------- name : str Name of the example model to load. Returns ------- model: pgmpy.base.DAG or pgmpy.models.DiscreteBayesianNetwork or pgmpy.models.LinearGaussianBayesianNetwork or pgmpy.models.FunctionalBayesianNetwork The loaded example model. Examples -------- # Loading a discrete Bayesian network with parameters. >>> from pgmpy.example_models import load_model >>> model = load_model("bnlearn/alarm") >>> print(model) DiscreteBayesianNetwork named 'unknown' with 37 nodes and 46 edges >>> len(model.nodes()) 37 >>> model.get_cpds("HISTORY") <TabularCPD representing P(HISTORY:2 | LVFAILURE:2) at 0x7d4527a84230> # Loading a DAG without parameters. >>> model = load_model("dagitty/acid_1996") >>> print(model) DAG with 18 nodes and 22 edges >>> len(model.nodes()) 18 # Loading a continuous Bayesian network with parameters. >>> model = load_model("bnlearn/arth150") >>> print(model) LinearGaussianBayesianNetwork with 107 nodes and 150 edges # Loading a bnRep discrete Bayesian network. >>> model = load_model("bnrep/asia") >>> print(model) DiscreteBayesianNetwork named 'unknown' with 8 nodes and 8 edges """ target_model = all_objects( object_types=_BaseExampleModel, package_name="pgmpy.example_models", filter_tags={"name": name}, return_names=False, ) if not target_model: raise ValueError(f"Model with name '{name}' not found. Please use list_models() to see available datasets.") return target_model[0].load_model_object()
[docs] def list_models(**filter_tags) -> list[str]: """ Lists all available example models. The models can be filtered based on their tags by providing keyword arguments. The available tags are: - name: str - n_nodes: No. of nodes in the model. - n_edges: No. of edges in the model. - is_parameterized: Whether it is just the network structure or also has parameters (CPDs) defined. - is_discrete: Whether the model has only discrete variables / parameterization. - is_continuous: Whether the model has only continuous variables / parameterization. - is_hybrid: Whether the model has both discrete and continuous variables / parameterization. Returns ------- list List of names of all available example models. Examples -------- >>> from pgmpy.example_models import list_models >>> list_models() ['bnlearn/alarm', 'bnlearn/arth150', ..... ] >>> list_models(is_discrete=True) ['bnlearn/alarm', 'bnlearn/asia', 'bnlearn/cancer', ..... ] >>> list_models(is_parameterized=False) ['dagitty/acid_1996', ...., ] """ valid_tags = set(_BaseExampleModel._tags.keys()) if invalid_tags := set(filter_tags.keys()) - valid_tags: raise ValueError( f"Unrecognized filter argument(s): {sorted(invalid_tags)}. Valid filter tags are: {sorted(valid_tags)}." ) all_models = all_objects( object_types=_BaseExampleModel, package_name="pgmpy.example_models", return_names=False, filter_tags=filter_tags, ) model_names = [cls.get_class_tag("name") for cls in all_models if cls.get_class_tag("name") is not None] return sorted(model_names)