BaseEstimator#
- class pgmpy.estimators.BaseEstimator(data=None, state_names=None)[source]#
Bases:
objectBase class for estimators in pgmpy; ParameterEstimator, StructureEstimator and StructureScore derive from this class.
- Parameters:
- data: pandas DataFrame object
object where each column represents one variable. (If some values in the data are missing the data cells should be set to numpy.nan. Note that pandas converts each column containing numpy.nan`s to dtype `float.)
- state_names: dict (optional)
A dict indicating, for each variable, the discrete set of states (or values) that the variable can take. If unspecified, the observed values in the data set are taken to be the only possible states.
- state_counts(variable, parents=[], weighted=False, reindex=True)[source]#
Return counts how often each state of ‘variable’ occurred in the data. If a list of parents is provided, counting is done conditionally for each state configuration of the parents.
- Parameters:
- variable: string
Name of the variable for which the state count is to be done.
- parents: list
Optional list of variable parents, if conditional counting is desired. Order of parents in list is reflected in the returned DataFrame
- weighted: bool
If True, data must have a _weight column specifying the weight of the datapoint (row). If False, each datapoint has a weight of 1.
- reindex: bool
If True, returns a data frame with all possible parents state combinations as the columns. If False, drops the state combinations which are not present in the data.
- Returns:
- state_counts: pandas.DataFrame
Table with state counts for ‘variable’
Examples
>>> import pandas as pd >>> from pgmpy.estimators import BaseEstimator >>> data = pd.DataFrame( ... data={ ... "A": ["a1", "a1", "a2"], ... "B": ["b1", "b2", "b1"], ... "C": ["c1", "c1", "c2"], ... } ... ) >>> estimator = BaseEstimator(data) >>> estimator.state_counts(variable="A").values array([[2], [1]]) >>> estimator.state_counts(variable="C", parents=["A", "B"]).values array([[1., 1., 0., 0.], [0., 0., 1., 0.]])