pyAgrum-nightly 2.3.0.9.dev202512061764412981__cp310-abi3-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyagrum/__init__.py +165 -0
- pyagrum/_pyagrum.so +0 -0
- pyagrum/bnmixture/BNMInference.py +268 -0
- pyagrum/bnmixture/BNMLearning.py +376 -0
- pyagrum/bnmixture/BNMixture.py +464 -0
- pyagrum/bnmixture/__init__.py +60 -0
- pyagrum/bnmixture/notebook.py +1058 -0
- pyagrum/causal/_CausalFormula.py +280 -0
- pyagrum/causal/_CausalModel.py +436 -0
- pyagrum/causal/__init__.py +81 -0
- pyagrum/causal/_causalImpact.py +356 -0
- pyagrum/causal/_dSeparation.py +598 -0
- pyagrum/causal/_doAST.py +761 -0
- pyagrum/causal/_doCalculus.py +361 -0
- pyagrum/causal/_doorCriteria.py +374 -0
- pyagrum/causal/_exceptions.py +95 -0
- pyagrum/causal/_types.py +61 -0
- pyagrum/causal/causalEffectEstimation/_CausalEffectEstimation.py +1175 -0
- pyagrum/causal/causalEffectEstimation/_IVEstimators.py +718 -0
- pyagrum/causal/causalEffectEstimation/_RCTEstimators.py +132 -0
- pyagrum/causal/causalEffectEstimation/__init__.py +46 -0
- pyagrum/causal/causalEffectEstimation/_backdoorEstimators.py +774 -0
- pyagrum/causal/causalEffectEstimation/_causalBNEstimator.py +324 -0
- pyagrum/causal/causalEffectEstimation/_frontdoorEstimators.py +396 -0
- pyagrum/causal/causalEffectEstimation/_learners.py +118 -0
- pyagrum/causal/causalEffectEstimation/_utils.py +466 -0
- pyagrum/causal/notebook.py +171 -0
- pyagrum/clg/CLG.py +658 -0
- pyagrum/clg/GaussianVariable.py +111 -0
- pyagrum/clg/SEM.py +312 -0
- pyagrum/clg/__init__.py +63 -0
- pyagrum/clg/canonicalForm.py +408 -0
- pyagrum/clg/constants.py +54 -0
- pyagrum/clg/forwardSampling.py +202 -0
- pyagrum/clg/learning.py +776 -0
- pyagrum/clg/notebook.py +480 -0
- pyagrum/clg/variableElimination.py +271 -0
- pyagrum/common.py +60 -0
- pyagrum/config.py +319 -0
- pyagrum/ctbn/CIM.py +513 -0
- pyagrum/ctbn/CTBN.py +573 -0
- pyagrum/ctbn/CTBNGenerator.py +216 -0
- pyagrum/ctbn/CTBNInference.py +459 -0
- pyagrum/ctbn/CTBNLearner.py +161 -0
- pyagrum/ctbn/SamplesStats.py +671 -0
- pyagrum/ctbn/StatsIndepTest.py +355 -0
- pyagrum/ctbn/__init__.py +79 -0
- pyagrum/ctbn/constants.py +54 -0
- pyagrum/ctbn/notebook.py +264 -0
- pyagrum/defaults.ini +199 -0
- pyagrum/deprecated.py +95 -0
- pyagrum/explain/_ComputationCausal.py +75 -0
- pyagrum/explain/_ComputationConditional.py +48 -0
- pyagrum/explain/_ComputationMarginal.py +48 -0
- pyagrum/explain/_CustomShapleyCache.py +110 -0
- pyagrum/explain/_Explainer.py +176 -0
- pyagrum/explain/_Explanation.py +70 -0
- pyagrum/explain/_FIFOCache.py +54 -0
- pyagrum/explain/_ShallCausalValues.py +204 -0
- pyagrum/explain/_ShallConditionalValues.py +155 -0
- pyagrum/explain/_ShallMarginalValues.py +155 -0
- pyagrum/explain/_ShallValues.py +296 -0
- pyagrum/explain/_ShapCausalValues.py +208 -0
- pyagrum/explain/_ShapConditionalValues.py +126 -0
- pyagrum/explain/_ShapMarginalValues.py +191 -0
- pyagrum/explain/_ShapleyValues.py +298 -0
- pyagrum/explain/__init__.py +81 -0
- pyagrum/explain/_explGeneralizedMarkovBlanket.py +152 -0
- pyagrum/explain/_explIndependenceListForPairs.py +146 -0
- pyagrum/explain/_explInformationGraph.py +264 -0
- pyagrum/explain/notebook/__init__.py +54 -0
- pyagrum/explain/notebook/_bar.py +142 -0
- pyagrum/explain/notebook/_beeswarm.py +174 -0
- pyagrum/explain/notebook/_showShapValues.py +97 -0
- pyagrum/explain/notebook/_waterfall.py +220 -0
- pyagrum/explain/shapley.py +225 -0
- pyagrum/lib/__init__.py +46 -0
- pyagrum/lib/_colors.py +390 -0
- pyagrum/lib/bn2graph.py +299 -0
- pyagrum/lib/bn2roc.py +1026 -0
- pyagrum/lib/bn2scores.py +217 -0
- pyagrum/lib/bn_vs_bn.py +605 -0
- pyagrum/lib/cn2graph.py +305 -0
- pyagrum/lib/discreteTypeProcessor.py +1102 -0
- pyagrum/lib/discretizer.py +58 -0
- pyagrum/lib/dynamicBN.py +390 -0
- pyagrum/lib/explain.py +57 -0
- pyagrum/lib/export.py +84 -0
- pyagrum/lib/id2graph.py +258 -0
- pyagrum/lib/image.py +387 -0
- pyagrum/lib/ipython.py +307 -0
- pyagrum/lib/mrf2graph.py +471 -0
- pyagrum/lib/notebook.py +1821 -0
- pyagrum/lib/proba_histogram.py +552 -0
- pyagrum/lib/utils.py +138 -0
- pyagrum/pyagrum.py +31495 -0
- pyagrum/skbn/_MBCalcul.py +242 -0
- pyagrum/skbn/__init__.py +49 -0
- pyagrum/skbn/_learningMethods.py +282 -0
- pyagrum/skbn/_utils.py +297 -0
- pyagrum/skbn/bnclassifier.py +1014 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/LICENSE.md +12 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/LICENSES/LGPL-3.0-or-later.txt +304 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/LICENSES/MIT.txt +18 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/METADATA +145 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/RECORD +107 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
############################################################################
|
|
2
|
+
# This file is part of the aGrUM/pyAgrum library. #
|
|
3
|
+
# #
|
|
4
|
+
# Copyright (c) 2005-2025 by #
|
|
5
|
+
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
6
|
+
# - Christophe GONZALES(_at_AMU) #
|
|
7
|
+
# #
|
|
8
|
+
# The aGrUM/pyAgrum library is free software; you can redistribute it #
|
|
9
|
+
# and/or modify it under the terms of either : #
|
|
10
|
+
# #
|
|
11
|
+
# - the GNU Lesser General Public License as published by #
|
|
12
|
+
# the Free Software Foundation, either version 3 of the License, #
|
|
13
|
+
# or (at your option) any later version, #
|
|
14
|
+
# - the MIT license (MIT), #
|
|
15
|
+
# - or both in dual license, as here. #
|
|
16
|
+
# #
|
|
17
|
+
# (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) #
|
|
18
|
+
# #
|
|
19
|
+
# This aGrUM/pyAgrum library is distributed in the hope that it will be #
|
|
20
|
+
# useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, #
|
|
21
|
+
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS #
|
|
22
|
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
|
|
23
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
|
|
24
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, #
|
|
25
|
+
# ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR #
|
|
26
|
+
# OTHER DEALINGS IN THE SOFTWARE. #
|
|
27
|
+
# #
|
|
28
|
+
# See LICENCES for more details. #
|
|
29
|
+
# #
|
|
30
|
+
# SPDX-FileCopyrightText: Copyright 2005-2025 #
|
|
31
|
+
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
32
|
+
# - Christophe GONZALES(_at_AMU) #
|
|
33
|
+
# SPDX-License-Identifier: LGPL-3.0-or-later OR MIT #
|
|
34
|
+
# #
|
|
35
|
+
# Contact : info_at_agrum_dot_org #
|
|
36
|
+
# homepage : http://agrum.gitlab.io #
|
|
37
|
+
# gitlab : https://gitlab.com/agrumery/agrum #
|
|
38
|
+
# #
|
|
39
|
+
############################################################################
|
|
40
|
+
|
|
41
|
+
import pyagrum as gum
|
|
42
|
+
import pyagrum.lib._colors as gumcols
|
|
43
|
+
from pyagrum.lib.bn2graph import BN2dot
|
|
44
|
+
|
|
45
|
+
# Matplotlib
|
|
46
|
+
import matplotlib as mpl
|
|
47
|
+
|
|
48
|
+
# GL
|
|
49
|
+
import warnings
|
|
50
|
+
|
|
51
|
+
_cdict = {
|
|
52
|
+
"red": ((0.0, 0.1, 0.3), (1.0, 0.6, 1.0)),
|
|
53
|
+
"green": ((0.0, 0.0, 0.0), (1.0, 0.6, 0.8)),
|
|
54
|
+
"blue": ((0.0, 0.0, 0.0), (1.0, 1, 0.8)),
|
|
55
|
+
}
|
|
56
|
+
_INFOcmap = mpl.colors.LinearSegmentedColormap("my_colormap", _cdict, 256)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _normalizeVals(vals, hilightExtrema=False):
|
|
60
|
+
"""
|
|
61
|
+
normalisation if vals is not a proba (max>1)
|
|
62
|
+
"""
|
|
63
|
+
ma = float(max(vals.values()))
|
|
64
|
+
mi = float(min(vals.values()))
|
|
65
|
+
if ma == mi:
|
|
66
|
+
return None
|
|
67
|
+
|
|
68
|
+
if not hilightExtrema:
|
|
69
|
+
vmi = 0.01
|
|
70
|
+
vma = 0.99
|
|
71
|
+
else:
|
|
72
|
+
vmi = 0
|
|
73
|
+
vma = 1
|
|
74
|
+
|
|
75
|
+
res = {name: vmi + (val - mi) * (vma - vmi) / (ma - mi) for name, val in vals.items()}
|
|
76
|
+
return res
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def getInformationGraph(bn, evs=None, size=None, cmap=_INFOcmap, withMinMax=False):
|
|
80
|
+
"""
|
|
81
|
+
Create a dot representation of the information graph for this BN
|
|
82
|
+
|
|
83
|
+
Parameters
|
|
84
|
+
----------
|
|
85
|
+
bn: gum.BayesNet
|
|
86
|
+
the BN
|
|
87
|
+
evs : Dict[str,str|int|List[float]]
|
|
88
|
+
map of evidence
|
|
89
|
+
size: str|int
|
|
90
|
+
size of the graph
|
|
91
|
+
cmap: matplotlib.colors.Colormap
|
|
92
|
+
color map
|
|
93
|
+
withMinMax: bool
|
|
94
|
+
min and max in the return values ?
|
|
95
|
+
|
|
96
|
+
Returns
|
|
97
|
+
-------
|
|
98
|
+
dot.Dot | Tuple[dot.Dot,float,float,float,float]
|
|
99
|
+
graph as a dot representation and if asked, min_information_value, max_information_value, min_mutual_information_value, max_mutual_information_value
|
|
100
|
+
"""
|
|
101
|
+
if size is None:
|
|
102
|
+
size = gum.config["notebook", "default_graph_size"]
|
|
103
|
+
|
|
104
|
+
if evs is None:
|
|
105
|
+
evs = {}
|
|
106
|
+
|
|
107
|
+
ie = gum.LazyPropagation(bn)
|
|
108
|
+
ie.setEvidence(evs)
|
|
109
|
+
ie.makeInference()
|
|
110
|
+
|
|
111
|
+
idEvs = ie.hardEvidenceNodes() | ie.softEvidenceNodes()
|
|
112
|
+
|
|
113
|
+
nodevals = dict()
|
|
114
|
+
for n in bn.nodes():
|
|
115
|
+
if n not in idEvs:
|
|
116
|
+
v = ie.H(n)
|
|
117
|
+
if v != v: # is NaN
|
|
118
|
+
warnings.warn(f"For {bn.variable(n).name()}, entropy is NaN.")
|
|
119
|
+
v = 0
|
|
120
|
+
nodevals[bn.variable(n).name()] = v
|
|
121
|
+
|
|
122
|
+
arcvals = dict()
|
|
123
|
+
for x, y in bn.arcs():
|
|
124
|
+
v = ie.jointMutualInformation({x, y})
|
|
125
|
+
if v != v: # is NaN
|
|
126
|
+
warnings.warn(f"For {bn.variable(x).name()}->{bn.variable(y).name()}, mutual information is Nan.")
|
|
127
|
+
v = 0
|
|
128
|
+
arcvals[(x, y)] = v
|
|
129
|
+
|
|
130
|
+
gr = BN2dot(
|
|
131
|
+
bn,
|
|
132
|
+
size,
|
|
133
|
+
nodeColor=_normalizeVals(nodevals, hilightExtrema=False),
|
|
134
|
+
arcWidth=arcvals,
|
|
135
|
+
cmapNode=cmap,
|
|
136
|
+
cmapArc=cmap,
|
|
137
|
+
showMsg=nodevals,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
if withMinMax:
|
|
141
|
+
mi_node = min(nodevals.values())
|
|
142
|
+
ma_node = max(nodevals.values())
|
|
143
|
+
mi_arc = min(arcvals.values())
|
|
144
|
+
ma_arc = max(arcvals.values())
|
|
145
|
+
return gr, mi_node, ma_node, mi_arc, ma_arc
|
|
146
|
+
else:
|
|
147
|
+
return gr
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def _reprInformation(bn, evs=None, size=None, cmap=_INFOcmap, asString=False):
|
|
151
|
+
"""
|
|
152
|
+
repr a bn annotated with results from inference : Information and mutual information
|
|
153
|
+
|
|
154
|
+
Parameters
|
|
155
|
+
----------
|
|
156
|
+
bn: pyagrum.BayesNet
|
|
157
|
+
the model
|
|
158
|
+
evs: Dict[str|int,str|int|List[float]]
|
|
159
|
+
the observations
|
|
160
|
+
size: int|str
|
|
161
|
+
size of the rendered graph
|
|
162
|
+
cmap: matplotlib.colours.Colormap
|
|
163
|
+
the cmap
|
|
164
|
+
asString: bool
|
|
165
|
+
returns the string or display the HTML
|
|
166
|
+
|
|
167
|
+
Returns
|
|
168
|
+
-------
|
|
169
|
+
str|None
|
|
170
|
+
return the HTML string or directly display it.
|
|
171
|
+
"""
|
|
172
|
+
import IPython.display
|
|
173
|
+
import IPython.core.pylabtools
|
|
174
|
+
from base64 import encodebytes
|
|
175
|
+
from matplotlib.backends.backend_agg import FigureCanvasAgg as fc
|
|
176
|
+
|
|
177
|
+
if size is None:
|
|
178
|
+
size = gum.config["notebook", "default_graph_size"]
|
|
179
|
+
|
|
180
|
+
if evs is None:
|
|
181
|
+
evs = {}
|
|
182
|
+
|
|
183
|
+
gr, mi, ma, _, _ = getInformationGraph(bn, evs, size, cmap, withMinMax=True)
|
|
184
|
+
gumcols.prepareDot(gr, size=size)
|
|
185
|
+
|
|
186
|
+
# dynamic member makes pylink unhappy
|
|
187
|
+
# pylint: disable=no-member
|
|
188
|
+
gsvg = IPython.display.SVG(gr.create_svg(encoding="utf-8"))
|
|
189
|
+
width = (
|
|
190
|
+
int(gsvg.data.split("width=")[1].split('"')[1].split("pt")[0]) / mpl.pyplot.rcParams["figure.dpi"]
|
|
191
|
+
) # pixel in inches
|
|
192
|
+
if width < 5:
|
|
193
|
+
width = 5
|
|
194
|
+
|
|
195
|
+
fig = mpl.figure.Figure(figsize=(width, 1))
|
|
196
|
+
fig.patch.set_alpha(0)
|
|
197
|
+
canvas = fc(fig)
|
|
198
|
+
ax1 = fig.add_axes([0.05, 0.80, 0.9, 0.15])
|
|
199
|
+
norm = mpl.colors.Normalize(vmin=mi, vmax=ma)
|
|
200
|
+
cb1 = mpl.colorbar.ColorbarBase(ax1, cmap=cmap, norm=norm, orientation="horizontal")
|
|
201
|
+
cb1.set_label("Entropy")
|
|
202
|
+
cb1.ax.text(mi, -2, f"{mi:.4f}", ha="left", va="top", color=gumcols.proba2bgcolor(0.01, cmap))
|
|
203
|
+
cb1.ax.text(ma, -2, f"{ma:.4f}", ha="right", va="top", color=gumcols.proba2bgcolor(0.99, cmap))
|
|
204
|
+
png = IPython.core.pylabtools.print_figure(canvas.figure, "png") # from IPython.core.pylabtools
|
|
205
|
+
png_legend = f"<img style='vertical-align:middle' src='data:image/png;base64,{encodebytes(png).decode('ascii')}'>"
|
|
206
|
+
|
|
207
|
+
sss = f"<div align='center'>{gsvg.data}<br/>{png_legend}</div>"
|
|
208
|
+
|
|
209
|
+
if asString:
|
|
210
|
+
return sss
|
|
211
|
+
|
|
212
|
+
return IPython.display.display(IPython.display.HTML(sss))
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def getInformation(bn, evs=None, size=None, cmap=_INFOcmap) -> str:
|
|
216
|
+
"""
|
|
217
|
+
get a HTML string for a bn annotated with results from inference : entropy and mutual information
|
|
218
|
+
|
|
219
|
+
Parameters
|
|
220
|
+
----------
|
|
221
|
+
bn: pyagrum.BayesNet
|
|
222
|
+
the model
|
|
223
|
+
evs: Dict[str|int,str|int|List[float]]
|
|
224
|
+
the observations
|
|
225
|
+
size: int|str
|
|
226
|
+
size of the rendered graph
|
|
227
|
+
cmap: matplotlib.colours.Colormap
|
|
228
|
+
the cmap
|
|
229
|
+
|
|
230
|
+
Returns
|
|
231
|
+
-------
|
|
232
|
+
str
|
|
233
|
+
return the HTML string
|
|
234
|
+
"""
|
|
235
|
+
if size is None:
|
|
236
|
+
size = gum.config["notebook", "default_graph_size"]
|
|
237
|
+
|
|
238
|
+
if evs is None:
|
|
239
|
+
evs = {}
|
|
240
|
+
|
|
241
|
+
return _reprInformation(bn, evs, size, cmap, asString=True)
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def showInformation(bn, evs=None, size=None, cmap=_INFOcmap):
|
|
245
|
+
"""
|
|
246
|
+
diplay a bn annotated with results from inference : entropy and mutual information
|
|
247
|
+
|
|
248
|
+
Parameters
|
|
249
|
+
----------
|
|
250
|
+
bn: pyagrum.BayesNet
|
|
251
|
+
the model
|
|
252
|
+
evs: Dict[str|int,str|int|List[float]]
|
|
253
|
+
the observations
|
|
254
|
+
size: int|str
|
|
255
|
+
size of the rendered graph
|
|
256
|
+
cmap: matplotlib.colours.Colormap
|
|
257
|
+
the cmap
|
|
258
|
+
"""
|
|
259
|
+
if evs is None:
|
|
260
|
+
evs = {}
|
|
261
|
+
|
|
262
|
+
if size is None:
|
|
263
|
+
size = gum.config["notebook", "default_graph_size"]
|
|
264
|
+
return _reprInformation(bn, evs, size, cmap, asString=False)
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
############################################################################
|
|
2
|
+
# This file is part of the aGrUM/pyAgrum library. #
|
|
3
|
+
# #
|
|
4
|
+
# Copyright (c) 2005-2025 by #
|
|
5
|
+
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
6
|
+
# - Christophe GONZALES(_at_AMU) #
|
|
7
|
+
# #
|
|
8
|
+
# The aGrUM/pyAgrum library is free software; you can redistribute it #
|
|
9
|
+
# and/or modify it under the terms of either : #
|
|
10
|
+
# #
|
|
11
|
+
# - the GNU Lesser General Public License as published by #
|
|
12
|
+
# the Free Software Foundation, either version 3 of the License, #
|
|
13
|
+
# or (at your option) any later version, #
|
|
14
|
+
# - the MIT license (MIT), #
|
|
15
|
+
# - or both in dual license, as here. #
|
|
16
|
+
# #
|
|
17
|
+
# (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) #
|
|
18
|
+
# #
|
|
19
|
+
# This aGrUM/pyAgrum library is distributed in the hope that it will be #
|
|
20
|
+
# useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, #
|
|
21
|
+
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS #
|
|
22
|
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
|
|
23
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
|
|
24
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, #
|
|
25
|
+
# ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR #
|
|
26
|
+
# OTHER DEALINGS IN THE SOFTWARE. #
|
|
27
|
+
# #
|
|
28
|
+
# See LICENCES for more details. #
|
|
29
|
+
# #
|
|
30
|
+
# SPDX-FileCopyrightText: Copyright 2005-2025 #
|
|
31
|
+
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
32
|
+
# - Christophe GONZALES(_at_AMU) #
|
|
33
|
+
# SPDX-License-Identifier: LGPL-3.0-or-later OR MIT #
|
|
34
|
+
# #
|
|
35
|
+
# Contact : info_at_agrum_dot_org #
|
|
36
|
+
# homepage : http://agrum.gitlab.io #
|
|
37
|
+
# gitlab : https://gitlab.com/agrumery/agrum #
|
|
38
|
+
# #
|
|
39
|
+
############################################################################
|
|
40
|
+
|
|
41
|
+
__author__ = "Pierre-Henri Wuillemin"
|
|
42
|
+
__copyright__ = "(c) 2019-2025 PARIS"
|
|
43
|
+
|
|
44
|
+
from ._bar import bar
|
|
45
|
+
from ._beeswarm import beeswarm
|
|
46
|
+
from ._waterfall import waterfall
|
|
47
|
+
from ._showShapValues import showShapValues
|
|
48
|
+
|
|
49
|
+
__all__ = [
|
|
50
|
+
"bar",
|
|
51
|
+
"beeswarm",
|
|
52
|
+
"waterfall",
|
|
53
|
+
"showShapValues",
|
|
54
|
+
]
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
############################################################################
|
|
2
|
+
# This file is part of the aGrUM/pyAgrum library. #
|
|
3
|
+
# #
|
|
4
|
+
# Copyright (c) 2005-2025 by #
|
|
5
|
+
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
6
|
+
# - Christophe GONZALES(_at_AMU) #
|
|
7
|
+
# #
|
|
8
|
+
# The aGrUM/pyAgrum library is free software; you can redistribute it #
|
|
9
|
+
# and/or modify it under the terms of either : #
|
|
10
|
+
# #
|
|
11
|
+
# - the GNU Lesser General Public License as published by #
|
|
12
|
+
# the Free Software Foundation, either version 3 of the License, #
|
|
13
|
+
# or (at your option) any later version, #
|
|
14
|
+
# - the MIT license (MIT), #
|
|
15
|
+
# - or both in dual license, as here. #
|
|
16
|
+
# #
|
|
17
|
+
# (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) #
|
|
18
|
+
# #
|
|
19
|
+
# This aGrUM/pyAgrum library is distributed in the hope that it will be #
|
|
20
|
+
# useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, #
|
|
21
|
+
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS #
|
|
22
|
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
|
|
23
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
|
|
24
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, #
|
|
25
|
+
# ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR #
|
|
26
|
+
# OTHER DEALINGS IN THE SOFTWARE. #
|
|
27
|
+
# #
|
|
28
|
+
# See LICENCES for more details. #
|
|
29
|
+
# #
|
|
30
|
+
# SPDX-FileCopyrightText: Copyright 2005-2025 #
|
|
31
|
+
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
32
|
+
# - Christophe GONZALES(_at_AMU) #
|
|
33
|
+
# SPDX-License-Identifier: LGPL-3.0-or-later OR MIT #
|
|
34
|
+
# #
|
|
35
|
+
# Contact : info_at_agrum_dot_org #
|
|
36
|
+
# homepage : http://agrum.gitlab.io #
|
|
37
|
+
# gitlab : https://gitlab.com/agrumery/agrum #
|
|
38
|
+
# #
|
|
39
|
+
############################################################################
|
|
40
|
+
|
|
41
|
+
import pyagrum as gum
|
|
42
|
+
from pyagrum.explain._Explanation import Explanation
|
|
43
|
+
|
|
44
|
+
import numpy as np
|
|
45
|
+
|
|
46
|
+
import matplotlib.pyplot as plt
|
|
47
|
+
from matplotlib.colors import LinearSegmentedColormap, to_rgb
|
|
48
|
+
from matplotlib.patches import Patch
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def bar(explanation: Explanation, y: int = None, ax: plt.Axes = None, percentage: bool = False) -> plt.axis:
|
|
52
|
+
"""
|
|
53
|
+
Plots a horizontal bar chart of the mean absolute SHAP/SHALL values for each feature in the explanation.
|
|
54
|
+
|
|
55
|
+
Parameters:
|
|
56
|
+
----------
|
|
57
|
+
explanation : Explanation
|
|
58
|
+
The explanation object containing the SHAP/SHALL values.
|
|
59
|
+
y : int, optional
|
|
60
|
+
If the values type of the explanation is SHALL, then y is ignored.
|
|
61
|
+
Else it is the class for which to plot the SHAP values (default is None, which plots multi-bar for all classes).
|
|
62
|
+
ax : plt.Axes, optional
|
|
63
|
+
The matplotlib Axes object to plot on (default is None, which creates a new figure).
|
|
64
|
+
percentage: bool
|
|
65
|
+
if True, the importance plot is shown in percent.
|
|
66
|
+
|
|
67
|
+
Raises :
|
|
68
|
+
------
|
|
69
|
+
TypeError : If `explanation` is not an Explanation object or if `y` is not an integer or None.
|
|
70
|
+
IndexError : If `y` is an integer but out of bounds for the explanation keys.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
if not isinstance(explanation, Explanation):
|
|
74
|
+
raise TypeError(f"`explanation` must be an Explanation object but got {type(explanation)}")
|
|
75
|
+
|
|
76
|
+
# Determine if The explanation object is a SHALL or SHAP explanation
|
|
77
|
+
if explanation.values_type == "SHAP":
|
|
78
|
+
if not isinstance(y, int) and y is not None:
|
|
79
|
+
raise TypeError(f"`y` must be either a positive integer or None, but got {type(y)}")
|
|
80
|
+
if isinstance(y, int) and (y < min(explanation.keys()) or y > max(explanation.keys())):
|
|
81
|
+
raise IndexError(f"Target index y={y} is out of bounds; expected 0 <= y < {max(explanation.keys()) + 1}.")
|
|
82
|
+
elif explanation.values_type == "SHALL":
|
|
83
|
+
# We force y to be an integer, so we can use the same code after for both explanations
|
|
84
|
+
y = 0
|
|
85
|
+
else:
|
|
86
|
+
raise ValueError(f"Wrong values type, expected SHAP/SHALL but got {explanation.values_type}")
|
|
87
|
+
|
|
88
|
+
if ax is None:
|
|
89
|
+
_, ax = plt.subplots(figsize=(6, 4))
|
|
90
|
+
|
|
91
|
+
if y is not None:
|
|
92
|
+
importances = explanation.importances[y] if explanation.values_type == "SHAP" else explanation.importances
|
|
93
|
+
columns = [col for col in sorted(importances.keys(), key=importances.get)]
|
|
94
|
+
values = [importances[feat] for feat in columns]
|
|
95
|
+
if percentage:
|
|
96
|
+
total = sum(values)
|
|
97
|
+
values = [(v / total) * 100 for v in values]
|
|
98
|
+
ax.barh(columns, values, color=gum.config["notebook", "tensor_color_0"], height=0.5, alpha=0.8)
|
|
99
|
+
else:
|
|
100
|
+
classes = sorted(explanation.keys())
|
|
101
|
+
cmap = LinearSegmentedColormap.from_list(
|
|
102
|
+
"class_cmap", [to_rgb(gum.config["notebook", "tensor_color_0"]), to_rgb(gum.config["notebook", "tensor_color_1"])]
|
|
103
|
+
)
|
|
104
|
+
colors = [cmap(i / (len(explanation) - 1)) for i in range(len(explanation))]
|
|
105
|
+
|
|
106
|
+
n_features = len(explanation.feature_names)
|
|
107
|
+
values = np.array([[explanation.importances[z][feat] for feat in explanation.feature_names] for z in classes])
|
|
108
|
+
# Sort bars
|
|
109
|
+
indices = np.argsort(np.sum(values, axis=0))
|
|
110
|
+
values = values[:, indices]
|
|
111
|
+
features = [explanation.feature_names[i] for i in indices]
|
|
112
|
+
bottom = np.zeros(n_features)
|
|
113
|
+
|
|
114
|
+
for i, cls in enumerate(classes):
|
|
115
|
+
contribs = values[i]
|
|
116
|
+
if percentage:
|
|
117
|
+
total = sum(contribs)
|
|
118
|
+
contribs = [(v / total) * 100 for v in contribs]
|
|
119
|
+
ax.barh(
|
|
120
|
+
features, contribs, height=0.5, left=bottom, color=colors[i % len(colors)], label=f"class {cls}", alpha=0.8
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
bottom += contribs
|
|
124
|
+
legend_elements = [Patch(facecolor=colors[i], edgecolor="black", label=f"Class {i}") for i in range(len(colors))]
|
|
125
|
+
ax.legend(loc="lower right", handles=legend_elements, title="Classes")
|
|
126
|
+
|
|
127
|
+
ax.set_title("Feature Importance", fontsize=16)
|
|
128
|
+
|
|
129
|
+
msg = " in %" if percentage else ""
|
|
130
|
+
ax.set_xlabel(f"mean(|{explanation.values_type} value|){msg}", fontsize=12)
|
|
131
|
+
ax.set_ylabel("Features", fontsize=12)
|
|
132
|
+
ax.tick_params(axis="x", labelsize=10)
|
|
133
|
+
ax.tick_params(axis="y", labelsize=10)
|
|
134
|
+
|
|
135
|
+
# Removing spines
|
|
136
|
+
ax.grid(axis="x", linestyle=":", alpha=0.6)
|
|
137
|
+
ax.grid(axis="y", linestyle=":", alpha=0.3)
|
|
138
|
+
ax.spines["top"].set_visible(False)
|
|
139
|
+
ax.spines["bottom"].set_visible(False)
|
|
140
|
+
ax.spines["left"].set_visible(False)
|
|
141
|
+
ax.spines["right"].set_visible(False)
|
|
142
|
+
ax.figure.set_facecolor("white")
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
############################################################################
|
|
2
|
+
# This file is part of the aGrUM/pyAgrum library. #
|
|
3
|
+
# #
|
|
4
|
+
# Copyright (c) 2005-2025 by #
|
|
5
|
+
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
6
|
+
# - Christophe GONZALES(_at_AMU) #
|
|
7
|
+
# #
|
|
8
|
+
# The aGrUM/pyAgrum library is free software; you can redistribute it #
|
|
9
|
+
# and/or modify it under the terms of either : #
|
|
10
|
+
# #
|
|
11
|
+
# - the GNU Lesser General Public License as published by #
|
|
12
|
+
# the Free Software Foundation, either version 3 of the License, #
|
|
13
|
+
# or (at your option) any later version, #
|
|
14
|
+
# - the MIT license (MIT), #
|
|
15
|
+
# - or both in dual license, as here. #
|
|
16
|
+
# #
|
|
17
|
+
# (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) #
|
|
18
|
+
# #
|
|
19
|
+
# This aGrUM/pyAgrum library is distributed in the hope that it will be #
|
|
20
|
+
# useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, #
|
|
21
|
+
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS #
|
|
22
|
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
|
|
23
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
|
|
24
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, #
|
|
25
|
+
# ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR #
|
|
26
|
+
# OTHER DEALINGS IN THE SOFTWARE. #
|
|
27
|
+
# #
|
|
28
|
+
# See LICENCES for more details. #
|
|
29
|
+
# #
|
|
30
|
+
# SPDX-FileCopyrightText: Copyright 2005-2025 #
|
|
31
|
+
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
32
|
+
# - Christophe GONZALES(_at_AMU) #
|
|
33
|
+
# SPDX-License-Identifier: LGPL-3.0-or-later OR MIT #
|
|
34
|
+
# #
|
|
35
|
+
# Contact : info_at_agrum_dot_org #
|
|
36
|
+
# homepage : http://agrum.gitlab.io #
|
|
37
|
+
# gitlab : https://gitlab.com/agrumery/agrum #
|
|
38
|
+
# #
|
|
39
|
+
############################################################################
|
|
40
|
+
|
|
41
|
+
import pyagrum as gum
|
|
42
|
+
from pyagrum.explain._Explanation import Explanation
|
|
43
|
+
|
|
44
|
+
import numpy as np
|
|
45
|
+
|
|
46
|
+
import matplotlib.pyplot as plt
|
|
47
|
+
import matplotlib.cm as cm
|
|
48
|
+
from matplotlib import colors
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def beeswarm(
|
|
52
|
+
explanation: Explanation, y: int = 1, max_display: int = 20, color_bar: bool = True, ax=None, sort: bool = True
|
|
53
|
+
):
|
|
54
|
+
"""
|
|
55
|
+
Plots a beeswarm plot of the Shapley values for a given target class.
|
|
56
|
+
Parameters:
|
|
57
|
+
----------
|
|
58
|
+
explanation : Explanation
|
|
59
|
+
The explanation object containing the SHAP/SHALL values.
|
|
60
|
+
y : int
|
|
61
|
+
If the values type of the explanation is SHALL, then y is ignored.
|
|
62
|
+
Else it is the class for which to plot the SHAP values.
|
|
63
|
+
max_display : int, optional
|
|
64
|
+
The maximum number of features to display in the beeswarm plot (default is 20).
|
|
65
|
+
color_bar : bool, optional
|
|
66
|
+
If True, adds a color bar to the plot (default is True).
|
|
67
|
+
ax : plt.Axes, optional
|
|
68
|
+
The matplotlib Axes object to plot on (default is None, which creates a new figure).
|
|
69
|
+
sort : bool, optional
|
|
70
|
+
If True, sorts the features by their importance before plotting (default is True).
|
|
71
|
+
|
|
72
|
+
Raises:
|
|
73
|
+
------
|
|
74
|
+
TypeError
|
|
75
|
+
If `explanation` is not an Explanation object, if `y` is not an integer or if the explanation is not global (i.e., does not contain lists of contributions for each feature).
|
|
76
|
+
IndexError
|
|
77
|
+
If `y` is out of bounds for the explanation keys.
|
|
78
|
+
"""
|
|
79
|
+
# Check parameters
|
|
80
|
+
if not isinstance(explanation, Explanation):
|
|
81
|
+
raise TypeError("`explanation` must be an Explanation object but got {}".format(type(explanation)))
|
|
82
|
+
|
|
83
|
+
# Determine if The explanation object is a SHALL or SHAP explanation
|
|
84
|
+
if explanation.values_type == "SHAP":
|
|
85
|
+
if not isinstance(y, int):
|
|
86
|
+
raise TypeError("`y` must be an integer but got {}".format(type(y)))
|
|
87
|
+
if y < min(explanation.keys()) or y > max(explanation.keys()):
|
|
88
|
+
raise IndexError(f"Target index y={y} is out of bounds; expected 0 <= y < {max(explanation.keys()) + 1}.")
|
|
89
|
+
contributions = explanation[y]
|
|
90
|
+
importances = explanation.importances[y]
|
|
91
|
+
elif explanation.values_type == "SHALL":
|
|
92
|
+
contributions = explanation
|
|
93
|
+
importances = explanation.importances
|
|
94
|
+
else:
|
|
95
|
+
raise ValueError(f"Wrong values type, expected SHAP/SHALL but got {explanation.values_type}")
|
|
96
|
+
|
|
97
|
+
feature_names = explanation.feature_names
|
|
98
|
+
if not isinstance(list(contributions.values())[0], list):
|
|
99
|
+
raise TypeError("For beeswarm plot, explanation must be global.")
|
|
100
|
+
values = np.array([contributions[k] for k in feature_names]).T
|
|
101
|
+
features = explanation.data
|
|
102
|
+
|
|
103
|
+
# Create the figure and axis if not provided
|
|
104
|
+
if ax == None:
|
|
105
|
+
_, ax = plt.subplots()
|
|
106
|
+
|
|
107
|
+
# Prepare the y-axis positions
|
|
108
|
+
y_positions = np.arange(min(max_display, len(feature_names)), 0, -1)
|
|
109
|
+
|
|
110
|
+
# Plot the beeswarm
|
|
111
|
+
ax.plot([0, 0], [y_positions[-1] - 0.25, y_positions[0] + 0.25], linestyle="--", color="gray")
|
|
112
|
+
color1 = gum.config["notebook", "tensor_color_0"]
|
|
113
|
+
color2 = gum.config["notebook", "tensor_color_1"]
|
|
114
|
+
cmap = colors.LinearSegmentedColormap.from_list("custom_red_green", [color1, color2])
|
|
115
|
+
|
|
116
|
+
if sort:
|
|
117
|
+
indices = [feature_names.index(feat) for feat in sorted(importances, key=importances.get, reverse=True)]
|
|
118
|
+
else:
|
|
119
|
+
indices = np.arange(min(max_display, values.shape[1]))
|
|
120
|
+
|
|
121
|
+
for k, j in enumerate(indices):
|
|
122
|
+
base = y_positions[k]
|
|
123
|
+
sequence = np.arange(values.shape[0])
|
|
124
|
+
np.random.shuffle(sequence)
|
|
125
|
+
|
|
126
|
+
shapes = values[sequence, j]
|
|
127
|
+
rounded_x = np.round(shapes, 2)
|
|
128
|
+
(unique, counts) = np.unique(rounded_x, return_counts=True)
|
|
129
|
+
density_map = dict(zip(unique, counts))
|
|
130
|
+
densities = np.array([density_map[val] for val in rounded_x])
|
|
131
|
+
|
|
132
|
+
sigmas = (densities / np.max(densities)) * 0.1
|
|
133
|
+
ords = np.random.normal(loc=base, scale=sigmas)
|
|
134
|
+
|
|
135
|
+
vals = features[sequence, j]
|
|
136
|
+
|
|
137
|
+
minimum = vals.min()
|
|
138
|
+
maximum = vals.max()
|
|
139
|
+
norm = colors.Normalize(vmin=minimum, vmax=maximum)
|
|
140
|
+
sm = cm.ScalarMappable(cmap=cmap, norm=norm)
|
|
141
|
+
sm.set_array([])
|
|
142
|
+
|
|
143
|
+
ax.scatter(
|
|
144
|
+
shapes,
|
|
145
|
+
ords,
|
|
146
|
+
c=vals,
|
|
147
|
+
cmap=cmap,
|
|
148
|
+
s=7,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
ax.set_yticks(y_positions)
|
|
152
|
+
ax.set_yticklabels([feature_names[i] for i in indices])
|
|
153
|
+
if color_bar:
|
|
154
|
+
norm = colors.Normalize(vmin=0.0, vmax=1.0)
|
|
155
|
+
sm = cm.ScalarMappable(cmap=cmap, norm=norm)
|
|
156
|
+
sm.set_array([])
|
|
157
|
+
cbar = plt.colorbar(sm, ax=ax)
|
|
158
|
+
cbar.set_label("Feature value")
|
|
159
|
+
cbar.set_ticks([0, 1])
|
|
160
|
+
cbar.set_ticklabels(["Low", "High"])
|
|
161
|
+
|
|
162
|
+
ax.set_ylim(y_positions[-1] - 0.5, y_positions[0] + 0.5)
|
|
163
|
+
ax.set_xlabel("Impact on model Output", fontsize=12)
|
|
164
|
+
ax.set_ylabel("Features", fontsize=12)
|
|
165
|
+
ax.set_title(f"{explanation.values_type} value (Impact on model Output)", fontsize=16)
|
|
166
|
+
|
|
167
|
+
# Setting the style
|
|
168
|
+
ax.grid(axis="x", linestyle=":", alpha=0.5)
|
|
169
|
+
ax.grid(axis="y", linestyle=":", alpha=0.5)
|
|
170
|
+
ax.spines["top"].set_visible(False)
|
|
171
|
+
ax.spines["bottom"].set_visible(False)
|
|
172
|
+
ax.spines["left"].set_visible(False)
|
|
173
|
+
ax.spines["right"].set_visible(False)
|
|
174
|
+
ax.figure.set_facecolor("white")
|