from collections import namedtuple
from itertools import product
import numpy as np
import pandas as pd
import torch
from pgmpy import config
from pgmpy.extern import tabulate
from pgmpy.factors.base import BaseFactor
from pgmpy.global_vars import logger
from pgmpy.utils import StateNameMixin, compat_fns
State = namedtuple("State", ["var", "state"])
[docs]
class DiscreteFactor(BaseFactor, StateNameMixin):
"""
Initialize a `DiscreteFactor` class.
Defined above, we have the following mapping from variable
assignments to the index of the row vector in the value field:
+-----+-----+-----+-------------------+
| x1 | x2 | x3 | phi(x1, x2, x3)|
+-----+-----+-----+-------------------+
| x1_0| x2_0| x3_0| phi.value(0) |
+-----+-----+-----+-------------------+
| x1_0| x2_0| x3_1| phi.value(1) |
+-----+-----+-----+-------------------+
| x1_0| x2_1| x3_0| phi.value(2) |
+-----+-----+-----+-------------------+
| x1_0| x2_1| x3_1| phi.value(3) |
+-----+-----+-----+-------------------+
| x1_1| x2_0| x3_0| phi.value(4) |
+-----+-----+-----+-------------------+
| x1_1| x2_0| x3_1| phi.value(5) |
+-----+-----+-----+-------------------+
| x1_1| x2_1| x3_0| phi.value(6) |
+-----+-----+-----+-------------------+
| x1_1| x2_1| x3_1| phi.value(7) |
+-----+-----+-----+-------------------+
Parameters
----------
variables: list, array-like
List of variables on which the factor is to be defined i.e. scope of the factor.
cardinality: list, array_like
List of cardinalities/no.of states of each variable. `cardinality`
array must have a value corresponding to each variable in
`variables`.
values: list, array_like
List of values of factor.
A DiscreteFactor's values are stored in a row vector in the value
using an ordering such that the left-most variables as defined in
`variables` cycle through their values the fastest. Please refer
to examples for usage examples.
Examples
--------
>>> import numpy as np
>>> from pgmpy.factors.discrete import DiscreteFactor
>>> phi = DiscreteFactor(['x1', 'x2', 'x3'], [2, 2, 2], np.ones(8))
>>> phi
<DiscreteFactor representing phi(x1:2, x2:2, x3:2) at 0x7f8188fcaa90>
>>> print(phi)
+------+------+------+-----------------+
| x1 | x2 | x3 | phi(x1,x2,x3) |
|------+------+------+-----------------|
| x1_0 | x2_0 | x3_0 | 1.0000 |
| x1_0 | x2_0 | x3_1 | 1.0000 |
| x1_0 | x2_1 | x3_0 | 1.0000 |
| x1_0 | x2_1 | x3_1 | 1.0000 |
| x1_1 | x2_0 | x3_0 | 1.0000 |
| x1_1 | x2_0 | x3_1 | 1.0000 |
| x1_1 | x2_1 | x3_0 | 1.0000 |
| x1_1 | x2_1 | x3_1 | 1.0000 |
+------+------+------+-----------------+
"""
def __init__(self, variables, cardinality, values, state_names={}):
if isinstance(variables, str):
raise TypeError("Variables: Expected type list or array like, got string")
if config.BACKEND == "numpy":
values = np.array(values, dtype=config.get_dtype())
else:
values = (
torch.Tensor(values).type(config.get_dtype()).to(config.get_device())
)
if len(cardinality) != len(variables):
raise ValueError(
"Number of elements in cardinality must be equal to number of variables"
)
if compat_fns.size(values) != np.prod(cardinality):
raise ValueError(f"Values array must be of size: {np.prod(cardinality)}")
if len(set(variables)) != len(variables):
raise ValueError("Variable names cannot be same")
if not isinstance(state_names, dict):
raise ValueError(
f"state_names must be of type dict. Got {type(state_names)}."
)
self.variables = list(variables)
self.cardinality = np.array(cardinality, dtype=int)
self.values = values.reshape(tuple(self.cardinality))
# Set the state names
super(DiscreteFactor, self).store_state_names(
variables, cardinality, state_names
)
[docs]
def scope(self):
"""
Returns the scope of the factor i.e. the variables on which the factor is defined.
Returns
-------
Scope of the factor: list
List of variables on which the factor is defined.
Examples
--------
>>> from pgmpy.factors.discrete import DiscreteFactor
>>> phi = DiscreteFactor(['x1', 'x2', 'x3'], [2, 3, 2], np.ones(12))
>>> phi.scope()
['x1', 'x2', 'x3']
"""
return self.variables
[docs]
def get_cardinality(self, variables):
"""
Returns the cardinality/no.of states of each variable in `variables`.
Parameters
----------
variables: list, array-like
A list of variable names.
Returns
-------
Cardinality of variables: dict
Dictionary of the form {variable: variable_cardinality}
Examples
--------
>>> from pgmpy.factors.discrete import DiscreteFactor
>>> phi = DiscreteFactor(['x1', 'x2', 'x3'], [2, 3, 2], range(12))
>>> phi.get_cardinality(['x1'])
{'x1': 2}
>>> phi.get_cardinality(['x1', 'x2'])
{'x1': 2, 'x2': 3}
"""
if isinstance(variables, str):
raise TypeError("variables: Expected type list or array-like, got type str")
if not all([var in self.variables for var in variables]):
raise ValueError("Variable not in scope")
return {var: self.cardinality[self.variables.index(var)] for var in variables}
[docs]
def get_value(self, **kwargs):
"""
Returns the value of the given variable states. Assumes that the arguments
specified are state names, and falls back to considering it as state no if
can't find the state name.
Parameters
----------
kwargs: named arguments of the form variable=state_name
Spcifies the state of each of the variable for which to get
the value.
Returns
-------
value of kwargs: float
The value of specified states.
Examples
--------
>>> from pgmpy.utils import get_example_model
>>> model = get_example_model("asia")
>>> phi = model.get_cpds("either").to_factor()
>>> phi.get_value(lung="yes", tub="no", either="yes")
1.0
"""
for variable in kwargs.keys():
if variable not in self.variables:
raise ValueError(f"Factor doesn't have the variable: {variable}")
index = []
for var in self.variables:
if var not in kwargs.keys():
raise ValueError(f"Variable: {var} not found in arguments")
else:
try:
index.append(self.name_to_no[var][kwargs[var]])
except KeyError:
logger.info(f"Using {var} state as number instead of name.")
index.append(kwargs[var])
return self.values[tuple(index)]
[docs]
def set_value(self, value, **kwargs):
"""
Sets the probability value of the given variable states.
Parameters
----------
value: float
The value for the specified state.
kwargs: named arguments of the form variable=state_name
Spcifies the state of each of the variable for which to get
the probability value.
Returns
-------
None
Examples
--------
>>> from pgmpy.utils import get_example_model
>>> model = get_example_model("asia")
>>> phi = model.get_cpds("either").to_factor()
>>> phi.set_value(value=0.1, lung="yes", tub="no", either="yes")
>>> phi.get_value(lung='yes', tub='no', either='yes')
0.1
"""
if not isinstance(value, (float, int)):
raise ValueError(f"value must be float. Got: {type(value)}.")
for variable in kwargs.keys():
if variable not in self.variables:
raise ValueError(f"Factor doesn't have the variable: {variable}")
index = []
for var in self.variables:
if var not in kwargs.keys():
raise ValueError(f"Variable: {var} not found in arguments")
elif isinstance(kwargs[var], str):
index.append(self.name_to_no[var][kwargs[var]])
else:
logger.info(f"Using {var} state as number instead of name.")
index.append(kwargs[var])
self.values[tuple(index)] = value
[docs]
def assignment(self, index):
"""
Returns a list of assignments (variable and state) for the corresponding index.
Parameters
----------
index: list, array-like
List of indices whose assignment is to be computed
Returns
-------
Full assignments: list
Returns a list of full assignments of all the variables of the factor.
Examples
--------
>>> import numpy as np
>>> from pgmpy.factors.discrete import DiscreteFactor
>>> phi = DiscreteFactor(['diff', 'intel'], [2, 2], np.ones(4))
>>> phi.assignment([1, 2])
[[('diff', 0), ('intel', 1)], [('diff', 1), ('intel', 0)]]
"""
if config.get_backend() == "numpy":
index = np.array(index)
else:
if (len(index) == 1) and (isinstance(index[0], torch.Tensor)):
index = index[0][None]
else:
index = torch.tensor(index, dtype=torch.int, device=config.get_device())
max_possible_index = np.prod(self.cardinality) - 1
if not all(i <= max_possible_index for i in index):
raise IndexError("Index greater than max possible index")
assignments = compat_fns.get_compute_backend().zeros(
(len(index), len(self.scope())), dtype=int
)
rev_card = self.cardinality[::-1]
for i, card in enumerate(rev_card):
assignments[:, i] = index % card
index = index // card
assignments = compat_fns.flip(assignments, axis=(1,))
return [
[
(key, self.get_state_names(key, int(val)))
for key, val in zip(self.variables, values)
]
for values in assignments
]
[docs]
def identity_factor(self):
"""
Returns the identity factor.
Def: The identity factor of a factor has the same scope and cardinality as the original factor,
but the values for all the assignments is 1. When the identity factor is multiplied with
the factor it returns the factor itself.
Returns
-------
Identity factor: pgmpy.factors.discrete.DiscreteFactor.
Returns a factor with all values set to 1.
Examples
--------
>>> from pgmpy.factors.discrete import DiscreteFactor
>>> phi = DiscreteFactor(['x1', 'x2', 'x3'], [2, 3, 2], range(12))
>>> phi_identity = phi.identity_factor()
>>> phi_identity.variables
['x1', 'x2', 'x3']
>>> phi_identity.values
array([[[ 1., 1.],
[ 1., 1.],
[ 1., 1.]],
[[ 1., 1.],
[ 1., 1.],
[ 1., 1.]]])
"""
return DiscreteFactor(
variables=self.variables,
cardinality=self.cardinality,
values=compat_fns.ones(compat_fns.size(self.values)),
state_names=self.state_names,
)
[docs]
def marginalize(self, variables, inplace=True):
"""
Modifies the factor with marginalized values.
Parameters
----------
variables: list, array-like
List of variables over which to marginalize.
inplace: boolean
If inplace=True it will modify the factor itself, else would return
a new factor.
Returns
-------
Marginalized factor: pgmpy.factors.discrete.DiscreteFactor or None
If inplace=True (default) returns None else returns a new `DiscreteFactor` instance.
Examples
--------
>>> from pgmpy.factors.discrete import DiscreteFactor
>>> phi = DiscreteFactor(['x1', 'x2', 'x3'], [2, 3, 2], range(12))
>>> phi.marginalize(['x1', 'x3'])
>>> phi.values
array([14., 22., 30.])
>>> phi.variables
['x2']
"""
if isinstance(variables, str):
raise TypeError("variables: Expected type list or array-like, got type str")
phi = self if inplace else self.copy()
for var in variables:
if var not in phi.variables:
raise ValueError(f"{var} not in scope.")
var_indexes = [phi.variables.index(var) for var in variables]
index_to_keep = sorted(set(range(len(self.variables))) - set(var_indexes))
n_variables = len(self.variables)
phi.variables = [phi.variables[index] for index in index_to_keep]
phi.cardinality = phi.cardinality[index_to_keep]
phi.del_state_names(variables)
phi.values = compat_fns.einsum(phi.values, range(n_variables), index_to_keep)
if not inplace:
return phi
[docs]
def maximize(self, variables, inplace=True):
"""
Maximizes the factor with respect to `variables`.
Parameters
----------
variables: list, array-like
List of variables with respect to which factor is to be maximized
inplace: boolean
If inplace=True it will modify the factor itself, else would return
a new factor.
Returns
-------
Maximized factor: pgmpy.factors.discrete.DiscreteFactor or None
If inplace=True (default) returns None else inplace=False returns a
new `DiscreteFactor` instance.
Examples
--------
>>> from pgmpy.factors.discrete import DiscreteFactor
>>> phi = DiscreteFactor(['x1', 'x2', 'x3'], [3, 2, 2], [0.25, 0.35, 0.08, 0.16, 0.05, 0.07,
... 0.00, 0.00, 0.15, 0.21, 0.09, 0.18])
>>> phi.variables
['x1', 'x2', 'x3']
>>> phi.maximize(['x2'])
>>> phi.variables
['x1', 'x3']
>>> phi.cardinality
array([3, 2])
>>> phi.values
array([[ 0.25, 0.35],
[ 0.05, 0.07],
[ 0.15, 0.21]])
"""
if isinstance(variables, str):
raise TypeError("variables: Expected type list or array-like, got type str")
phi = self if inplace else self.copy()
for var in variables:
if var not in phi.variables:
raise ValueError(f"{var} not in scope.")
var_indexes = [phi.variables.index(var) for var in variables]
index_to_keep = sorted(set(range(len(self.variables))) - set(var_indexes))
phi.variables = [phi.variables[index] for index in index_to_keep]
phi.cardinality = phi.cardinality[index_to_keep]
phi.del_state_names(variables)
phi.values = compat_fns.max(phi.values, axis=tuple(var_indexes))
if not inplace:
return phi
[docs]
def normalize(self, inplace=True):
"""
Normalizes the values of factor so that they sum to 1.
Parameters
----------
inplace: boolean
If inplace=True it will modify the factor itself, else would return
a new factor
Returns
-------
Normalized factor: pgmpy.factors.discrete.DiscreteFactor or None
If inplace=True (default) returns None else returns a new `DiscreteFactor` instance.
Examples
--------
>>> from pgmpy.factors.discrete import DiscreteFactor
>>> phi = DiscreteFactor(['x1', 'x2', 'x3'], [2, 3, 2], range(12))
>>> phi.values
array([[[ 0., 1.],
[ 2., 3.],
[ 4., 5.]],
[[ 6., 7.],
[ 8., 9.],
[10., 11.]]])
>>> phi.normalize()
>>> phi.variables
['x1', 'x2', 'x3']
>>> phi.cardinality
array([2, 3, 2])
>>> phi.values
array([[[ 0. , 0.01515152],
[ 0.03030303, 0.04545455],
[ 0.06060606, 0.07575758]],
[[ 0.09090909, 0.10606061],
[ 0.12121212, 0.13636364],
[ 0.15151515, 0.16666667]]])
"""
phi = self if inplace else self.copy()
phi.values = phi.values / phi.values.sum()
if not inplace:
return phi
[docs]
def reduce(self, values, inplace=True, show_warnings=True):
"""
Reduces the factor to the context of given variable values. The variables which
are reduced would be removed from the factor.
Parameters
----------
values: list, array-like
A list of tuples of the form (variable_name, variable_state).
inplace: boolean
If inplace=True it will modify the factor itself, else would return
a new factor.
show_warnings: boolean
Whether to show warning when state name not found.
Returns
-------
Reduced factor: pgmpy.factors.discrete.DiscreteFactor or None
If inplace=True (default) returns None else returns a new `DiscreteFactor` instance.
Examples
--------
>>> from pgmpy.factors.discrete import DiscreteFactor
>>> phi = DiscreteFactor(['x1', 'x2', 'x3'], [2, 3, 2], range(12))
>>> phi.reduce([('x1', 0), ('x2', 0)])
>>> phi.variables
['x3']
>>> phi.cardinality
array([2])
>>> phi.values
array([0., 1.])
"""
# Check if values is an array
if isinstance(values, str):
raise TypeError("values: Expected type list or array-like, got type str")
if not all([isinstance(state_tuple, tuple) for state_tuple in values]):
raise TypeError(
"values: Expected type list of tuples, get type {type}", type(values[0])
)
# Check if all variables in values are in the factor
for var, _ in values:
if var not in self.variables:
raise ValueError(f"The variable: {var} is not in the factor")
phi = self if inplace else self.copy()
# Convert the state names to state number. If state name not found treat them as
# state numbers.
try:
values = [
(var, self.get_state_no(var, state_name)) for var, state_name in values
]
except KeyError:
if show_warnings:
logger.warning(
"Found unknown state name. Trying to switch to using all state names as state numbers"
)
var_index_to_del = []
slice_ = [slice(None)] * len(self.variables)
for var, state in values:
var_index = phi.variables.index(var)
slice_[var_index] = state
var_index_to_del.append(var_index)
var_index_to_keep = sorted(
set(range(len(phi.variables))) - set(var_index_to_del)
)
# set difference is not guaranteed to maintain ordering
phi.variables = [phi.variables[index] for index in var_index_to_keep]
phi.cardinality = phi.cardinality[var_index_to_keep]
phi.del_state_names([var for var, _ in values])
phi.values = phi.values[tuple(slice_)]
if not inplace:
return phi
[docs]
def sum(self, phi1, inplace=True):
"""
DiscreteFactor sum with `phi1`.
Parameters
----------
phi1: float or `DiscreteFactor` instance.
If float, the value is added to each value in the factor.
DiscreteFactor to be added.
inplace: boolean
If inplace=True it will modify the factor itself, else would return
a new factor.
Returns
-------
Summed factor: pgmpy.factors.discrete.DiscreteFactor or None
If inplace=True (default) returns None else returns a new `DiscreteFactor` instance.
Examples
--------
>>> from pgmpy.factors.discrete import DiscreteFactor
>>> phi1 = DiscreteFactor(['x1', 'x2', 'x3'], [2, 3, 2], range(12))
>>> phi2 = DiscreteFactor(['x3', 'x4', 'x1'], [2, 2, 2], range(8))
>>> phi1.sum(phi2, inplace=True)
>>> phi1.variables
['x1', 'x2', 'x3', 'x4']
>>> phi1.cardinality
array([2, 3, 2, 2])
>>> phi1.values
array([[[[ 0., 2.],
[ 5., 7.]],
[[ 2., 4.],
[ 7., 9.]],
[[ 4., 6.],
[ 9., 11.]]],
[[[ 7., 9.],
[12., 14.]],
[[ 9., 11.],
[14., 16.]],
[[11., 13.],
[16., 18.]]]])
"""
phi = self if inplace else self.copy()
if isinstance(phi1, (int, float)):
phi.values += phi1
else:
phi1 = phi1.copy()
# modifying phi to add new variables
extra_vars = set(phi1.variables) - set(phi.variables)
if extra_vars:
slice_ = [slice(None)] * len(phi.variables)
slice_.extend([np.newaxis] * len(extra_vars))
phi.values = phi.values[tuple(slice_)]
phi.variables.extend(extra_vars)
new_var_card = phi1.get_cardinality(extra_vars)
phi.cardinality = np.append(
phi.cardinality, [new_var_card[var] for var in extra_vars]
)
phi.add_state_names(phi1)
# modifying phi1 to add new variables
extra_vars = set(phi.variables) - set(phi1.variables)
if extra_vars:
slice_ = [slice(None)] * len(phi1.variables)
slice_.extend([np.newaxis] * len(extra_vars))
phi1.values = phi1.values[tuple(slice_)]
phi1.variables.extend(extra_vars)
# No need to modify cardinality as we don't need it.
# rearranging the axes of phi1 to match phi
for axis in range(phi.values.ndim):
exchange_index = phi1.variables.index(phi.variables[axis])
phi1.variables[axis], phi1.variables[exchange_index] = (
phi1.variables[exchange_index],
phi1.variables[axis],
)
phi1.values = phi1.values.swapaxes(axis, exchange_index)
phi.values = phi.values + phi1.values
if not inplace:
return phi
[docs]
def product(self, phi1, inplace=True):
"""
DiscreteFactor product with `phi1`.
Parameters
----------
phi1: float or `DiscreteFactor` instance
If float, all the values are multiplied with `phi1`.
else if `DiscreteFactor` instance, mutliply based on matching rows.
inplace: boolean
If inplace=True it will modify the factor itself, else would return
a new factor.
Returns
-------
Multiplied factor: pgmpy.factors.discrete.DiscreteFactor or None
If inplace=True (default) returns None else returns a new `DiscreteFactor` instance.
Examples
--------
>>> from pgmpy.factors.discrete import DiscreteFactor
>>> phi1 = DiscreteFactor(['x1', 'x2', 'x3'], [2, 3, 2], range(12))
>>> phi2 = DiscreteFactor(['x3', 'x4', 'x1'], [2, 2, 2], range(8))
>>> phi1.product(phi2, inplace=True)
>>> phi1.variables
['x1', 'x2', 'x3', 'x4']
>>> phi1.cardinality
array([2, 3, 2, 2])
>>> phi1.values
array([[[[ 0, 0],
[ 4, 6]],
[[ 0, 4],
[12, 18]],
[[ 0, 8],
[20, 30]]],
[[[ 6, 18],
[35, 49]],
[[ 8, 24],
[45, 63]],
[[10, 30],
[55, 77]]]]
"""
phi = self if inplace else self.copy()
if isinstance(phi1, (int, float)):
phi.values *= phi1
else:
# Compute the new values
new_variables = list(set(phi.variables).union(phi1.variables))
var_to_int = {var: index for index, var in enumerate(new_variables)}
phi.values = compat_fns.einsum(
phi.values,
[var_to_int[var] for var in phi.variables],
phi1.values,
[var_to_int[var] for var in phi1.variables],
range(len(new_variables)),
)
# Compute the new cardinality array
phi_card = {var: card for var, card in zip(phi.variables, phi.cardinality)}
phi1_card = {
var: card for var, card in zip(phi1.variables, phi1.cardinality)
}
phi_card.update(phi1_card)
phi.cardinality = np.array([phi_card[var] for var in new_variables])
# Set the new variables and state names
phi.variables = new_variables
phi.add_state_names(phi1)
if not inplace:
return phi
[docs]
def divide(self, phi1, inplace=True):
"""
DiscreteFactor division by `phi1`.
Parameters
----------
phi1 : `DiscreteFactor` instance
The denominator for division.
inplace: boolean
If inplace=True it will modify the factor itself, else would return
a new factor.
Returns
-------
Divided factor: pgmpy.factors.discrete.DiscreteFactor or None
If inplace=True (default) returns None else returns a new `DiscreteFactor` instance.
Examples
--------
>>> from pgmpy.factors.discrete import DiscreteFactor
>>> phi1 = DiscreteFactor(['x1', 'x2', 'x3'], [2, 3, 2], range(12))
>>> phi2 = DiscreteFactor(['x3', 'x1'], [2, 2], range(1, 5))
>>> phi1.divide(phi2)
>>> phi1.variables
['x1', 'x2', 'x3']
>>> phi1.cardinality
array([2, 3, 2])
>>> phi1.values
array([[[ 0. , 0.33333333],
[ 2. , 1. ],
[ 4. , 1.66666667]],
[[ 3. , 1.75 ],
[ 4. , 2.25 ],
[ 5. , 2.75 ]]])
"""
phi = self if inplace else self.copy()
phi1 = phi1.copy()
if set(phi1.variables) - set(phi.variables):
raise ValueError("Scope of divisor should be a subset of dividend")
# Adding extra variables in phi1.
extra_vars = set(phi.variables) - set(phi1.variables)
if extra_vars:
slice_ = [slice(None)] * len(phi1.variables)
slice_.extend([np.newaxis] * len(extra_vars))
phi1.values = phi1.values[tuple(slice_)]
phi1.variables.extend(extra_vars)
# Rearranging the axes of phi1 to match phi
for axis in range(phi.values.ndim):
exchange_index = phi1.variables.index(phi.variables[axis])
phi1.variables[axis], phi1.variables[exchange_index] = (
phi1.variables[exchange_index],
phi1.variables[axis],
)
phi1.values = phi1.values.swapaxes(axis, exchange_index)
phi.values = phi.values / phi1.values
# If factor division 0/0 = 0 but is undefined for x/0. In pgmpy we are using
# np.inf to represent x/0 cases.
phi.values[config.get_compute_backend().isnan(phi.values)] = 0
if not inplace:
return phi
[docs]
def sample(self, n):
"""
Normalizes the factor and samples state combinations from it.
Parameters
----------
n: int
No. of samples to return
Examples
--------
>>> from pgmpy.factors.discrete import DiscreteFactor
>>> phi1 = DiscreteFactor(['x1', 'x2', 'x3'], [2, 3, 2], range(12))
>>> phi1.sample(5)
x1 x2 x3
0 1 0 0
1 0 2 0
2 1 2 0
3 1 1 1
4 1 1 1
"""
phi = self.normalize(inplace=False)
p = phi.values.ravel()
# TODO: Fix this to make it work natively in torch.
p = compat_fns.to_numpy(p)
indexes = np.random.choice(range(len(p)), size=n, p=p)
samples = []
index_to_state = {}
for index in indexes:
if index in index_to_state:
samples.append(index_to_state[index])
else:
assignment = self.assignment([index])[0]
samples.append(assignment)
index_to_state[index] = assignment
return pd.DataFrame([{k: v for k, v in s} for s in samples])
[docs]
def copy(self):
"""
Returns a copy of the factor.
Returns
-------
Copy of self: pgmpy.factors.discrete.DiscreteFactor
A copy of the original discrete factor.
Examples
--------
>>> import numpy as np
>>> from pgmpy.factors.discrete import DiscreteFactor
>>> phi = DiscreteFactor(['x1', 'x2', 'x3'], [2, 3, 3], np.arange(18))
>>> phi_copy = phi.copy()
>>> phi_copy.variables
['x1', 'x2', 'x3']
>>> phi_copy.cardinality
array([2, 3, 3])
>>> phi_copy.values
array([[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8]],
[[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]]])
"""
copy = DiscreteFactor.__new__(self.__class__)
copy.variables = [*self.variables]
copy.cardinality = np.array(self.cardinality)
copy.values = compat_fns.copy(self.values)
copy.state_names = self.state_names.copy()
copy.no_to_name = self.no_to_name.copy()
copy.name_to_no = self.name_to_no.copy()
return copy
[docs]
def is_valid_cpd(self):
"""
Checks if the factor's values can be used for a valid CPD.
"""
return config.get_compute_backend().allclose(
self.to_factor()
.marginalize(self.scope()[:1], inplace=False)
.values.flatten(),
compat_fns.ones(np.prod(self.cardinality[:0:-1])),
atol=0.01,
)
def __str__(self):
return self._str(phi_or_p="phi", tablefmt="grid")
def _str(self, phi_or_p="phi", tablefmt="grid", print_state_names=True):
"""
Generate the string from `__str__` method.
Parameters
----------
phi_or_p: 'phi' | 'p'
'phi': When used for Factors.
'p': When used for CPDs.
print_state_names: boolean
If True, the user defined state names are displayed.
"""
string_header = list(map(str, self.scope()))
string_header.append(f"{phi_or_p}({','.join(string_header)})")
value_index = 0
factor_table = []
for prob in product(*[range(card) for card in self.cardinality]):
if self.state_names and print_state_names:
prob_list = [
"{var}({state})".format(
var=list(self.variables)[i],
state=self.state_names[list(self.variables)[i]][prob[i]],
)
for i in range(len(self.variables))
]
else:
prob_list = [
f"{list(self.variables)[i]}_{prob[i]}"
for i in range(len(self.variables))
]
prob_list.append(self.values.ravel()[value_index])
factor_table.append(prob_list)
value_index += 1
return tabulate(
factor_table, headers=string_header, tablefmt=tablefmt, floatfmt=".4f"
)
def __repr__(self):
var_card = ", ".join(
[f"{var}:{card}" for var, card in zip(self.variables, self.cardinality)]
)
return f"<DiscreteFactor representing phi({var_card}) at {hex(id(self))}>"
def __mul__(self, other):
return self.product(other, inplace=False)
def __rmul__(self, other):
return self.__mul__(other)
def __add__(self, other):
return self.sum(other, inplace=False)
def __radd__(self, other):
return self.__add__(other)
def __truediv__(self, other):
return self.divide(other, inplace=False)
__div__ = __truediv__
def __eq__(self, other, atol=1e-08):
"""
Method for checking if two factors are equal.
Parameters
----------
atol: float
The maximum allowed difference in values to be considered equal.
"""
if not (isinstance(self, DiscreteFactor) and isinstance(other, DiscreteFactor)):
return False
elif set(self.scope()) != set(other.scope()):
return False
else:
# Change the axis so that the variables are in the same order.
phi = other.copy()
if self.variables != phi.variables:
for axis in range(self.values.ndim):
exchange_index = phi.variables.index(self.variables[axis])
phi.variables[axis], phi.variables[exchange_index] = (
phi.variables[exchange_index],
phi.variables[axis],
)
phi.cardinality[axis], phi.cardinality[exchange_index] = (
phi.cardinality[exchange_index],
phi.cardinality[axis],
)
phi.values = phi.values.swapaxes(axis, exchange_index)
# Check the state names order and match them
for axis, var in enumerate(self.variables):
if set(self.state_names[var]) != set(phi.state_names[var]):
return False
elif self.state_names[var] != phi.state_names[var]:
ref_index = []
for state_name in self.state_names[var]:
ref_index.append(phi.state_names[var].index(state_name))
slice_ = [slice(None)] * len(self.variables)
slice_[axis] = ref_index
phi.values = phi.values[tuple(slice_)]
if phi.values.shape != self.values.shape:
return False
elif not config.get_compute_backend().allclose(
phi.values, self.values, atol=atol
):
return False
elif not all(self.cardinality == phi.cardinality):
return False
else:
return True
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self):
variable_hashes = [hash(variable) for variable in self.variables]
sorted_var_hashes = sorted(variable_hashes)
state_names_hash = hash(frozenset(self.state_names))
phi = self.copy()
for axis in range(phi.values.ndim):
exchange_index = variable_hashes.index(sorted_var_hashes[axis])
variable_hashes[axis], variable_hashes[exchange_index] = (
variable_hashes[exchange_index],
variable_hashes[axis],
)
phi.cardinality[axis], phi.cardinality[exchange_index] = (
phi.cardinality[exchange_index],
phi.cardinality[axis],
)
phi.values = phi.values.swapaxes(axis, exchange_index)
return hash(
str(sorted_var_hashes)
+ str(hash(compat_fns.tobytes(phi.values)))
+ str(hash(compat_fns.tobytes(phi.cardinality)))
+ str(state_names_hash)
)