pyAgrum-nightly 2.3.1.9.dev202512261765915415__cp310-abi3-macosx_10_15_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. pyagrum/__init__.py +165 -0
  2. pyagrum/_pyagrum.so +0 -0
  3. pyagrum/bnmixture/BNMInference.py +268 -0
  4. pyagrum/bnmixture/BNMLearning.py +376 -0
  5. pyagrum/bnmixture/BNMixture.py +464 -0
  6. pyagrum/bnmixture/__init__.py +60 -0
  7. pyagrum/bnmixture/notebook.py +1058 -0
  8. pyagrum/causal/_CausalFormula.py +280 -0
  9. pyagrum/causal/_CausalModel.py +436 -0
  10. pyagrum/causal/__init__.py +81 -0
  11. pyagrum/causal/_causalImpact.py +356 -0
  12. pyagrum/causal/_dSeparation.py +598 -0
  13. pyagrum/causal/_doAST.py +761 -0
  14. pyagrum/causal/_doCalculus.py +361 -0
  15. pyagrum/causal/_doorCriteria.py +374 -0
  16. pyagrum/causal/_exceptions.py +95 -0
  17. pyagrum/causal/_types.py +61 -0
  18. pyagrum/causal/causalEffectEstimation/_CausalEffectEstimation.py +1175 -0
  19. pyagrum/causal/causalEffectEstimation/_IVEstimators.py +718 -0
  20. pyagrum/causal/causalEffectEstimation/_RCTEstimators.py +132 -0
  21. pyagrum/causal/causalEffectEstimation/__init__.py +46 -0
  22. pyagrum/causal/causalEffectEstimation/_backdoorEstimators.py +774 -0
  23. pyagrum/causal/causalEffectEstimation/_causalBNEstimator.py +324 -0
  24. pyagrum/causal/causalEffectEstimation/_frontdoorEstimators.py +396 -0
  25. pyagrum/causal/causalEffectEstimation/_learners.py +118 -0
  26. pyagrum/causal/causalEffectEstimation/_utils.py +466 -0
  27. pyagrum/causal/notebook.py +172 -0
  28. pyagrum/clg/CLG.py +658 -0
  29. pyagrum/clg/GaussianVariable.py +111 -0
  30. pyagrum/clg/SEM.py +312 -0
  31. pyagrum/clg/__init__.py +63 -0
  32. pyagrum/clg/canonicalForm.py +408 -0
  33. pyagrum/clg/constants.py +54 -0
  34. pyagrum/clg/forwardSampling.py +202 -0
  35. pyagrum/clg/learning.py +776 -0
  36. pyagrum/clg/notebook.py +480 -0
  37. pyagrum/clg/variableElimination.py +271 -0
  38. pyagrum/common.py +60 -0
  39. pyagrum/config.py +319 -0
  40. pyagrum/ctbn/CIM.py +513 -0
  41. pyagrum/ctbn/CTBN.py +573 -0
  42. pyagrum/ctbn/CTBNGenerator.py +216 -0
  43. pyagrum/ctbn/CTBNInference.py +459 -0
  44. pyagrum/ctbn/CTBNLearner.py +161 -0
  45. pyagrum/ctbn/SamplesStats.py +671 -0
  46. pyagrum/ctbn/StatsIndepTest.py +355 -0
  47. pyagrum/ctbn/__init__.py +79 -0
  48. pyagrum/ctbn/constants.py +54 -0
  49. pyagrum/ctbn/notebook.py +264 -0
  50. pyagrum/defaults.ini +199 -0
  51. pyagrum/deprecated.py +95 -0
  52. pyagrum/explain/_ComputationCausal.py +75 -0
  53. pyagrum/explain/_ComputationConditional.py +48 -0
  54. pyagrum/explain/_ComputationMarginal.py +48 -0
  55. pyagrum/explain/_CustomShapleyCache.py +110 -0
  56. pyagrum/explain/_Explainer.py +176 -0
  57. pyagrum/explain/_Explanation.py +70 -0
  58. pyagrum/explain/_FIFOCache.py +54 -0
  59. pyagrum/explain/_ShallCausalValues.py +204 -0
  60. pyagrum/explain/_ShallConditionalValues.py +155 -0
  61. pyagrum/explain/_ShallMarginalValues.py +155 -0
  62. pyagrum/explain/_ShallValues.py +296 -0
  63. pyagrum/explain/_ShapCausalValues.py +208 -0
  64. pyagrum/explain/_ShapConditionalValues.py +126 -0
  65. pyagrum/explain/_ShapMarginalValues.py +191 -0
  66. pyagrum/explain/_ShapleyValues.py +298 -0
  67. pyagrum/explain/__init__.py +81 -0
  68. pyagrum/explain/_explGeneralizedMarkovBlanket.py +152 -0
  69. pyagrum/explain/_explIndependenceListForPairs.py +146 -0
  70. pyagrum/explain/_explInformationGraph.py +264 -0
  71. pyagrum/explain/notebook/__init__.py +54 -0
  72. pyagrum/explain/notebook/_bar.py +142 -0
  73. pyagrum/explain/notebook/_beeswarm.py +174 -0
  74. pyagrum/explain/notebook/_showShapValues.py +97 -0
  75. pyagrum/explain/notebook/_waterfall.py +220 -0
  76. pyagrum/explain/shapley.py +225 -0
  77. pyagrum/lib/__init__.py +46 -0
  78. pyagrum/lib/_colors.py +390 -0
  79. pyagrum/lib/bn2graph.py +299 -0
  80. pyagrum/lib/bn2roc.py +1026 -0
  81. pyagrum/lib/bn2scores.py +217 -0
  82. pyagrum/lib/bn_vs_bn.py +605 -0
  83. pyagrum/lib/cn2graph.py +305 -0
  84. pyagrum/lib/discreteTypeProcessor.py +1102 -0
  85. pyagrum/lib/discretizer.py +58 -0
  86. pyagrum/lib/dynamicBN.py +390 -0
  87. pyagrum/lib/explain.py +57 -0
  88. pyagrum/lib/export.py +84 -0
  89. pyagrum/lib/id2graph.py +258 -0
  90. pyagrum/lib/image.py +387 -0
  91. pyagrum/lib/ipython.py +307 -0
  92. pyagrum/lib/mrf2graph.py +471 -0
  93. pyagrum/lib/notebook.py +1821 -0
  94. pyagrum/lib/proba_histogram.py +552 -0
  95. pyagrum/lib/utils.py +138 -0
  96. pyagrum/pyagrum.py +31495 -0
  97. pyagrum/skbn/_MBCalcul.py +242 -0
  98. pyagrum/skbn/__init__.py +49 -0
  99. pyagrum/skbn/_learningMethods.py +282 -0
  100. pyagrum/skbn/_utils.py +297 -0
  101. pyagrum/skbn/bnclassifier.py +1014 -0
  102. pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/LICENSE.md +12 -0
  103. pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/LICENSES/LGPL-3.0-or-later.txt +304 -0
  104. pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/LICENSES/MIT.txt +18 -0
  105. pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/METADATA +145 -0
  106. pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/RECORD +107 -0
  107. pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/WHEEL +4 -0
