Tree Search¶
- class pgmpy.estimators.TreeSearch(data, root_node=None, n_jobs=- 1, **kwargs)[source]¶
- static chow_liu(data, root_node, edge_weights_fn='mutual_info', n_jobs=- 1, show_progress=True)[source]¶
Chow-Liu algorithm for estimating tree structure from given data. Refer to pgmpy.estimators.TreeSearch for more details.
- Parameters:
data (pandas.DataFrame object) – dataframe object where each column represents one variable.
root_node (str, int, or any hashable python object.) – The root node of the tree structure.
n_jobs (int (default: -1)) – Number of jobs to run in parallel. -1 means use all processors.
edge_weights_fn (str or function (default: mutual_info)) –
- Method to use for computing edge weights. Options are:
’mutual_info’: Mutual Information Score.
’adjusted_mutual_info’: Adjusted Mutual Information Score.
’normalized_mutual_info’: Normalized Mutual Information Score.
function(array[n_samples,], array[n_samples,]): Custom function.
show_progress (boolean) – If True, shows a progress bar for the running algorithm.
- Returns:
Estimated Model – The estimated model structure.
- Return type:
Examples
>>> import numpy as np >>> import pandas as pd >>> from pgmpy.estimators import TreeSearch >>> values = pd.DataFrame(np.random.randint(low=0, high=2, size=(1000, 5)), ... columns=['A', 'B', 'C', 'D', 'E']) >>> est = TreeSearch(values, root_node='B') >>> model = est.estimate(estimator_type='chow-liu')
- estimate(estimator_type='chow-liu', class_node=None, edge_weights_fn='mutual_info', show_progress=True)[source]¶
Estimate the DAG structure that fits best to the given data set without parametrization.
- Parameters:
estimator_type (str (chow-liu | tan)) – The algorithm to use for estimating the DAG.
class_node (string, int or any hashable python object. (optional)) – Optional if estimator_type = ‘tan’. If None then class node is auto-picked as the node with the second highest sum of edge weights.
edge_weights_fn (str or function (default: mutual info)) – Method to use for computing edge weights. By default Mutual Info Score is used.
show_progress (boolean) – If True, shows a progress bar for the running algorithm.
- Returns:
Estimated Model – The estimated model structure.
- Return type:
Examples
>>> import numpy as np >>> import pandas as pd >>> import networkx as nx >>> import matplotlib.pyplot as plt >>> from pgmpy.estimators import TreeSearch >>> values = pd.DataFrame(np.random.randint(low=0, high=2, size=(1000, 5)), ... columns=['A', 'B', 'C', 'D', 'E']) >>> est = TreeSearch(values, root_node='B') >>> model = est.estimate(estimator_type='chow-liu') >>> nx.draw_circular(model, with_labels=True, arrowsize=20, arrowstyle='fancy', ... alpha=0.3) >>> plt.show() >>> est = TreeSearch(values) >>> model = est.estimate(estimator_type='chow-liu') >>> nx.draw_circular(model, with_labels=True, arrowsize=20, arrowstyle='fancy', ... alpha=0.3) >>> plt.show() >>> est = TreeSearch(values, root_node='B') >>> model = est.estimate(estimator_type='tan', class_node='A') >>> nx.draw_circular(model, with_labels=True, arrowsize=20, arrowstyle='fancy', ... alpha=0.3) >>> plt.show() >>> est = TreeSearch(values) >>> model = est.estimate(estimator_type='tan') >>> nx.draw_circular(model, with_labels=True, arrowsize=20, arrowstyle='fancy', ... alpha=0.3) >>> plt.show()