pyAgrum-nightly 2.3.1.9.dev202512261765915415__cp310-abi3-macosx_10_15_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. pyagrum/__init__.py +165 -0
  2. pyagrum/_pyagrum.so +0 -0
  3. pyagrum/bnmixture/BNMInference.py +268 -0
  4. pyagrum/bnmixture/BNMLearning.py +376 -0
  5. pyagrum/bnmixture/BNMixture.py +464 -0
  6. pyagrum/bnmixture/__init__.py +60 -0
  7. pyagrum/bnmixture/notebook.py +1058 -0
  8. pyagrum/causal/_CausalFormula.py +280 -0
  9. pyagrum/causal/_CausalModel.py +436 -0
  10. pyagrum/causal/__init__.py +81 -0
  11. pyagrum/causal/_causalImpact.py +356 -0
  12. pyagrum/causal/_dSeparation.py +598 -0
  13. pyagrum/causal/_doAST.py +761 -0
  14. pyagrum/causal/_doCalculus.py +361 -0
  15. pyagrum/causal/_doorCriteria.py +374 -0
  16. pyagrum/causal/_exceptions.py +95 -0
  17. pyagrum/causal/_types.py +61 -0
  18. pyagrum/causal/causalEffectEstimation/_CausalEffectEstimation.py +1175 -0
  19. pyagrum/causal/causalEffectEstimation/_IVEstimators.py +718 -0
  20. pyagrum/causal/causalEffectEstimation/_RCTEstimators.py +132 -0
  21. pyagrum/causal/causalEffectEstimation/__init__.py +46 -0
  22. pyagrum/causal/causalEffectEstimation/_backdoorEstimators.py +774 -0
  23. pyagrum/causal/causalEffectEstimation/_causalBNEstimator.py +324 -0
  24. pyagrum/causal/causalEffectEstimation/_frontdoorEstimators.py +396 -0
  25. pyagrum/causal/causalEffectEstimation/_learners.py +118 -0
  26. pyagrum/causal/causalEffectEstimation/_utils.py +466 -0
  27. pyagrum/causal/notebook.py +172 -0
  28. pyagrum/clg/CLG.py +658 -0
  29. pyagrum/clg/GaussianVariable.py +111 -0
  30. pyagrum/clg/SEM.py +312 -0
  31. pyagrum/clg/__init__.py +63 -0
  32. pyagrum/clg/canonicalForm.py +408 -0
  33. pyagrum/clg/constants.py +54 -0
  34. pyagrum/clg/forwardSampling.py +202 -0
  35. pyagrum/clg/learning.py +776 -0
  36. pyagrum/clg/notebook.py +480 -0
  37. pyagrum/clg/variableElimination.py +271 -0
  38. pyagrum/common.py +60 -0
  39. pyagrum/config.py +319 -0
  40. pyagrum/ctbn/CIM.py +513 -0
  41. pyagrum/ctbn/CTBN.py +573 -0
  42. pyagrum/ctbn/CTBNGenerator.py +216 -0
  43. pyagrum/ctbn/CTBNInference.py +459 -0
  44. pyagrum/ctbn/CTBNLearner.py +161 -0
  45. pyagrum/ctbn/SamplesStats.py +671 -0
  46. pyagrum/ctbn/StatsIndepTest.py +355 -0
  47. pyagrum/ctbn/__init__.py +79 -0
  48. pyagrum/ctbn/constants.py +54 -0
  49. pyagrum/ctbn/notebook.py +264 -0
  50. pyagrum/defaults.ini +199 -0
  51. pyagrum/deprecated.py +95 -0
  52. pyagrum/explain/_ComputationCausal.py +75 -0
  53. pyagrum/explain/_ComputationConditional.py +48 -0
  54. pyagrum/explain/_ComputationMarginal.py +48 -0
  55. pyagrum/explain/_CustomShapleyCache.py +110 -0
  56. pyagrum/explain/_Explainer.py +176 -0
  57. pyagrum/explain/_Explanation.py +70 -0
  58. pyagrum/explain/_FIFOCache.py +54 -0
  59. pyagrum/explain/_ShallCausalValues.py +204 -0
  60. pyagrum/explain/_ShallConditionalValues.py +155 -0
  61. pyagrum/explain/_ShallMarginalValues.py +155 -0
  62. pyagrum/explain/_ShallValues.py +296 -0
  63. pyagrum/explain/_ShapCausalValues.py +208 -0
  64. pyagrum/explain/_ShapConditionalValues.py +126 -0
  65. pyagrum/explain/_ShapMarginalValues.py +191 -0
  66. pyagrum/explain/_ShapleyValues.py +298 -0
  67. pyagrum/explain/__init__.py +81 -0
  68. pyagrum/explain/_explGeneralizedMarkovBlanket.py +152 -0
  69. pyagrum/explain/_explIndependenceListForPairs.py +146 -0
  70. pyagrum/explain/_explInformationGraph.py +264 -0
  71. pyagrum/explain/notebook/__init__.py +54 -0
  72. pyagrum/explain/notebook/_bar.py +142 -0
  73. pyagrum/explain/notebook/_beeswarm.py +174 -0
  74. pyagrum/explain/notebook/_showShapValues.py +97 -0
  75. pyagrum/explain/notebook/_waterfall.py +220 -0
  76. pyagrum/explain/shapley.py +225 -0
  77. pyagrum/lib/__init__.py +46 -0
  78. pyagrum/lib/_colors.py +390 -0
  79. pyagrum/lib/bn2graph.py +299 -0
  80. pyagrum/lib/bn2roc.py +1026 -0
  81. pyagrum/lib/bn2scores.py +217 -0
  82. pyagrum/lib/bn_vs_bn.py +605 -0
  83. pyagrum/lib/cn2graph.py +305 -0
  84. pyagrum/lib/discreteTypeProcessor.py +1102 -0
  85. pyagrum/lib/discretizer.py +58 -0
  86. pyagrum/lib/dynamicBN.py +390 -0
  87. pyagrum/lib/explain.py +57 -0
  88. pyagrum/lib/export.py +84 -0
  89. pyagrum/lib/id2graph.py +258 -0
  90. pyagrum/lib/image.py +387 -0
  91. pyagrum/lib/ipython.py +307 -0
  92. pyagrum/lib/mrf2graph.py +471 -0
  93. pyagrum/lib/notebook.py +1821 -0
  94. pyagrum/lib/proba_histogram.py +552 -0
  95. pyagrum/lib/utils.py +138 -0
  96. pyagrum/pyagrum.py +31495 -0
  97. pyagrum/skbn/_MBCalcul.py +242 -0
  98. pyagrum/skbn/__init__.py +49 -0
  99. pyagrum/skbn/_learningMethods.py +282 -0
  100. pyagrum/skbn/_utils.py +297 -0
  101. pyagrum/skbn/bnclassifier.py +1014 -0
  102. pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/LICENSE.md +12 -0
  103. pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/LICENSES/LGPL-3.0-or-later.txt +304 -0
  104. pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/LICENSES/MIT.txt +18 -0
  105. pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/METADATA +145 -0
  106. pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/RECORD +107 -0
  107. pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/WHEEL +4 -0