@@ -0,0 +1,264 @@
1
+ ############################################################################
2
+ # This file is part of the aGrUM/pyAgrum library. #
3
+ # #
4
+ # Copyright (c) 2005-2025 by #
5
+ # - Pierre-Henri WUILLEMIN(_at_LIP6) #
6
+ # - Christophe GONZALES(_at_AMU) #
7
+ # #
8
+ # The aGrUM/pyAgrum library is free software; you can redistribute it #
9
+ # and/or modify it under the terms of either : #
10
+ # #
11
+ # - the GNU Lesser General Public License as published by #
12
+ # the Free Software Foundation, either version 3 of the License, #
13
+ # or (at your option) any later version, #
14
+ # - the MIT license (MIT), #
15
+ # - or both in dual license, as here. #
16
+ # #
17
+ # (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) #
18
+ # #
19
+ # This aGrUM/pyAgrum library is distributed in the hope that it will be #
20
+ # useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, #
21
+ # INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS #
22
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
23
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
24
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, #
25
+ # ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR #
26
+ # OTHER DEALINGS IN THE SOFTWARE. #
27
+ # #
28
+ # See LICENCES for more details. #
29
+ # #
30
+ # SPDX-FileCopyrightText: Copyright 2005-2025 #
31
+ # - Pierre-Henri WUILLEMIN(_at_LIP6) #
32
+ # - Christophe GONZALES(_at_AMU) #
33
+ # SPDX-License-Identifier: LGPL-3.0-or-later OR MIT #
34
+ # #
35
+ # Contact : info_at_agrum_dot_org #
36
+ # homepage : http://agrum.gitlab.io #
37
+ # gitlab : https://gitlab.com/agrumery/agrum #
38
+ # #
39
+ ############################################################################
40
+
41
+ import pyagrum as gum
42
+ import pyagrum.lib._colors as gumcols
43
+ from pyagrum.lib.bn2graph import BN2dot
44
+
45
+ # Matplotlib
46
+ import matplotlib as mpl
47
+
48
+ # GL
49
+ import warnings
50
+
51
+ _cdict = {
52
+ "red": ((0.0, 0.1, 0.3), (1.0, 0.6, 1.0)),
53
+ "green": ((0.0, 0.0, 0.0), (1.0, 0.6, 0.8)),
54
+ "blue": ((0.0, 0.0, 0.0), (1.0, 1, 0.8)),
55
+ }
56
+ _INFOcmap = mpl.colors.LinearSegmentedColormap("my_colormap", _cdict, 256)
57
+
58
+
59
+ def _normalizeVals(vals, hilightExtrema=False):
60
+ """
61
+ normalisation if vals is not a proba (max>1)
62
+ """
63
+ ma = float(max(vals.values()))
64
+ mi = float(min(vals.values()))
65
+ if ma == mi:
66
+ return None
67
+
68
+ if not hilightExtrema:
69
+ vmi = 0.01
70
+ vma = 0.99
71
+ else:
72
+ vmi = 0
73
+ vma = 1
74
+
75
+ res = {name: vmi + (val - mi) * (vma - vmi) / (ma - mi) for name, val in vals.items()}
76
+ return res
77
+
78
+
79
+ def getInformationGraph(bn, evs=None, size=None, cmap=_INFOcmap, withMinMax=False):
80
+ """
81
+ Create a dot representation of the information graph for this BN
82
+
83
+ Parameters
84
+ ----------
85
+ bn: gum.BayesNet
86
+ the BN
87
+ evs : Dict[str,str|int|List[float]]
88
+ map of evidence
89
+ size: str|int
90
+ size of the graph
91
+ cmap: matplotlib.colors.Colormap
92
+ color map
93
+ withMinMax: bool
94
+ min and max in the return values ?
95
+
96
+ Returns
97
+ -------
98
+ dot.Dot | Tuple[dot.Dot,float,float,float,float]
99
+ graph as a dot representation and if asked, min_information_value, max_information_value, min_mutual_information_value, max_mutual_information_value
100
+ """
101
+ if size is None:
102
+ size = gum.config["notebook", "default_graph_size"]
103
+
104
+ if evs is None:
105
+ evs = {}
106
+
107
+ ie = gum.LazyPropagation(bn)
108
+ ie.setEvidence(evs)
109
+ ie.makeInference()
110
+
111
+ idEvs = ie.hardEvidenceNodes() | ie.softEvidenceNodes()
112
+
113
+ nodevals = dict()
114
+ for n in bn.nodes():
115
+ if n not in idEvs:
116
+ v = ie.H(n)
117
+ if v != v: # is NaN
118
+ warnings.warn(f"For {bn.variable(n).name()}, entropy is NaN.")
119
+ v = 0
120
+ nodevals[bn.variable(n).name()] = v
121
+
122
+ arcvals = dict()
123
+ for x, y in bn.arcs():
124
+ v = ie.jointMutualInformation({x, y})
125
+ if v != v: # is NaN
126
+ warnings.warn(f"For {bn.variable(x).name()}->{bn.variable(y).name()}, mutual information is Nan.")
127
+ v = 0
128
+ arcvals[(x, y)] = v
129
+
130
+ gr = BN2dot(
131
+ bn,
132
+ size,
133
+ nodeColor=_normalizeVals(nodevals, hilightExtrema=False),
134
+ arcWidth=arcvals,
135
+ cmapNode=cmap,
136
+ cmapArc=cmap,
137
+ showMsg=nodevals,
138
+ )
139
+
140
+ if withMinMax:
141
+ mi_node = min(nodevals.values())
142
+ ma_node = max(nodevals.values())
143
+ mi_arc = min(arcvals.values())
144
+ ma_arc = max(arcvals.values())
145
+ return gr, mi_node, ma_node, mi_arc, ma_arc
146
+ else:
147
+ return gr
148
+
149
+
150
+ def _reprInformation(bn, evs=None, size=None, cmap=_INFOcmap, asString=False):
151
+ """
152
+ repr a bn annotated with results from inference : Information and mutual information
153
+
154
+ Parameters
155
+ ----------
156
+ bn: pyagrum.BayesNet
157
+ the model
158
+ evs: Dict[str|int,str|int|List[float]]
159
+ the observations
160
+ size: int|str
161
+ size of the rendered graph
162
+ cmap: matplotlib.colours.Colormap
163
+ the cmap
164
+ asString: bool
165
+ returns the string or display the HTML
166
+
167
+ Returns
168
+ -------
169
+ str|None
170
+ return the HTML string or directly display it.
171
+ """
172
+ import IPython.display
173
+ import IPython.core.pylabtools
174
+ from base64 import encodebytes
175
+ from matplotlib.backends.backend_agg import FigureCanvasAgg as fc
176
+
177
+ if size is None:
178
+ size = gum.config["notebook", "default_graph_size"]
179
+
180
+ if evs is None:
181
+ evs = {}
182
+
183
+ gr, mi, ma, _, _ = getInformationGraph(bn, evs, size, cmap, withMinMax=True)
184
+ gumcols.prepareDot(gr, size=size)
185
+
186
+ # dynamic member makes pylink unhappy
187
+ # pylint: disable=no-member
188
+ gsvg = IPython.display.SVG(gr.create_svg(encoding="utf-8"))
189
+ width = (
190
+ int(gsvg.data.split("width=")[1].split('"')[1].split("pt")[0]) / mpl.pyplot.rcParams["figure.dpi"]
191
+ ) # pixel in inches
192
+ if width < 5:
193
+ width = 5
194
+
195
+ fig = mpl.figure.Figure(figsize=(width, 1))
196
+ fig.patch.set_alpha(0)
197
+ canvas = fc(fig)
198
+ ax1 = fig.add_axes([0.05, 0.80, 0.9, 0.15])
199
+ norm = mpl.colors.Normalize(vmin=mi, vmax=ma)
200
+ cb1 = mpl.colorbar.ColorbarBase(ax1, cmap=cmap, norm=norm, orientation="horizontal")
201
+ cb1.set_label("Entropy")
202
+ cb1.ax.text(mi, -2, f"{mi:.4f}", ha="left", va="top", color=gumcols.proba2bgcolor(0.01, cmap))
203
+ cb1.ax.text(ma, -2, f"{ma:.4f}", ha="right", va="top", color=gumcols.proba2bgcolor(0.99, cmap))
204
+ png = IPython.core.pylabtools.print_figure(canvas.figure, "png") # from IPython.core.pylabtools
205
+ png_legend = f"<img style='vertical-align:middle' src='data:image/png;base64,{encodebytes(png).decode('ascii')}'>"
206
+
207
+ sss = f"<div align='center'>{gsvg.data}<br/>{png_legend}</div>"
208
+
209
+ if asString:
210
+ return sss
211
+
212
+ return IPython.display.display(IPython.display.HTML(sss))
213
+
214
+
215
+ def getInformation(bn, evs=None, size=None, cmap=_INFOcmap) -> str:
216
+ """
217
+ get a HTML string for a bn annotated with results from inference : entropy and mutual information
218
+
219
+ Parameters
220
+ ----------
221
+ bn: pyagrum.BayesNet
222
+ the model
223
+ evs: Dict[str|int,str|int|List[float]]
224
+ the observations
225
+ size: int|str
226
+ size of the rendered graph
227
+ cmap: matplotlib.colours.Colormap
228
+ the cmap
229
+
230
+ Returns
231
+ -------
232
+ str
233
+ return the HTML string
234
+ """
235
+ if size is None:
236
+ size = gum.config["notebook", "default_graph_size"]
237
+
238
+ if evs is None:
239
+ evs = {}
240
+
241
+ return _reprInformation(bn, evs, size, cmap, asString=True)
242
+
243
+
244
+ def showInformation(bn, evs=None, size=None, cmap=_INFOcmap):
245
+ """
246
+ diplay a bn annotated with results from inference : entropy and mutual information
247
+
248
+ Parameters
249
+ ----------
250
+ bn: pyagrum.BayesNet
251
+ the model
252
+ evs: Dict[str|int,str|int|List[float]]
253
+ the observations
254
+ size: int|str
255
+ size of the rendered graph
256
+ cmap: matplotlib.colours.Colormap
257
+ the cmap
258
+ """
259
+ if evs is None:
260
+ evs = {}
261
+
262
+ if size is None:
263
+ size = gum.config["notebook", "default_graph_size"]
264
+ return _reprInformation(bn, evs, size, cmap, asString=False)
@@ -0,0 +1,54 @@
1
+ ############################################################################
2
+ # This file is part of the aGrUM/pyAgrum library. #
3
+ # #
4
+ # Copyright (c) 2005-2025 by #
5
+ # - Pierre-Henri WUILLEMIN(_at_LIP6) #
6
+ # - Christophe GONZALES(_at_AMU) #
7
+ # #
8
+ # The aGrUM/pyAgrum library is free software; you can redistribute it #
9
+ # and/or modify it under the terms of either : #
10
+ # #
11
+ # - the GNU Lesser General Public License as published by #
12
+ # the Free Software Foundation, either version 3 of the License, #
13
+ # or (at your option) any later version, #
14
+ # - the MIT license (MIT), #
15
+ # - or both in dual license, as here. #
16
+ # #
17
+ # (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) #
18
+ # #
19
+ # This aGrUM/pyAgrum library is distributed in the hope that it will be #
20
+ # useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, #
21
+ # INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS #
22
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
23
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
24
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, #
25
+ # ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR #
26
+ # OTHER DEALINGS IN THE SOFTWARE. #
27
+ # #
28
+ # See LICENCES for more details. #
29
+ # #
30
+ # SPDX-FileCopyrightText: Copyright 2005-2025 #
31
+ # - Pierre-Henri WUILLEMIN(_at_LIP6) #
32
+ # - Christophe GONZALES(_at_AMU) #
33
+ # SPDX-License-Identifier: LGPL-3.0-or-later OR MIT #
34
+ # #
35
+ # Contact : info_at_agrum_dot_org #
36
+ # homepage : http://agrum.gitlab.io #
37
+ # gitlab : https://gitlab.com/agrumery/agrum #
38
+ # #
39
+ ############################################################################
40
+
41
+ __author__ = "Pierre-Henri Wuillemin"
42
+ __copyright__ = "(c) 2019-2025 PARIS"
43
+
44
+ from ._bar import bar
45
+ from ._beeswarm import beeswarm
46
+ from ._waterfall import waterfall
47
+ from ._showShapValues import showShapValues
48
+
49
+ __all__ = [
50
+ "bar",
51
+ "beeswarm",
52
+ "waterfall",
53
+ "showShapValues",
54
+ ]
@@ -0,0 +1,142 @@
1
+ ############################################################################
2
+ # This file is part of the aGrUM/pyAgrum library. #
3
+ # #
4
+ # Copyright (c) 2005-2025 by #
5
+ # - Pierre-Henri WUILLEMIN(_at_LIP6) #
6
+ # - Christophe GONZALES(_at_AMU) #
7
+ # #
8
+ # The aGrUM/pyAgrum library is free software; you can redistribute it #
9
+ # and/or modify it under the terms of either : #
10
+ # #
11
+ # - the GNU Lesser General Public License as published by #
12
+ # the Free Software Foundation, either version 3 of the License, #
13
+ # or (at your option) any later version, #
14
+ # - the MIT license (MIT), #
15
+ # - or both in dual license, as here. #
16
+ # #
17
+ # (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) #
18
+ # #
19
+ # This aGrUM/pyAgrum library is distributed in the hope that it will be #
20
+ # useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, #
21
+ # INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS #
22
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
23
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
24
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, #
25
+ # ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR #
26
+ # OTHER DEALINGS IN THE SOFTWARE. #
27
+ # #
28
+ # See LICENCES for more details. #
29
+ # #
30
+ # SPDX-FileCopyrightText: Copyright 2005-2025 #
31
+ # - Pierre-Henri WUILLEMIN(_at_LIP6) #
32
+ # - Christophe GONZALES(_at_AMU) #
33
+ # SPDX-License-Identifier: LGPL-3.0-or-later OR MIT #
34
+ # #
35
+ # Contact : info_at_agrum_dot_org #
36
+ # homepage : http://agrum.gitlab.io #
37
+ # gitlab : https://gitlab.com/agrumery/agrum #
38
+ # #
39
+ ############################################################################
40
+
41
+ import pyagrum as gum
42
+ from pyagrum.explain._Explanation import Explanation
43
+
44
+ import numpy as np
45
+
46
+ import matplotlib.pyplot as plt
47
+ from matplotlib.colors import LinearSegmentedColormap, to_rgb
48
+ from matplotlib.patches import Patch
49
+
50
+
51
+ def bar(explanation: Explanation, y: int = None, ax: plt.Axes = None, percentage: bool = False) -> plt.axis:
52
+ """
53
+ Plots a horizontal bar chart of the mean absolute SHAP/SHALL values for each feature in the explanation.
54
+
55
+ Parameters:
56
+ ----------
57
+ explanation : Explanation
58
+ The explanation object containing the SHAP/SHALL values.
59
+ y : int, optional
60
+ If the values type of the explanation is SHALL, then y is ignored.
61
+ Else it is the class for which to plot the SHAP values (default is None, which plots multi-bar for all classes).
62
+ ax : plt.Axes, optional
63
+ The matplotlib Axes object to plot on (default is None, which creates a new figure).
64
+ percentage: bool
65
+ if True, the importance plot is shown in percent.
66
+
67
+ Raises :
68
+ ------
69
+ TypeError : If `explanation` is not an Explanation object or if `y` is not an integer or None.
70
+ IndexError : If `y` is an integer but out of bounds for the explanation keys.
71
+ """
72
+
73
+ if not isinstance(explanation, Explanation):
74
+ raise TypeError(f"`explanation` must be an Explanation object but got {type(explanation)}")
75
+
76
+ # Determine if The explanation object is a SHALL or SHAP explanation
77
+ if explanation.values_type == "SHAP":
78
+ if not isinstance(y, int) and y is not None:
79
+ raise TypeError(f"`y` must be either a positive integer or None, but got {type(y)}")
80
+ if isinstance(y, int) and (y < min(explanation.keys()) or y > max(explanation.keys())):
81
+ raise IndexError(f"Target index y={y} is out of bounds; expected 0 <= y < {max(explanation.keys()) + 1}.")
82
+ elif explanation.values_type == "SHALL":
83
+ # We force y to be an integer, so we can use the same code after for both explanations
84
+ y = 0
85
+ else:
86
+ raise ValueError(f"Wrong values type, expected SHAP/SHALL but got {explanation.values_type}")
87
+
88
+ if ax is None:
89
+ _, ax = plt.subplots(figsize=(6, 4))
90
+
91
+ if y is not None:
92
+ importances = explanation.importances[y] if explanation.values_type == "SHAP" else explanation.importances
93
+ columns = [col for col in sorted(importances.keys(), key=importances.get)]
94
+ values = [importances[feat] for feat in columns]
95
+ if percentage:
96
+ total = sum(values)
97
+ values = [(v / total) * 100 for v in values]
98
+ ax.barh(columns, values, color=gum.config["notebook", "tensor_color_0"], height=0.5, alpha=0.8)
99
+ else:
100
+ classes = sorted(explanation.keys())
101
+ cmap = LinearSegmentedColormap.from_list(
102
+ "class_cmap", [to_rgb(gum.config["notebook", "tensor_color_0"]), to_rgb(gum.config["notebook", "tensor_color_1"])]
103
+ )
104
+ colors = [cmap(i / (len(explanation) - 1)) for i in range(len(explanation))]
105
+
106
+ n_features = len(explanation.feature_names)
107
+ values = np.array([[explanation.importances[z][feat] for feat in explanation.feature_names] for z in classes])
108
+ # Sort bars
109
+ indices = np.argsort(np.sum(values, axis=0))
110
+ values = values[:, indices]
111
+ features = [explanation.feature_names[i] for i in indices]
112
+ bottom = np.zeros(n_features)
113
+
114
+ for i, cls in enumerate(classes):
115
+ contribs = values[i]
116
+ if percentage:
117
+ total = sum(contribs)
118
+ contribs = [(v / total) * 100 for v in contribs]
119
+ ax.barh(
120
+ features, contribs, height=0.5, left=bottom, color=colors[i % len(colors)], label=f"class {cls}", alpha=0.8
121
+ )
122
+
123
+ bottom += contribs
124
+ legend_elements = [Patch(facecolor=colors[i], edgecolor="black", label=f"Class {i}") for i in range(len(colors))]
125
+ ax.legend(loc="lower right", handles=legend_elements, title="Classes")
126
+
127
+ ax.set_title("Feature Importance", fontsize=16)
128
+
129
+ msg = " in %" if percentage else ""
130
+ ax.set_xlabel(f"mean(|{explanation.values_type} value|){msg}", fontsize=12)
131
+ ax.set_ylabel("Features", fontsize=12)
132
+ ax.tick_params(axis="x", labelsize=10)
133
+ ax.tick_params(axis="y", labelsize=10)
134
+
135
+ # Removing spines
136
+ ax.grid(axis="x", linestyle=":", alpha=0.6)
137
+ ax.grid(axis="y", linestyle=":", alpha=0.3)
138
+ ax.spines["top"].set_visible(False)
139
+ ax.spines["bottom"].set_visible(False)
140
+ ax.spines["left"].set_visible(False)
141
+ ax.spines["right"].set_visible(False)
142
+ ax.figure.set_facecolor("white")
@@ -0,0 +1,174 @@
1
+ ############################################################################
2
+ # This file is part of the aGrUM/pyAgrum library. #
3
+ # #
4
+ # Copyright (c) 2005-2025 by #
5
+ # - Pierre-Henri WUILLEMIN(_at_LIP6) #
6
+ # - Christophe GONZALES(_at_AMU) #
7
+ # #
8
+ # The aGrUM/pyAgrum library is free software; you can redistribute it #
9
+ # and/or modify it under the terms of either : #
10
+ # #
11
+ # - the GNU Lesser General Public License as published by #
12
+ # the Free Software Foundation, either version 3 of the License, #
13
+ # or (at your option) any later version, #
14
+ # - the MIT license (MIT), #
15
+ # - or both in dual license, as here. #
16
+ # #
17
+ # (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) #
18
+ # #
19
+ # This aGrUM/pyAgrum library is distributed in the hope that it will be #
20
+ # useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, #
21
+ # INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS #
22
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
23
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
24
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, #
25
+ # ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR #
26
+ # OTHER DEALINGS IN THE SOFTWARE. #
27
+ # #
28
+ # See LICENCES for more details. #
29
+ # #
30
+ # SPDX-FileCopyrightText: Copyright 2005-2025 #
31
+ # - Pierre-Henri WUILLEMIN(_at_LIP6) #
32
+ # - Christophe GONZALES(_at_AMU) #
33
+ # SPDX-License-Identifier: LGPL-3.0-or-later OR MIT #
34
+ # #
35
+ # Contact : info_at_agrum_dot_org #
36
+ # homepage : http://agrum.gitlab.io #
37
+ # gitlab : https://gitlab.com/agrumery/agrum #
38
+ # #
39
+ ############################################################################
40
+
41
+ import pyagrum as gum
42
+ from pyagrum.explain._Explanation import Explanation
43
+
44
+ import numpy as np
45
+
46
+ import matplotlib.pyplot as plt
47
+ import matplotlib.cm as cm
48
+ from matplotlib import colors
49
+
50
+
51
+ def beeswarm(
52
+ explanation: Explanation, y: int = 1, max_display: int = 20, color_bar: bool = True, ax=None, sort: bool = True
53
+ ):
54
+ """
55
+ Plots a beeswarm plot of the Shapley values for a given target class.
56
+ Parameters:
57
+ ----------
58
+ explanation : Explanation
59
+ The explanation object containing the SHAP/SHALL values.
60
+ y : int
61
+ If the values type of the explanation is SHALL, then y is ignored.
62
+ Else it is the class for which to plot the SHAP values.
63
+ max_display : int, optional
64
+ The maximum number of features to display in the beeswarm plot (default is 20).
65
+ color_bar : bool, optional
66
+ If True, adds a color bar to the plot (default is True).
67
+ ax : plt.Axes, optional
68
+ The matplotlib Axes object to plot on (default is None, which creates a new figure).
69
+ sort : bool, optional
70
+ If True, sorts the features by their importance before plotting (default is True).
71
+
72
+ Raises:
73
+ ------
74
+ TypeError
75
+ If `explanation` is not an Explanation object, if `y` is not an integer or if the explanation is not global (i.e., does not contain lists of contributions for each feature).
76
+ IndexError
77
+ If `y` is out of bounds for the explanation keys.
78
+ """
79
+ # Check parameters
80
+ if not isinstance(explanation, Explanation):
81
+ raise TypeError("`explanation` must be an Explanation object but got {}".format(type(explanation)))
82
+
83
+ # Determine if The explanation object is a SHALL or SHAP explanation
84
+ if explanation.values_type == "SHAP":
85
+ if not isinstance(y, int):
86
+ raise TypeError("`y` must be an integer but got {}".format(type(y)))
87
+ if y < min(explanation.keys()) or y > max(explanation.keys()):
88
+ raise IndexError(f"Target index y={y} is out of bounds; expected 0 <= y < {max(explanation.keys()) + 1}.")
89
+ contributions = explanation[y]
90
+ importances = explanation.importances[y]
91
+ elif explanation.values_type == "SHALL":
92
+ contributions = explanation
93
+ importances = explanation.importances
94
+ else:
95
+ raise ValueError(f"Wrong values type, expected SHAP/SHALL but got {explanation.values_type}")
96
+
97
+ feature_names = explanation.feature_names
98
+ if not isinstance(list(contributions.values())[0], list):
99
+ raise TypeError("For beeswarm plot, explanation must be global.")
100
+ values = np.array([contributions[k] for k in feature_names]).T
101
+ features = explanation.data
102
+
103
+ # Create the figure and axis if not provided
104
+ if ax == None:
105
+ _, ax = plt.subplots()
106
+
107
+ # Prepare the y-axis positions
108
+ y_positions = np.arange(min(max_display, len(feature_names)), 0, -1)
109
+
110
+ # Plot the beeswarm
111
+ ax.plot([0, 0], [y_positions[-1] - 0.25, y_positions[0] + 0.25], linestyle="--", color="gray")
112
+ color1 = gum.config["notebook", "tensor_color_0"]
113
+ color2 = gum.config["notebook", "tensor_color_1"]
114
+ cmap = colors.LinearSegmentedColormap.from_list("custom_red_green", [color1, color2])
115
+
116
+ if sort:
117
+ indices = [feature_names.index(feat) for feat in sorted(importances, key=importances.get, reverse=True)]
118
+ else:
119
+ indices = np.arange(min(max_display, values.shape[1]))
120
+
121
+ for k, j in enumerate(indices):
122
+ base = y_positions[k]
123
+ sequence = np.arange(values.shape[0])
124
+ np.random.shuffle(sequence)
125
+
126
+ shapes = values[sequence, j]
127
+ rounded_x = np.round(shapes, 2)
128
+ (unique, counts) = np.unique(rounded_x, return_counts=True)
129
+ density_map = dict(zip(unique, counts))
130
+ densities = np.array([density_map[val] for val in rounded_x])
131
+
132
+ sigmas = (densities / np.max(densities)) * 0.1
133
+ ords = np.random.normal(loc=base, scale=sigmas)
134
+
135
+ vals = features[sequence, j]
136
+
137
+ minimum = vals.min()
138
+ maximum = vals.max()
139
+ norm = colors.Normalize(vmin=minimum, vmax=maximum)
140
+ sm = cm.ScalarMappable(cmap=cmap, norm=norm)
141
+ sm.set_array([])
142
+
143
+ ax.scatter(
144
+ shapes,
145
+ ords,
146
+ c=vals,
147
+ cmap=cmap,
148
+ s=7,
149
+ )
150
+
151
+ ax.set_yticks(y_positions)
152
+ ax.set_yticklabels([feature_names[i] for i in indices])
153
+ if color_bar:
154
+ norm = colors.Normalize(vmin=0.0, vmax=1.0)
155
+ sm = cm.ScalarMappable(cmap=cmap, norm=norm)
156
+ sm.set_array([])
157
+ cbar = plt.colorbar(sm, ax=ax)
158
+ cbar.set_label("Feature value")
159
+ cbar.set_ticks([0, 1])
160
+ cbar.set_ticklabels(["Low", "High"])
161
+
162
+ ax.set_ylim(y_positions[-1] - 0.5, y_positions[0] + 0.5)
163
+ ax.set_xlabel("Impact on model Output", fontsize=12)
164
+ ax.set_ylabel("Features", fontsize=12)
165
+ ax.set_title(f"{explanation.values_type} value (Impact on model Output)", fontsize=16)
166
+
167
+ # Setting the style
168
+ ax.grid(axis="x", linestyle=":", alpha=0.5)
169
+ ax.grid(axis="y", linestyle=":", alpha=0.5)
170
+ ax.spines["top"].set_visible(False)
171
+ ax.spines["bottom"].set_visible(False)
172
+ ax.spines["left"].set_visible(False)
173
+ ax.spines["right"].set_visible(False)
174
+ ax.figure.set_facecolor("white")