Tree Search¶
- class pgmpy.estimators.TreeSearch(data, root_node=None, n_jobs=-1, **kwargs)[source]¶
Search class for learning tree related graph structure. The algorithms supported are Chow-Liu and Tree-augmented naive bayes (TAN).
Chow-Liu constructs the maximum-weight spanning tree with mutual information score as edge weights.
TAN is an extension of Naive Bayes classifier to allow a tree structure over the independent variables to account for interaction.
- Parameters:
data (pandas.DataFrame object) – dataframe object where each column represents one variable.
root_node (str, int, or any hashable python object, default is None.) – The root node of the tree structure. If None then root node is auto-picked as the node with the highest sum of edge weights.
n_jobs (int (default: -1)) – Number of jobs to run in parallel. -1 means use all processors.
References
- [1] Chow, C. K.; Liu, C.N. (1968), “Approximating discrete probability
distributions with dependence trees”, IEEE Transactions on Information Theory, IT-14 (3): 462–467
- [2] Friedman N, Geiger D and Goldszmidt M (1997). Bayesian network classifiers.
Machine Learning 29: 131–163
- estimate(estimator_type='chow-liu', class_node=None, edge_weights_fn='mutual_info', show_progress=True)[source]¶
Estimate the DAG structure that fits best to the given data set without parametrization.
- Parameters:
estimator_type (str (chow-liu | tan)) – The algorithm to use for estimating the DAG.
class_node (string, int or any hashable python object. (optional)) – Needed only if estimator_type = ‘tan’. In the estimated DAG, there would be edges from class_node to each of the feature variables.
edge_weights_fn (str or function (default: mutual info)) – Method to use for computing edge weights. By default, Mutual Info Score is used.
show_progress (boolean) – If True, shows a progress bar for the running algorithm.
- Returns:
Estimated Model – The estimated model structure.
- Return type:
Examples
>>> import numpy as np >>> import pandas as pd >>> import networkx as nx >>> import matplotlib.pyplot as plt >>> from pgmpy.estimators import TreeSearch >>> values = pd.DataFrame(np.random.randint(low=0, high=2, size=(1000, 5)), ... columns=['A', 'B', 'C', 'D', 'E']) >>> est = TreeSearch(values, root_node='B') >>> model = est.estimate(estimator_type='chow-liu') >>> nx.draw_circular(model, with_labels=True, arrowsize=20, arrowstyle='fancy', ... alpha=0.3) >>> plt.show() >>> est = TreeSearch(values) >>> model = est.estimate(estimator_type='chow-liu') >>> nx.draw_circular(model, with_labels=True, arrowsize=20, arrowstyle='fancy', ... alpha=0.3) >>> plt.show() >>> est = TreeSearch(values, root_node='B') >>> model = est.estimate(estimator_type='tan', class_node='A') >>> nx.draw_circular(model, with_labels=True, arrowsize=20, arrowstyle='fancy', ... alpha=0.3) >>> plt.show() >>> est = TreeSearch(values) >>> model = est.estimate(estimator_type='tan') >>> nx.draw_circular(model, with_labels=True, arrowsize=20, arrowstyle='fancy', ... alpha=0.3) >>> plt.show()