@@ -0,0 +1,298 @@
1
+ ############################################################################
2
+ # This file is part of the aGrUM/pyAgrum library. #
3
+ # #
4
+ # Copyright (c) 2005-2025 by #
5
+ # - Pierre-Henri WUILLEMIN(_at_LIP6) #
6
+ # - Christophe GONZALES(_at_AMU) #
7
+ # #
8
+ # The aGrUM/pyAgrum library is free software; you can redistribute it #
9
+ # and/or modify it under the terms of either : #
10
+ # #
11
+ # - the GNU Lesser General Public License as published by #
12
+ # the Free Software Foundation, either version 3 of the License, #
13
+ # or (at your option) any later version, #
14
+ # - the MIT license (MIT), #
15
+ # - or both in dual license, as here. #
16
+ # #
17
+ # (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) #
18
+ # #
19
+ # This aGrUM/pyAgrum library is distributed in the hope that it will be #
20
+ # useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, #
21
+ # INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS #
22
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
23
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
24
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, #
25
+ # ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR #
26
+ # OTHER DEALINGS IN THE SOFTWARE. #
27
+ # #
28
+ # See LICENCES for more details. #
29
+ # #
30
+ # SPDX-FileCopyrightText: Copyright 2005-2025 #
31
+ # - Pierre-Henri WUILLEMIN(_at_LIP6) #
32
+ # - Christophe GONZALES(_at_AMU) #
33
+ # SPDX-License-Identifier: LGPL-3.0-or-later OR MIT #
34
+ # #
35
+ # Contact : info_at_agrum_dot_org #
36
+ # homepage : http://agrum.gitlab.io #
37
+ # gitlab : https://gitlab.com/agrumery/agrum #
38
+ # #
39
+ ############################################################################
40
+
41
+ from pyagrum.explain._Explainer import Explainer
42
+ from abc import abstractmethod
43
+ from pyagrum.explain._Explanation import Explanation
44
+
45
+ # Calculations
46
+ import pandas as pd
47
+ import numpy as np
48
+
49
+ # aGrUM
50
+ import pyagrum as gum
51
+
52
+ # GL
53
+ import warnings
54
+
55
+
56
+ class ShapleyValues(Explainer):
57
+ """
58
+ The ShapleyValues class is an abstract base class for computing Shapley values in a Bayesian Network.
59
+ """
60
+
61
+ def __init__(self, bn, target, logit):
62
+ """
63
+ Parameters:
64
+ ------
65
+ bn : pyagrum.BayesNet
66
+ The Bayesian Network.
67
+ target : int | str
68
+ The node id (or node name) of the target.
69
+ logit : bool
70
+ If True, applies the logit transformation to the probabilities.
71
+
72
+ Raises:
73
+ ------
74
+ TypeError : If bn is not a gum.BayesNet or target is not an integer or string.
75
+ ValueError : If target is not a valid node id in the Bayesian Network.
76
+ """
77
+ super().__init__(bn)
78
+ if isinstance(target, str):
79
+ if target not in bn.names():
80
+ raise ValueError("Target node name '{}' not found in the Bayesian Network.".format(target))
81
+ target = bn.idFromName(target) # Convert node name to ID.
82
+ elif isinstance(target, int):
83
+ if target not in bn.nodes():
84
+ raise ValueError("Target node ID {} not found in the Bayesian Network.".format(target))
85
+ else:
86
+ raise TypeError("Target must be a node ID (int) or a node name (str), but got {}".format(type(target)))
87
+ if not isinstance(logit, bool):
88
+ warnings.warn("logit should be a boolean, unexpected calculation may occur.", UserWarning)
89
+
90
+ # Class attributes.
91
+ self.target = target # ID of the target node.
92
+ self.target_name = self.feat_names[self.target]
93
+ self._mb = self._markov_blanket()
94
+ self.ie = gum.LazyPropagation(self.bn) # Inference engine for the Bayesian Network.
95
+ self.ie.addTarget(self.target) # Setting the target for inference.
96
+ self.func = self._logit if logit else self._identity # Function to apply to the probabilities.
97
+
98
+ def _markov_blanket(self):
99
+ # Retrieves the Markov blanket of the target node.
100
+ mb = gum.MarkovBlanket(self.bn, self.target).nodes()
101
+ mb.remove(self.target)
102
+ return sorted(list(mb))
103
+
104
+ def _posterior(self, evidces: dict[int, int]):
105
+ # Returns the posterior probability of the target given the evidence.
106
+ self.ie.updateEvidence(evidces)
107
+ return self.ie.posterior(self.target).toarray()
108
+
109
+ @abstractmethod
110
+ def _shap_1dim(self, x, elements):
111
+ # Computes the Shapley values for a single instance.
112
+ # This method should be implemented in subclasses.
113
+ raise NotImplementedError("This method should be implemented in subclasses.")
114
+
115
+ @abstractmethod
116
+ def _shap_ndim(self, x, elements):
117
+ # Computes the Shapley values for multiple instances.
118
+ # This method should be implemented in subclasses.
119
+ raise NotImplementedError("This method should be implemented in subclasses.")
120
+
121
+ def compute(self, data: tuple | None, N=100):
122
+ """
123
+ Computes the Shapley values for the target node based on the provided data.
124
+
125
+ Parameters:
126
+ ----------
127
+ data : tuple | None
128
+ A tuple containing a pandas DataFrame, Series or a dictionary and a boolean indicating whether data are provided with labels. If None, a random sample of size N is generated.
129
+ N : int
130
+ The number of samples to generate if data is None.
131
+
132
+ Returns:
133
+ -------
134
+ Explanation
135
+ An Explanation object containing the Shapley values and importances for the target node.
136
+
137
+ Raises:
138
+ ------
139
+ TypeError : If the first element of data is not a pd.DataFrame, pd.Series or dict, or if N is not an integer when data is None.
140
+ ValueError : If N is less than 2 when data is None.
141
+ """
142
+ if data is None:
143
+ if not isinstance(N, int):
144
+ raise TypeError("Since df is None, N must be an integer, but got {}".format(type(N)))
145
+ if N < 2:
146
+ raise ValueError("N must be greater than 1, but got {}".format(N))
147
+ y = gum.generateSample(self.bn, N, with_labels=False)[0].reindex(columns=self.feat_names).to_numpy()
148
+ elements = [i for i in range(self.M) if i != self.target]
149
+ # Remove duplicate rows in x and unused columns.
150
+ mask_cols = [i for i in range(self.M) if i not in elements]
151
+ _, idx = np.unique(y[:, elements], axis=0, return_index=True)
152
+ y = y[idx, :]
153
+ y[:, mask_cols] = 0
154
+ contributions = self._shap_ndim(y, sorted(elements))
155
+
156
+ else:
157
+ if not isinstance(data, tuple):
158
+ raise TypeError(f"`data` must be a tuple (pd.DataFrame, bool).")
159
+ df, with_labels = data
160
+ if not isinstance(with_labels, bool):
161
+ warnings.warn(
162
+ f"The second element of `data` should be a boolean, but got {type(with_labels)}. Unexpected calculations may occur."
163
+ )
164
+ dtype = "U50" if with_labels else int
165
+
166
+ if isinstance(df, pd.Series):
167
+ # Here we are sure that df is a single instance (a Series).
168
+ s = df.dropna()
169
+ x = np.empty(self.M, dtype=dtype)
170
+ elements = []
171
+ for feat in s.index:
172
+ id = self.bn.idFromName(feat)
173
+ x[id] = s[feat]
174
+ if id != self.target:
175
+ elements.append(id)
176
+ if with_labels:
177
+ y = self._labelToPos_row(x, elements)
178
+ else:
179
+ y = x
180
+ contributions = self._shap_1dim(y, sorted(elements))
181
+
182
+ elif isinstance(df, pd.DataFrame):
183
+ df_clean = df.dropna(axis=1)
184
+ if len(df_clean) == 1:
185
+ # Here we are sure that df is a single instance (a DataFrame with one row).
186
+ x = np.empty(self.M, dtype=dtype)
187
+ elements = []
188
+ for feat in df_clean.columns:
189
+ id = self.bn.idFromName(feat)
190
+ x[id] = df_clean[feat].values[0]
191
+ if id != self.target:
192
+ elements.append(id)
193
+ if with_labels:
194
+ y = self._labelToPos_row(x, elements)
195
+ else:
196
+ y = x
197
+ contributions = self._shap_1dim(y, sorted(elements))
198
+
199
+ else:
200
+ x = np.empty((len(df_clean), self.M), dtype=dtype)
201
+ elements = []
202
+ for feat in df_clean.columns:
203
+ id = self.bn.idFromName(feat)
204
+ x[:, id] = df_clean[feat].values
205
+ if id != self.target:
206
+ elements.append(id)
207
+ # Remove duplicate rows in x and unused columns.
208
+ mask_cols = [i for i in range(self.M) if i not in elements]
209
+ _, idx = np.unique(x[:, elements], axis=0, return_index=True)
210
+ x = x[idx, :]
211
+ x[:, mask_cols] = 0
212
+ if with_labels:
213
+ y = self._labelToPos_df(x, elements)
214
+ else:
215
+ y = x
216
+ contributions = self._shap_ndim(y, sorted(elements))
217
+
218
+ elif isinstance(df, dict):
219
+ try:
220
+ N = len(list(df.values())[0])
221
+ if not isinstance(list(df.values())[0], (list, np.ndarray)):
222
+ raise TypeError("Each value in the dictionary must be a list or a numpy array.")
223
+ elements = []
224
+ x = np.empty((N, self.M), dtype=dtype)
225
+ for feat in df.keys():
226
+ if all(not (x is None) and not (isinstance(x, float) and np.isnan(x)) for x in df[feat]):
227
+ id = self.bn.idFromName(feat)
228
+ x[:, id] = df[feat]
229
+ if id != self.target:
230
+ elements.append(id)
231
+ # Remove duplicate rows in x and unused columns.
232
+ mask_cols = [i for i in range(self.M) if i not in elements]
233
+ _, idx = np.unique(x[:, elements], axis=0, return_index=True)
234
+ x = x[idx, :]
235
+ x[:, mask_cols] = 0
236
+ if with_labels:
237
+ y = self._labelToPos_df(x, elements)
238
+ else:
239
+ y = x
240
+ contributions = self._shap_ndim(y, sorted(elements))
241
+
242
+ except TypeError:
243
+ # Here we are sure that df is a single instance (a dictionary with one row).
244
+ x = np.empty(self.M, dtype=dtype)
245
+ elements = []
246
+ for feat in df.keys():
247
+ if not (df[feat] is None):
248
+ id = self.bn.idFromName(feat)
249
+ x[id] = df[feat]
250
+ if id != self.target:
251
+ elements.append(id)
252
+ if with_labels:
253
+ y = self._labelToPos_row(x, elements)
254
+ else:
255
+ y = x
256
+ contributions = self._shap_1dim(y, sorted(elements))
257
+
258
+ else:
259
+ raise TypeError(
260
+ "The first element of `data` must be a pandas DataFrame, Series or a dictionary, but got {}".format(type(df))
261
+ )
262
+
263
+ if contributions.ndim == 2:
264
+ values = {
265
+ z: {self.feat_names[i]: float(contributions[i, z]) for i in elements} for z in range(contributions.shape[1])
266
+ }
267
+ importances = {
268
+ z: {self.feat_names[i]: abs(float(contributions[i, z])) for i in elements}
269
+ for z in range(contributions.shape[1])
270
+ }
271
+ explanation = Explanation(
272
+ values,
273
+ importances,
274
+ list(self.feat_names[sorted(elements)]),
275
+ x[sorted(elements)],
276
+ self.baseline,
277
+ self.func.__name__,
278
+ "SHAP",
279
+ )
280
+ else:
281
+ values = {
282
+ z: {self.feat_names[i]: [float(v) for v in contributions[i, :, z]] for i in elements}
283
+ for z in range(contributions.shape[2])
284
+ }
285
+ mean_abs = np.mean(np.abs(contributions), axis=1)
286
+ importances = {
287
+ z: {self.feat_names[i]: abs(float(mean_abs[i, z])) for i in elements} for z in range(contributions.shape[2])
288
+ }
289
+ explanation = Explanation(
290
+ values,
291
+ importances,
292
+ list(self.feat_names[sorted(elements)]),
293
+ y[:, sorted(elements)],
294
+ self.baseline,
295
+ self.func.__name__,
296
+ "SHAP",
297
+ )
298
+ return explanation
@@ -0,0 +1,81 @@
1
+ ############################################################################
2
+ # This file is part of the aGrUM/pyAgrum library. #
3
+ # #
4
+ # Copyright (c) 2005-2025 by #
5
+ # - Pierre-Henri WUILLEMIN(_at_LIP6) #
6
+ # - Christophe GONZALES(_at_AMU) #
7
+ # #
8
+ # The aGrUM/pyAgrum library is free software; you can redistribute it #
9
+ # and/or modify it under the terms of either : #
10
+ # #
11
+ # - the GNU Lesser General Public License as published by #
12
+ # the Free Software Foundation, either version 3 of the License, #
13
+ # or (at your option) any later version, #
14
+ # - the MIT license (MIT), #
15
+ # - or both in dual license, as here. #
16
+ # #
17
+ # (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) #
18
+ # #
19
+ # This aGrUM/pyAgrum library is distributed in the hope that it will be #
20
+ # useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, #
21
+ # INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS #
22
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
23
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
24
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, #
25
+ # ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR #
26
+ # OTHER DEALINGS IN THE SOFTWARE. #
27
+ # #
28
+ # See LICENCES for more details. #
29
+ # #
30
+ # SPDX-FileCopyrightText: Copyright 2005-2025 #
31
+ # - Pierre-Henri WUILLEMIN(_at_LIP6) #
32
+ # - Christophe GONZALES(_at_AMU) #
33
+ # SPDX-License-Identifier: LGPL-3.0-or-later OR MIT #
34
+ # #
35
+ # Contact : info_at_agrum_dot_org #
36
+ # homepage : http://agrum.gitlab.io #
37
+ # gitlab : https://gitlab.com/agrumery/agrum #
38
+ # #
39
+ ############################################################################
40
+
41
+ """
42
+ This module provides classes for explaining predictions and other computations made by the bayesian networks.
43
+ """
44
+
45
+ __author__ = "Pierre-Henri Wuillemin"
46
+ __copyright__ = "(c) 2019-2025 PARIS"
47
+
48
+ # Shapley Values
49
+ from ._ShapConditionalValues import ConditionalShapValues
50
+ from ._ShapMarginalValues import MarginalShapValues
51
+ from ._ShapCausalValues import CausalShapValues
52
+ from ._Explanation import Explanation
53
+
54
+ # Shall Values
55
+ from ._ShallConditionalValues import ConditionalShallValues
56
+ from ._ShallMarginalValues import MarginalShallValues
57
+ from ._ShallCausalValues import CausalShallValues
58
+
59
+ # Independence List For Pairs
60
+ from ._explIndependenceListForPairs import independenceListForPairs
61
+
62
+ # Generalized Markov Blanket
63
+ from ._explGeneralizedMarkovBlanket import generalizedMarkovBlanket
64
+
65
+ # Entropy and Mutual Information
66
+ from ._explInformationGraph import getInformationGraph, getInformation, showInformation
67
+
68
+ __all__ = [
69
+ "ConditionalShapValues",
70
+ "MarginalShapValues",
71
+ "CausalShapValues",
72
+ "ConditionalShallValues",
73
+ "MarginalShallValues",
74
+ "CausalShallValues",
75
+ "Explanation",
76
+ "independenceListForPairs",
77
+ "generalizedMarkovBlanket",
78
+ "getInformationGraph",
79
+ "getInformation",
80
+ "showInformation",
81
+ ]
@@ -0,0 +1,152 @@
1
+ ############################################################################
2
+ # This file is part of the aGrUM/pyAgrum library. #
3
+ # #
4
+ # Copyright (c) 2005-2025 by #
5
+ # - Pierre-Henri WUILLEMIN(_at_LIP6) #
6
+ # - Christophe GONZALES(_at_AMU) #
7
+ # #
8
+ # The aGrUM/pyAgrum library is free software; you can redistribute it #
9
+ # and/or modify it under the terms of either : #
10
+ # #
11
+ # - the GNU Lesser General Public License as published by #
12
+ # the Free Software Foundation, either version 3 of the License, #
13
+ # or (at your option) any later version, #
14
+ # - the MIT license (MIT), #
15
+ # - or both in dual license, as here. #
16
+ # #
17
+ # (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) #
18
+ # #
19
+ # This aGrUM/pyAgrum library is distributed in the hope that it will be #
20
+ # useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, #
21
+ # INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS #
22
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
23
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
24
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, #
25
+ # ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR #
26
+ # OTHER DEALINGS IN THE SOFTWARE. #
27
+ # #
28
+ # See LICENCES for more details. #
29
+ # #
30
+ # SPDX-FileCopyrightText: Copyright 2005-2025 #
31
+ # - Pierre-Henri WUILLEMIN(_at_LIP6) #
32
+ # - Christophe GONZALES(_at_AMU) #
33
+ # SPDX-License-Identifier: LGPL-3.0-or-later OR MIT #
34
+ # #
35
+ # Contact : info_at_agrum_dot_org #
36
+ # homepage : http://agrum.gitlab.io #
37
+ # gitlab : https://gitlab.com/agrumery/agrum #
38
+ # #
39
+ ############################################################################
40
+
41
+ import pyagrum.lib._colors as gumcols
42
+ import matplotlib.pyplot as plt
43
+ import pydot as dot
44
+
45
+
46
+ def _buildMB(model, x: int, k: int = 1):
47
+ """
48
+ Build the nodes and arcs of Markov Blanket (of order k) of node x
49
+
50
+ Parameters
51
+ ----------
52
+ model: pyagrum.DirectedGraphicalModel
53
+ i.e. a class with methods parents, children, variable(i), idFromName(name)
54
+ x : int
55
+ the nodeId of the node for the Markov blanket
56
+ k: int
57
+ the order of the Markov blanket. If k=2, build the MarkovBlanket(MarkovBlanket())
58
+
59
+ Returns
60
+ -------
61
+ (nodes,arcs,depth) : the set of nodes, the set of arcs of the Markov Blanket and a dict[Str,int] that gives the MB-depth of each node in nodes.
62
+ """
63
+ nodes = {x}
64
+ arcs = set()
65
+ depth = dict()
66
+
67
+ def _internal_build_markov_blanket(bn, x: int, k: int):
68
+ nodes.add(x)
69
+ depth[x] = k
70
+ if k == 0:
71
+ return
72
+ for y in bn.parents(x):
73
+ visit(y, k - 1)
74
+ arcs.add((y, x))
75
+ for y in bn.children(x):
76
+ visit(y, k - 1)
77
+ arcs.add((x, y))
78
+ for z in bn.parents(y):
79
+ visit(z, k - 1)
80
+ arcs.add((z, y))
81
+
82
+ def visit(x, k):
83
+ if x in nodes and depth[x] >= k:
84
+ return
85
+ _internal_build_markov_blanket(model, x, k)
86
+
87
+ _internal_build_markov_blanket(model, x, k)
88
+ return nodes, arcs, depth
89
+
90
+
91
+ def generalizedMarkovBlanket(bn, var: int | str, k: int = 1, cmapNode=None):
92
+ """
93
+ Build a pydot.Dot representation of the nested Markov Blankets (of order k) of node x
94
+
95
+ Warnings
96
+ --------
97
+ It is assumed that k<=8. If not, every thing is fine except that the colorscale will change in order to accept more colors.
98
+
99
+ Parameters
100
+ ----------
101
+ bn: pyagrum.DirectedGraphicalModel
102
+ i.e. a class with methods parents, children, variable(i), idFromName(name)
103
+ var : str|int
104
+ the name or nodeId of the node for the Markov blanket
105
+ k: int
106
+ the order of the Markov blanket. If k=1, build the MarkovBlanket(MarkovBlanket())
107
+ cmap: maplotlib.ColorMap
108
+ the colormap used (if not, inferno is used)
109
+
110
+ Returns
111
+ -------
112
+ pydotplus.Dot object
113
+ """
114
+ if cmapNode is None:
115
+ cmapNode = plt.get_cmap("inferno") # gum.config["notebook", "default_arc_cmap"])
116
+
117
+ maxcols = max(
118
+ 8, k
119
+ ) # It is assumed that k<=8. If not, every thing is fine except that the colorscale will change in order to accept more colors.
120
+
121
+ mb = dot.Dot(f"MB({var},{k}", graph_type="digraph", bgcolor="transparent")
122
+
123
+ if isinstance(var, str):
124
+ nx = bn.idFromName(var)
125
+ else:
126
+ nx = var
127
+ nodes, arcs, visited = _buildMB(bn, nx, k)
128
+ names = dict()
129
+
130
+ for n in nodes:
131
+ protected_name = f'"{bn.variable(n).name()}"'
132
+ pnode = dot.Node(protected_name, style="filled")
133
+ if n == var:
134
+ bgcol = "#99FF99"
135
+ fgcol = "black"
136
+ else:
137
+ bgcol = gumcols.proba2bgcolor(1 - (k - visited[n]) / maxcols, cmapNode)
138
+ fgcol = gumcols.proba2fgcolor(1 - (k - visited[n]) / maxcols, cmapNode)
139
+ pnode.set_fillcolor(bgcol)
140
+ pnode.set_fontcolor(fgcol)
141
+ mb.add_node(pnode)
142
+ names[n] = protected_name
143
+ for n in nodes:
144
+ for u in bn.parents(n).intersection(nodes):
145
+ edge = dot.Edge(names[u], names[n])
146
+ if (u, n) in arcs:
147
+ edge.set_color("black")
148
+ else:
149
+ edge.set_color("#DDDDDD")
150
+ mb.add_edge(edge)
151
+
152
+ return mb
@@ -0,0 +1,146 @@
1
+ ############################################################################
2
+ # This file is part of the aGrUM/pyAgrum library. #
3
+ # #
4
+ # Copyright (c) 2005-2025 by #
5
+ # - Pierre-Henri WUILLEMIN(_at_LIP6) #
6
+ # - Christophe GONZALES(_at_AMU) #
7
+ # #
8
+ # The aGrUM/pyAgrum library is free software; you can redistribute it #
9
+ # and/or modify it under the terms of either : #
10
+ # #
11
+ # - the GNU Lesser General Public License as published by #
12
+ # the Free Software Foundation, either version 3 of the License, #
13
+ # or (at your option) any later version, #
14
+ # - the MIT license (MIT), #
15
+ # - or both in dual license, as here. #
16
+ # #
17
+ # (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) #
18
+ # #
19
+ # This aGrUM/pyAgrum library is distributed in the hope that it will be #
20
+ # useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, #
21
+ # INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS #
22
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
23
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
24
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, #
25
+ # ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR #
26
+ # OTHER DEALINGS IN THE SOFTWARE. #
27
+ # #
28
+ # See LICENCES for more details. #
29
+ # #
30
+ # SPDX-FileCopyrightText: Copyright 2005-2025 #
31
+ # - Pierre-Henri WUILLEMIN(_at_LIP6) #
32
+ # - Christophe GONZALES(_at_AMU) #
33
+ # SPDX-License-Identifier: LGPL-3.0-or-later OR MIT #
34
+ # #
35
+ # Contact : info_at_agrum_dot_org #
36
+ # homepage : http://agrum.gitlab.io #
37
+ # gitlab : https://gitlab.com/agrumery/agrum #
38
+ # #
39
+ ############################################################################
40
+
41
+ import pyagrum as gum
42
+ import pylab
43
+ import matplotlib as mpl
44
+ import itertools
45
+
46
+
47
+ def _independenceListForPairs(bn, target=None):
48
+ """
49
+ returns a list of triples `(i,j,k)` for each non arc `(i,j)` such that `i` is independent of `j` given `k`.
50
+
51
+ Parameters
52
+ ----------
53
+ bn: gum.BayesNet
54
+ the Bayesian Network
55
+
56
+ target: (optional) str or int
57
+ the name or id of the target variable. If a target is given, only the independence given a subset of the markov blanket of the target are tested.
58
+
59
+ Returns
60
+ -------
61
+ List[(str,str,List[str])]
62
+ A list of independence found in the structure of BN.
63
+ """
64
+
65
+ def powerset(iterable):
66
+ xs = list(iterable)
67
+ # note we return an iterator rather than a list
68
+ return itertools.chain.from_iterable(itertools.combinations(xs, n) for n in range(len(xs) + 1))
69
+
70
+ # testing every d-separation
71
+ l = []
72
+ nams = sorted(bn.names())
73
+ if target is None:
74
+ firstnams = nams.copy()
75
+ indepnodes = bn.names()
76
+ else:
77
+ indepnodes = {bn.variable(i).name() for i in gum.MarkovBlanket(bn, target).nodes()}
78
+ if isinstance(target, str):
79
+ firstnams = [target]
80
+ else:
81
+ firstnams = [bn.variable(target).name()]
82
+
83
+ for i in firstnams:
84
+ nams.remove(i)
85
+ for j in nams:
86
+ if not (bn.existsArc(i, j) or bn.existsArc(j, i)):
87
+ for k in powerset(sorted(indepnodes - {i, j})):
88
+ if bn.isIndependent(i, j, k):
89
+ l.append((i, j, tuple(k)))
90
+ break
91
+ return l
92
+
93
+
94
+ def independenceListForPairs(bn, filename, target=None, plot=True, alphabetic=False):
95
+ """
96
+ get the p-values of the chi2 test of a (as simple as possible) independence proposition for every non arc.
97
+
98
+ Parameters
99
+ ----------
100
+ bn : gum.BayesNet
101
+ the Bayesian network
102
+
103
+ filename : str
104
+ the name of the csv database
105
+
106
+ alphabetic : bool
107
+ if True, the list is alphabetically sorted else it is sorted by the p-value
108
+
109
+ target: (optional) str or int
110
+ the name or id of the target variable
111
+
112
+ plot : bool
113
+ if True, plot the result
114
+
115
+ Returns
116
+ -------
117
+ the list
118
+ """
119
+
120
+ learner = gum.BNLearner(filename, bn)
121
+ vals = {}
122
+ for indep in _independenceListForPairs(bn, target):
123
+ vals[indep] = learner.chi2(*indep)[1]
124
+
125
+ if plot:
126
+ plotvals = dict()
127
+ for indep in vals:
128
+ key = "$" + indep[0] + " \\perp " + indep[1]
129
+ if len(indep[2]) > 0:
130
+ key += " \\mid " + ",".join(indep[2])
131
+ key += "$"
132
+ plotvals[key] = vals[indep]
133
+
134
+ if not alphabetic:
135
+ sortedkeys = sorted(plotvals, key=plotvals.__getitem__, reverse=False)
136
+ else:
137
+ sortedkeys = list(plotvals.keys())
138
+
139
+ fig = pylab.figure(figsize=(10, 1 + 0.25 * len(plotvals)))
140
+ ax = fig.add_subplot(1, 1, 1)
141
+ ax.plot([plotvals[k] for k in sortedkeys], sortedkeys, "o")
142
+ ax.grid(True)
143
+ ax.vlines(x=0.05, ymin=-0.5, ymax=len(vals) - 0.5, colors="purple")
144
+ ax.add_patch(mpl.patches.Rectangle((0, -0.5), 0.05, len(vals), color="yellow"))
145
+
146
+ return vals