Source code for pgmpy.metrics._base
import pandas as pd
from skbase.base import BaseObject
from skbase.lookup import all_objects
class _BaseSupervisedMetric(BaseObject):
"""
Base class for all metric classes in pgmpy that require ground truth causal graph.
"""
def evaluate(self, true_causal_graph, est_causal_graph, **kwargs):
"""
Evaluate the metric by comparing the true causal graph with the estimated causal graph.
Parameters
----------
true_causal_graph: Instance of type pgmpy.base
The ground truth causal graph.
est_causal_graph: Instance of type pgmpy.base
The estimated causal graph.
"""
if not isinstance(true_causal_graph, self._tags["supported_graph_types"]):
raise ValueError(
f"The true_causal_graph must be one of the following types: "
f"{self._tags['supported_graph_types']}, "
f"but got {type(true_causal_graph)} instead."
)
if not isinstance(est_causal_graph, self._tags["supported_graph_types"]):
raise ValueError(
f"The est_causal_graph must be one of the following types: "
f"{self._tags['supported_graph_types']}, "
f"but got {type(est_causal_graph)} instead."
)
if not set(true_causal_graph.nodes()) == set(est_causal_graph.nodes()):
raise ValueError("The `true_causal_graph` and `est_causal_graph` must be on the same nodes.")
return self._evaluate(
true_causal_graph=true_causal_graph,
est_causal_graph=est_causal_graph,
**kwargs,
)
def __call__(self, true_causal_graph, est_causal_graph, **kwargs):
return self.evaluate(
true_causal_graph=true_causal_graph,
est_causal_graph=est_causal_graph,
**kwargs,
)
class _BaseUnsupervisedMetric(BaseObject):
"""
Base class for all metric classes in pgmpy that do not require ground truth causal graph.
"""
def evaluate(self, X, causal_graph, **kwargs):
"""
Evaluate the metric by comparing the causal graph with the data.
Parameters
----------
X: pandas.DataFrame
The data used for evaluation.
causal_graph: Instance of type pgmpy.base
The causal graph to be evaluated.
"""
if not isinstance(causal_graph, self._tags["supported_graph_types"]):
raise ValueError(
f"The causal_graph must be one of the following types: "
f"{self._tags['supported_graph_types']}, "
f"but got {type(causal_graph)} instead."
)
if not isinstance(X, pd.DataFrame):
raise ValueError(f"The data must be a pandas.DataFrame instance, but got {type(X)} instead.")
elif len(set(causal_graph.nodes()) - set(X.columns)) > 0:
raise ValueError(
"Missing columns in data. Can't find values for the following variables: "
f" {set(causal_graph.nodes()) - set(X.columns)}"
)
return self._evaluate(X=X, causal_graph=causal_graph, **kwargs)
def __call__(self, X, causal_graph, **kwargs):
return self.evaluate(X=X, causal_graph=causal_graph, **kwargs)
[docs]
def get_metrics(**kwargs):
"""
Get metric classes matching the given tag filters.
Parameters
----------
**kwargs
Keyword arguments specifying tag filters to be passed to
:func:`skbase.lookup.all_objects` via its ``filter_tags`` parameter.
Returns
-------
Type[BaseObject] or list[Type[BaseObject]]
Metric class(es) corresponding to the given tag filters.
Raises
------
ValueError
If no metric class matching the given tag filters is found.
"""
return all_objects(
object_types=[_BaseSupervisedMetric, _BaseUnsupervisedMetric],
package_name="pgmpy.metrics",
return_names=False,
filter_tags=kwargs,
)