from __future__ import annotations
import io
import re
from dataclasses import dataclass
from typing import Any
import numpy as np
import pandas as pd
from skbase.base import BaseObject
from skbase.lookup import all_objects
from pgmpy.base import DAG
from pgmpy.causal_discovery import ExpertKnowledge
from pgmpy.utils.hf_hub import read_hf_file
@dataclass
class Dataset:
name: str
data: pd.DataFrame
expert_knowledge: ExpertKnowledge | None = None
ground_truth: DAG | None = None
tags: dict[str, Any] = None
def __str__(self) -> str:
return (
f"Dataset(name={self.name}, \n data=DataFrame of size: {self.data.shape}, \n "
f"expert_knowledge={self.expert_knowledge}, \n ground_truth={self.ground_truth}, \n tags={self.tags})"
)
def __repr__(self) -> str:
return self.__str__()
class _BaseDataset(BaseObject):
"""
Base class for all datasets in pgmpy.
Inherits from skbase.base.BaseObject to utilize its tag and lookup functionality.
"""
# define tags
_tags = {
"name": None,
"n_variables": None,
"n_samples": None,
"has_ground_truth": False,
"has_expert_knowledge": False,
"has_missing_data": False,
"has_index_col": False,
"is_simulated": False,
"is_interventional": False,
"is_discrete": False,
"is_continuous": False,
"is_mixed": False,
"is_ordinal": False,
}
base_url = ""
repo_id = "pgmpy/example_datasets"
repo_type = "dataset"
revision = "main"
@staticmethod
def _parse_expert_knowledge(raw_expert_knowledge: bytes) -> ExpertKnowledge:
"""
Helper method to parse expert knowledge from raw bytes.
"""
text = raw_expert_knowledge.decode("utf-8-sig", errors="ignore")
temporal: list[list[str]] = []
forbids: list[tuple[str, str]] = []
requires: list[tuple[str, str]] = []
section = None
for raw_line in text.splitlines():
stripped = raw_line.strip()
if not stripped:
continue
lower = stripped.lower()
# Section headers
if lower == "/knowledge":
section = None
continue
if lower == "addtemporal":
section = "addtemporal"
continue
if lower == "forbiddirect":
section = "forbiddirect"
continue
if lower in ("requiredirect", "requireddirect"):
section = "requiredirect"
continue
# Content, depending on current section
if section == "addtemporal":
# Treat lines that are just an integer as placeholders for empty lines
if stripped.isdigit():
temporal.append([])
else:
tokens = stripped.split()
# Skip the first token; its the line number
temporal.append(tokens[1:])
elif section == "forbiddirect":
tokens = stripped.split()
forbids.append((tokens[0], tokens[1]))
elif section == "requiredirect":
tokens = stripped.split()
requires.append((tokens[0], tokens[1]))
return ExpertKnowledge(forbidden_edges=forbids, required_edges=requires, temporal_order=temporal)
@classmethod
def _get_raw_data(cls, filename) -> bytes:
"""
Fetches a dataset file from the Hugging Face Hub cache.
"""
return read_hf_file(
repo_id=cls.repo_id,
filename=f"{cls.base_url}/{filename}",
repo_type=cls.repo_type,
revision=cls.revision,
)
@classmethod
def load_dataframe(cls) -> pd.DataFrame:
"""
Fetches/reads from cache the data associated with the dataset.
"""
raw_data = cls._get_raw_data(cls.data_url)
df = pd.read_csv(io.BytesIO(raw_data), sep=getattr(cls, "sep", "\t"))
if cls.get_class_tag("has_missing_data"):
df.replace(cls.missing_values_marker, pd.NA, inplace=True)
if cls.get_class_tag("has_index_col"):
df.drop(df.columns[0], axis=1, inplace=True)
if len(cls.categorical_variables) > 0:
for col in cls.categorical_variables:
df[col] = df[col].astype("category")
if len(cls.ordinal_variables) > 0:
for col, order in cls.ordinal_variables.items():
cat_type = pd.CategoricalDtype(categories=order, ordered=True)
df[col] = df[col].astype(cat_type)
return df
@classmethod
def load_expert_knowledge(cls) -> ExpertKnowledge:
"""Fetches/reads from cache the expert knowledge associated with the dataset."""
if not cls.get_class_tag("has_expert_knowledge"):
return None
raw_data = cls._get_raw_data(cls.expert_knowledge_url)
expert_knowledge = cls._parse_expert_knowledge(raw_data)
return expert_knowledge
@classmethod
def load_ground_truth(cls) -> DAG:
"""Fetches/reads from cache the ground truth DAG associated with the dataset."""
if not cls.get_class_tag("has_ground_truth"):
return None
raw_data = cls._get_raw_data(cls.ground_truth_url).decode("utf-8-sig", errors="ignore")
return DAG.from_dagitty(raw_data)
class _CovarianceMixin:
"""
This mixin class provides functionality to load datasets defined by a covariance matrix. Mainly the `load_dataframe`
method is overridden to generate data from the covariance matrix instead of loading a static data file as is the
case with `_BaseDataset`.
"""
@classmethod
def _load_covariance_matrix(cls) -> pd.DataFrame:
"""
Fetches the data and creates a covariance matrix DataFrame.
"""
raw_data = cls._get_raw_data(cls.data_url).decode("utf-8-sig", errors="ignore")
lines = raw_data.strip().splitlines()
# First replace multiple spaces with a single space and then split the line on either \t or space. Datasets are
# not uniform.
names = re.split(r"\t|\ ", re.sub(r"\s+", " ", lines[1].strip()))
mat = np.zeros((len(names), len(names)), dtype=float)
for i, line in enumerate(lines[2 : 2 + len(names)]):
vals = np.fromstring(line, sep="\t", dtype=float)
mat[i, : i + 1] = vals
mat[: i + 1, i] = vals
return pd.DataFrame(mat, columns=names, index=names)
@classmethod
def load_dataframe(cls) -> pd.DataFrame:
"""Method to create data from covariance matrix. When the `_CovarDatasetMixin is
used this method is supposed to override the _BaseDataset.load_dataframe method.
** Hence, when using this mixin, _CovarDatasetMixin should be the first parent class. **
"""
cov_matrix = cls._load_covariance_matrix()
mean = [0] * cls.get_class_tag("n_variables")
data = pd.DataFrame(
np.random.multivariate_normal(mean, cov_matrix.values, size=cls.get_class_tag("n_samples")),
columns=cov_matrix.columns,
)
return data
class _TubingenBenchmarkMixin:
"""
Mixin for Tubingen datasets that consist of multiple independent pairs/files.
URL: https://webdav.tuebingen.mpg.de/cause-effect/
"""
@classmethod
def load_dataframe(cls, pair_id: int) -> pd.DataFrame:
raw_data = cls._get_raw_data(f"pair{pair_id:04}.txt")
return pd.read_csv(io.BytesIO(raw_data), sep=r"\s+", header=None, names=["x", "y"])
@classmethod
def load_ground_truth(cls, pair_id: int) -> DAG:
raw_data = cls._get_raw_data(f"pair{pair_id:04}_graph.txt")
content = raw_data.decode("utf-8-sig", errors="ignore")
return DAG.from_dagitty(content)
[docs]
def load_dataset(name: str) -> Dataset:
"""
Load a dataset by name.
Parameters
----------
name : str
Name of the dataset to load.
Examples
--------
>>> from pgmpy.datasets import load_dataset
>>> dataset = load_dataset("sachs_mixed")
>>> df = dataset.data
>>> ground_truth = dataset.ground_truth
"""
all_datasets = all_objects(object_types=_BaseDataset, package_name="pgmpy.datasets", return_names=False)
if name.startswith("tubingen"):
name_parts = name.split("/")
if len(name_parts) == 2 and name_parts[1].isdigit():
pair_id = int(name_parts[1])
if not (1 <= pair_id <= 108):
raise ValueError(f"Tubingen pair ID must be between 1 and 108. Got {pair_id}.")
target_cls = next(
(cls for cls in all_datasets if cls.get_class_tag("name") == "tubingen"),
None,
)
df = target_cls.load_dataframe(pair_id)
gt = target_cls.load_ground_truth(pair_id)
tags = target_cls.get_class_tags()
tags["n_samples"] = df.shape[0]
tags["has_missing_data"] = bool(df.isnull().any().any())
return Dataset(
name=name,
data=df,
expert_knowledge=None,
ground_truth=gt,
tags=tags,
)
else:
raise ValueError(f"Invalid dataset name format: '{name}'. For Tubingen datasets, use 'tubingen/<pair_id>'.")
target_cls = None
for cls in all_datasets:
if cls.get_class_tag("name") == name:
target_cls = cls
break
if target_cls is None:
raise ValueError(f"Dataset with name '{name}' not found. Please use list_datasets() to see available datasets.")
return Dataset(
name=name,
data=target_cls.load_dataframe(),
expert_knowledge=target_cls.load_expert_knowledge(),
ground_truth=target_cls.load_ground_truth(),
tags=target_cls.get_class_tags(),
)
[docs]
def list_datasets(**filter_tags) -> list[str]:
"""
Returns a list of all available datasets, optionally filtered by a query string.
Parameters
----------
**filter_tags : optional arguments
If specified, returns only datasets matching the provided tag filters. Any dataset tag can be used as a filter.
Available tags:
- n_variables
- n_samples
- has_ground_truth
- has_expert_knowledge
- has_missing_data
- is_simulated
- is_interventional
- is_discrete
- is_continuous
- is_mixed
- is_ordinal
Returns
-------
list of str
A sorted list of available dataset names.
Examples
--------
>>> from pgmpy.datasets import list_datasets
>>> list_datasets()
['abalone_continuous', 'abalone_mixed', ..., 'sachs_continuous', ...]
>>> list_datasets(is_discrete=True, has_ground_truth=True)
['sachs_discrete']
"""
valid_tags = set(_BaseDataset._tags.keys())
if invalid_tags := set(filter_tags.keys()) - valid_tags:
raise ValueError(
f"Unrecognized filter argument(s): {sorted(invalid_tags)}. Valid filter tags are: {sorted(valid_tags)}."
)
all_datasets = all_objects(
object_types=_BaseDataset,
package_name="pgmpy.datasets",
return_names=False,
filter_tags=filter_tags,
)
dataset_names = [cls.get_class_tag("name") for cls in all_datasets if cls.get_class_tag("name") is not None]
return sorted(dataset_names)