pyAgrum-nightly 2.3.1.9.dev202512261765915415__cp310-abi3-macosx_10_15_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyagrum/__init__.py +165 -0
- pyagrum/_pyagrum.so +0 -0
- pyagrum/bnmixture/BNMInference.py +268 -0
- pyagrum/bnmixture/BNMLearning.py +376 -0
- pyagrum/bnmixture/BNMixture.py +464 -0
- pyagrum/bnmixture/__init__.py +60 -0
- pyagrum/bnmixture/notebook.py +1058 -0
- pyagrum/causal/_CausalFormula.py +280 -0
- pyagrum/causal/_CausalModel.py +436 -0
- pyagrum/causal/__init__.py +81 -0
- pyagrum/causal/_causalImpact.py +356 -0
- pyagrum/causal/_dSeparation.py +598 -0
- pyagrum/causal/_doAST.py +761 -0
- pyagrum/causal/_doCalculus.py +361 -0
- pyagrum/causal/_doorCriteria.py +374 -0
- pyagrum/causal/_exceptions.py +95 -0
- pyagrum/causal/_types.py +61 -0
- pyagrum/causal/causalEffectEstimation/_CausalEffectEstimation.py +1175 -0
- pyagrum/causal/causalEffectEstimation/_IVEstimators.py +718 -0
- pyagrum/causal/causalEffectEstimation/_RCTEstimators.py +132 -0
- pyagrum/causal/causalEffectEstimation/__init__.py +46 -0
- pyagrum/causal/causalEffectEstimation/_backdoorEstimators.py +774 -0
- pyagrum/causal/causalEffectEstimation/_causalBNEstimator.py +324 -0
- pyagrum/causal/causalEffectEstimation/_frontdoorEstimators.py +396 -0
- pyagrum/causal/causalEffectEstimation/_learners.py +118 -0
- pyagrum/causal/causalEffectEstimation/_utils.py +466 -0
- pyagrum/causal/notebook.py +172 -0
- pyagrum/clg/CLG.py +658 -0
- pyagrum/clg/GaussianVariable.py +111 -0
- pyagrum/clg/SEM.py +312 -0
- pyagrum/clg/__init__.py +63 -0
- pyagrum/clg/canonicalForm.py +408 -0
- pyagrum/clg/constants.py +54 -0
- pyagrum/clg/forwardSampling.py +202 -0
- pyagrum/clg/learning.py +776 -0
- pyagrum/clg/notebook.py +480 -0
- pyagrum/clg/variableElimination.py +271 -0
- pyagrum/common.py +60 -0
- pyagrum/config.py +319 -0
- pyagrum/ctbn/CIM.py +513 -0
- pyagrum/ctbn/CTBN.py +573 -0
- pyagrum/ctbn/CTBNGenerator.py +216 -0
- pyagrum/ctbn/CTBNInference.py +459 -0
- pyagrum/ctbn/CTBNLearner.py +161 -0
- pyagrum/ctbn/SamplesStats.py +671 -0
- pyagrum/ctbn/StatsIndepTest.py +355 -0
- pyagrum/ctbn/__init__.py +79 -0
- pyagrum/ctbn/constants.py +54 -0
- pyagrum/ctbn/notebook.py +264 -0
- pyagrum/defaults.ini +199 -0
- pyagrum/deprecated.py +95 -0
- pyagrum/explain/_ComputationCausal.py +75 -0
- pyagrum/explain/_ComputationConditional.py +48 -0
- pyagrum/explain/_ComputationMarginal.py +48 -0
- pyagrum/explain/_CustomShapleyCache.py +110 -0
- pyagrum/explain/_Explainer.py +176 -0
- pyagrum/explain/_Explanation.py +70 -0
- pyagrum/explain/_FIFOCache.py +54 -0
- pyagrum/explain/_ShallCausalValues.py +204 -0
- pyagrum/explain/_ShallConditionalValues.py +155 -0
- pyagrum/explain/_ShallMarginalValues.py +155 -0
- pyagrum/explain/_ShallValues.py +296 -0
- pyagrum/explain/_ShapCausalValues.py +208 -0
- pyagrum/explain/_ShapConditionalValues.py +126 -0
- pyagrum/explain/_ShapMarginalValues.py +191 -0
- pyagrum/explain/_ShapleyValues.py +298 -0
- pyagrum/explain/__init__.py +81 -0
- pyagrum/explain/_explGeneralizedMarkovBlanket.py +152 -0
- pyagrum/explain/_explIndependenceListForPairs.py +146 -0
- pyagrum/explain/_explInformationGraph.py +264 -0
- pyagrum/explain/notebook/__init__.py +54 -0
- pyagrum/explain/notebook/_bar.py +142 -0
- pyagrum/explain/notebook/_beeswarm.py +174 -0
- pyagrum/explain/notebook/_showShapValues.py +97 -0
- pyagrum/explain/notebook/_waterfall.py +220 -0
- pyagrum/explain/shapley.py +225 -0
- pyagrum/lib/__init__.py +46 -0
- pyagrum/lib/_colors.py +390 -0
- pyagrum/lib/bn2graph.py +299 -0
- pyagrum/lib/bn2roc.py +1026 -0
- pyagrum/lib/bn2scores.py +217 -0
- pyagrum/lib/bn_vs_bn.py +605 -0
- pyagrum/lib/cn2graph.py +305 -0
- pyagrum/lib/discreteTypeProcessor.py +1102 -0
- pyagrum/lib/discretizer.py +58 -0
- pyagrum/lib/dynamicBN.py +390 -0
- pyagrum/lib/explain.py +57 -0
- pyagrum/lib/export.py +84 -0
- pyagrum/lib/id2graph.py +258 -0
- pyagrum/lib/image.py +387 -0
- pyagrum/lib/ipython.py +307 -0
- pyagrum/lib/mrf2graph.py +471 -0
- pyagrum/lib/notebook.py +1821 -0
- pyagrum/lib/proba_histogram.py +552 -0
- pyagrum/lib/utils.py +138 -0
- pyagrum/pyagrum.py +31495 -0
- pyagrum/skbn/_MBCalcul.py +242 -0
- pyagrum/skbn/__init__.py +49 -0
- pyagrum/skbn/_learningMethods.py +282 -0
- pyagrum/skbn/_utils.py +297 -0
- pyagrum/skbn/bnclassifier.py +1014 -0
- pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/LICENSE.md +12 -0
- pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/LICENSES/LGPL-3.0-or-later.txt +304 -0
- pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/LICENSES/MIT.txt +18 -0
- pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/METADATA +145 -0
- pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/RECORD +107 -0
- pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
############################################################################
|
|
2
|
+
# This file is part of the aGrUM/pyAgrum library. #
|
|
3
|
+
# #
|
|
4
|
+
# Copyright (c) 2005-2025 by #
|
|
5
|
+
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
6
|
+
# - Christophe GONZALES(_at_AMU) #
|
|
7
|
+
# #
|
|
8
|
+
# The aGrUM/pyAgrum library is free software; you can redistribute it #
|
|
9
|
+
# and/or modify it under the terms of either : #
|
|
10
|
+
# #
|
|
11
|
+
# - the GNU Lesser General Public License as published by #
|
|
12
|
+
# the Free Software Foundation, either version 3 of the License, #
|
|
13
|
+
# or (at your option) any later version, #
|
|
14
|
+
# - the MIT license (MIT), #
|
|
15
|
+
# - or both in dual license, as here. #
|
|
16
|
+
# #
|
|
17
|
+
# (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) #
|
|
18
|
+
# #
|
|
19
|
+
# This aGrUM/pyAgrum library is distributed in the hope that it will be #
|
|
20
|
+
# useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, #
|
|
21
|
+
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS #
|
|
22
|
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
|
|
23
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
|
|
24
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, #
|
|
25
|
+
# ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR #
|
|
26
|
+
# OTHER DEALINGS IN THE SOFTWARE. #
|
|
27
|
+
# #
|
|
28
|
+
# See LICENCES for more details. #
|
|
29
|
+
# #
|
|
30
|
+
# SPDX-FileCopyrightText: Copyright 2005-2025 #
|
|
31
|
+
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
32
|
+
# - Christophe GONZALES(_at_AMU) #
|
|
33
|
+
# SPDX-License-Identifier: LGPL-3.0-or-later OR MIT #
|
|
34
|
+
# #
|
|
35
|
+
# Contact : info_at_agrum_dot_org #
|
|
36
|
+
# homepage : http://agrum.gitlab.io #
|
|
37
|
+
# gitlab : https://gitlab.com/agrumery/agrum #
|
|
38
|
+
# #
|
|
39
|
+
############################################################################
|
|
40
|
+
|
|
41
|
+
import pyagrum as gum
|
|
42
|
+
import pyagrum.lib.notebook as gnb
|
|
43
|
+
from pyagrum.lib.bn2graph import BN2dot
|
|
44
|
+
from pyagrum.explain import Explanation
|
|
45
|
+
|
|
46
|
+
import matplotlib.pyplot as plt
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def showShapValues(bn: gum.BayesNet, expl: Explanation | dict, cmap="plasma", y=1):
|
|
50
|
+
"""
|
|
51
|
+
Show the Shap values in the DAG of the BN
|
|
52
|
+
Parameters
|
|
53
|
+
----------
|
|
54
|
+
bn : pyagrum.BayesNet
|
|
55
|
+
The Bayesian network
|
|
56
|
+
expl: Explanation | dict[str,float]
|
|
57
|
+
The Shap values to each variable
|
|
58
|
+
cmap: str
|
|
59
|
+
Name of the Matplotlib colormap used for coloring the nodes.
|
|
60
|
+
y: int
|
|
61
|
+
The target class for which the Shap values are computed (default is 1).
|
|
62
|
+
y is ignored if `expl` is a dict.
|
|
63
|
+
|
|
64
|
+
Raises
|
|
65
|
+
------
|
|
66
|
+
TypeError
|
|
67
|
+
If bn is not a gum.BayesNet, if expl is neither an Explanation nor a dict, or if expl is an Explanation and y is not an integer.
|
|
68
|
+
IndexError
|
|
69
|
+
If expl is an Explanation and y is outside the valid class range.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
if not isinstance(bn, gum.BayesNet):
|
|
73
|
+
raise TypeError(f"The parameter bn must be a gum.BayesNet but got {type(bn)}")
|
|
74
|
+
if isinstance(expl, Explanation):
|
|
75
|
+
if isinstance(y, int):
|
|
76
|
+
if y < min(expl.keys()) or y > max(expl.keys()):
|
|
77
|
+
raise IndexError(
|
|
78
|
+
f"Target index y={y} is out of bounds; expected {min(expl.keys())} <= y < {max(expl.keys()) + 1}."
|
|
79
|
+
)
|
|
80
|
+
else:
|
|
81
|
+
raise TypeError("`y`must be an integer but got {}".format(y))
|
|
82
|
+
importances = expl.importances[y]
|
|
83
|
+
|
|
84
|
+
elif isinstance(expl, dict):
|
|
85
|
+
importances = expl
|
|
86
|
+
|
|
87
|
+
else:
|
|
88
|
+
raise TypeError(f"The parameter expl must be either an Explanation object or a dict but got {type(expl)}")
|
|
89
|
+
|
|
90
|
+
norm_color = {}
|
|
91
|
+
raw = list(importances.values())
|
|
92
|
+
norm = [float(i) / sum(raw) for i in raw]
|
|
93
|
+
for i, feat in enumerate(list(importances.keys())):
|
|
94
|
+
norm_color[feat] = norm[i]
|
|
95
|
+
cm = plt.get_cmap(cmap)
|
|
96
|
+
g = BN2dot(bn, nodeColor=norm_color, cmapNode=cm)
|
|
97
|
+
gnb.showGraph(g)
|
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
############################################################################
|
|
2
|
+
# This file is part of the aGrUM/pyAgrum library. #
|
|
3
|
+
# #
|
|
4
|
+
# Copyright (c) 2005-2025 by #
|
|
5
|
+
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
6
|
+
# - Christophe GONZALES(_at_AMU) #
|
|
7
|
+
# #
|
|
8
|
+
# The aGrUM/pyAgrum library is free software; you can redistribute it #
|
|
9
|
+
# and/or modify it under the terms of either : #
|
|
10
|
+
# #
|
|
11
|
+
# - the GNU Lesser General Public License as published by #
|
|
12
|
+
# the Free Software Foundation, either version 3 of the License, #
|
|
13
|
+
# or (at your option) any later version, #
|
|
14
|
+
# - the MIT license (MIT), #
|
|
15
|
+
# - or both in dual license, as here. #
|
|
16
|
+
# #
|
|
17
|
+
# (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) #
|
|
18
|
+
# #
|
|
19
|
+
# This aGrUM/pyAgrum library is distributed in the hope that it will be #
|
|
20
|
+
# useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, #
|
|
21
|
+
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS #
|
|
22
|
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
|
|
23
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
|
|
24
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, #
|
|
25
|
+
# ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR #
|
|
26
|
+
# OTHER DEALINGS IN THE SOFTWARE. #
|
|
27
|
+
# #
|
|
28
|
+
# See LICENCES for more details. #
|
|
29
|
+
# #
|
|
30
|
+
# SPDX-FileCopyrightText: Copyright 2005-2025 #
|
|
31
|
+
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
32
|
+
# - Christophe GONZALES(_at_AMU) #
|
|
33
|
+
# SPDX-License-Identifier: LGPL-3.0-or-later OR MIT #
|
|
34
|
+
# #
|
|
35
|
+
# Contact : info_at_agrum_dot_org #
|
|
36
|
+
# homepage : http://agrum.gitlab.io #
|
|
37
|
+
# gitlab : https://gitlab.com/agrumery/agrum #
|
|
38
|
+
# #
|
|
39
|
+
############################################################################
|
|
40
|
+
|
|
41
|
+
import pyagrum as gum
|
|
42
|
+
from pyagrum.explain._Explanation import Explanation
|
|
43
|
+
from typing import Callable, Dict
|
|
44
|
+
import numpy as np
|
|
45
|
+
|
|
46
|
+
import matplotlib.pyplot as plt
|
|
47
|
+
from matplotlib.patches import Polygon
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _POSTERIOR(y: int, func: str) -> str:
|
|
51
|
+
return f"logit($p(y={y} \\mid x)$)" if func == "_logit" else f"$p(y={y} \\mid x)$"
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _JOIN(func: str) -> str:
|
|
55
|
+
return "log($p(x \\mid \\theta)$)" if func == "_log" else "$p(x \\mid \\theta)$"
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _FMT(func: str) -> str:
|
|
59
|
+
return ".2e" if func == "_identity" else ".2f"
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def waterfall(explanation: Explanation, y: int = 1, ax=None, real_values: Dict = None):
|
|
63
|
+
"""
|
|
64
|
+
Plots a waterfall chart of the SHAP/SHALL values.
|
|
65
|
+
|
|
66
|
+
Parameters:
|
|
67
|
+
----------
|
|
68
|
+
explanation : Explanation
|
|
69
|
+
The explanation object containing the SHAP/SHALL values.
|
|
70
|
+
y : int, optional
|
|
71
|
+
If the values type of the explanation is SHALL, then y is ignored.
|
|
72
|
+
Else it is the class for which to plot the SHAP values.
|
|
73
|
+
ax : matplotlib.Axes, optional
|
|
74
|
+
The matplotlib Axes object to plot on (default is None, which creates a new figure).
|
|
75
|
+
real_values : Dict, optional
|
|
76
|
+
Dictionary used to display custum values for each feature.
|
|
77
|
+
For example, useful when continuous values have been discretized but you still want to show the original continuous values from the database.
|
|
78
|
+
The keys of the dictionary must match the keys in the Explanation object, and the values are the values you want to display on the plot.
|
|
79
|
+
|
|
80
|
+
Raises :
|
|
81
|
+
------
|
|
82
|
+
TypeError : If `explanation` is not an Explanation object or if `y` is not an integer.
|
|
83
|
+
IndexError : If `y` is an integer but out of bounds for the explanation keys.
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
if not isinstance(explanation, Explanation):
|
|
87
|
+
raise TypeError("`explanation` must be an Explanation object but got {}".format(type(explanation)))
|
|
88
|
+
if explanation.values_type == "SHAP":
|
|
89
|
+
if isinstance(y, int):
|
|
90
|
+
if y < min(explanation.keys()) or y > max(explanation.keys()):
|
|
91
|
+
raise IndexError(f"Target index y={y} is out of bounds; expected 0 <= y < {max(explanation.keys()) + 1}.")
|
|
92
|
+
else:
|
|
93
|
+
raise TypeError("`y`must be an integer but got {}".format(y))
|
|
94
|
+
values = explanation[y]
|
|
95
|
+
baseline = explanation.baseline[y]
|
|
96
|
+
elif explanation.values_type == "SHALL":
|
|
97
|
+
values = explanation._values
|
|
98
|
+
baseline = explanation.baseline
|
|
99
|
+
else:
|
|
100
|
+
raise ValueError(f"Wrong values type, expected SHAP/SHALL but got {explanation.values_type}")
|
|
101
|
+
|
|
102
|
+
# Computing arrow width
|
|
103
|
+
arrow_width_base = 0.08 * np.max(np.abs(np.array(list(values.values()))))
|
|
104
|
+
|
|
105
|
+
# Tri des SHAP values par importance décroissante
|
|
106
|
+
features = [feature for feature in sorted(values.keys(), key=lambda x: abs(values.get(x)), reverse=True)]
|
|
107
|
+
y_positions = np.arange(len(values) * 0.25, 0, -0.25)
|
|
108
|
+
|
|
109
|
+
# Create the figure and axis if not provided
|
|
110
|
+
if ax == None:
|
|
111
|
+
_, ax = plt.subplots()
|
|
112
|
+
|
|
113
|
+
# Ligne de base :
|
|
114
|
+
ax.plot([baseline, baseline], [y_positions[-1] - 0.25, y_positions[0] + 0.25], linestyle="--", color="gray")
|
|
115
|
+
if explanation.values_type == "SHAP":
|
|
116
|
+
ax.text(
|
|
117
|
+
baseline,
|
|
118
|
+
y_positions[0] + 0.5,
|
|
119
|
+
f"E({_POSTERIOR(y, explanation.func)}) = {baseline:{_FMT(explanation.func)}}",
|
|
120
|
+
ha="center",
|
|
121
|
+
va="bottom",
|
|
122
|
+
color="gray",
|
|
123
|
+
)
|
|
124
|
+
else:
|
|
125
|
+
ax.text(
|
|
126
|
+
baseline,
|
|
127
|
+
y_positions[0] + 0.5,
|
|
128
|
+
f"E({_JOIN(explanation.func)}) = {baseline:{_FMT(explanation.func)}}",
|
|
129
|
+
ha="center",
|
|
130
|
+
va="bottom",
|
|
131
|
+
color="gray",
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# Lignes de shapes-values
|
|
135
|
+
current_x = min_x = max_x = baseline
|
|
136
|
+
|
|
137
|
+
for i, feature in enumerate(features):
|
|
138
|
+
delta = values[feature]
|
|
139
|
+
x_start = current_x
|
|
140
|
+
x_end = current_x + delta
|
|
141
|
+
z = y_positions[i]
|
|
142
|
+
height = 0.2
|
|
143
|
+
arrow_width = min(0.4 * abs(delta), arrow_width_base)
|
|
144
|
+
facecolor, edgecolor, alpha = (
|
|
145
|
+
(gum.config["notebook", "tensor_color_0"], "#D98383", 1)
|
|
146
|
+
if values[feature] <= 0
|
|
147
|
+
else (gum.config["notebook", "tensor_color_1"], "#82D882", -1)
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
# Dessin du polygon
|
|
151
|
+
polygon = Polygon(
|
|
152
|
+
[
|
|
153
|
+
(x_end + alpha * arrow_width, z - height / 2),
|
|
154
|
+
(x_start, z - height / 2),
|
|
155
|
+
(x_start, z + height / 2),
|
|
156
|
+
(x_end + alpha * arrow_width, z + height / 2),
|
|
157
|
+
(x_end, z), # pointe
|
|
158
|
+
],
|
|
159
|
+
closed=True,
|
|
160
|
+
facecolor=facecolor,
|
|
161
|
+
edgecolor=edgecolor,
|
|
162
|
+
alpha=0.8,
|
|
163
|
+
linewidth=2,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
ax.add_patch(polygon)
|
|
167
|
+
|
|
168
|
+
current_x = x_end
|
|
169
|
+
min_x = min(min_x, current_x)
|
|
170
|
+
max_x = max(max_x, current_x)
|
|
171
|
+
|
|
172
|
+
# Ligne de sortie du modèle
|
|
173
|
+
ax.plot([current_x, current_x], [y_positions[-1] - 0.25, y_positions[0] + 0.25], linestyle="--", color="Black")
|
|
174
|
+
if explanation.func == "_logit":
|
|
175
|
+
ax.text(
|
|
176
|
+
current_x,
|
|
177
|
+
y_positions[-1] - 0.5,
|
|
178
|
+
f"{_POSTERIOR(y, explanation.func)} = {current_x:{_FMT(explanation.func)}}",
|
|
179
|
+
ha="center",
|
|
180
|
+
va="bottom",
|
|
181
|
+
color="Black",
|
|
182
|
+
)
|
|
183
|
+
else:
|
|
184
|
+
ax.text(
|
|
185
|
+
current_x,
|
|
186
|
+
y_positions[-1] - 0.5,
|
|
187
|
+
f"{_JOIN(explanation.func)} = {current_x:{_FMT(explanation.func)}}",
|
|
188
|
+
ha="center",
|
|
189
|
+
va="bottom",
|
|
190
|
+
color="Black",
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
y_tickslabels = []
|
|
194
|
+
|
|
195
|
+
for feature in features:
|
|
196
|
+
feat_shap_value = values[feature]
|
|
197
|
+
if real_values is not None:
|
|
198
|
+
value = real_values[feature]
|
|
199
|
+
if isinstance(real_values[feature], float):
|
|
200
|
+
value = round(value, 2)
|
|
201
|
+
else:
|
|
202
|
+
value = explanation.data[explanation.feature_names.index(feature)]
|
|
203
|
+
|
|
204
|
+
y_tickslabels.append(f"{feature} = {value} [{feat_shap_value:{_FMT(explanation.func)}}]")
|
|
205
|
+
|
|
206
|
+
ax.set_yticks(y_positions)
|
|
207
|
+
ax.set_yticklabels(y_tickslabels)
|
|
208
|
+
|
|
209
|
+
# Setting the style
|
|
210
|
+
ax.grid(axis="x", linestyle=":", alpha=0.5)
|
|
211
|
+
ax.grid(axis="y", alpha=0.5)
|
|
212
|
+
ax.spines["top"].set_visible(False)
|
|
213
|
+
ax.spines["bottom"].set_visible(False)
|
|
214
|
+
ax.spines["left"].set_visible(False)
|
|
215
|
+
ax.spines["right"].set_visible(False)
|
|
216
|
+
ax.figure.set_facecolor("White")
|
|
217
|
+
|
|
218
|
+
plt.ylim(min(y_positions) - 1, max(y_positions) + 1)
|
|
219
|
+
delta = max_x - min_x
|
|
220
|
+
plt.xlim(min_x - 0.05 * delta, max_x + 0.05 * delta)
|
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
############################################################################
|
|
2
|
+
# This file is part of the aGrUM/pyAgrum library. #
|
|
3
|
+
# #
|
|
4
|
+
# Copyright (c) 2005-2025 by #
|
|
5
|
+
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
6
|
+
# - Christophe GONZALES(_at_AMU) #
|
|
7
|
+
# #
|
|
8
|
+
# The aGrUM/pyAgrum library is free software; you can redistribute it #
|
|
9
|
+
# and/or modify it under the terms of either : #
|
|
10
|
+
# #
|
|
11
|
+
# - the GNU Lesser General Public License as published by #
|
|
12
|
+
# the Free Software Foundation, either version 3 of the License, #
|
|
13
|
+
# or (at your option) any later version, #
|
|
14
|
+
# - the MIT license (MIT), #
|
|
15
|
+
# - or both in dual license, as here. #
|
|
16
|
+
# #
|
|
17
|
+
# (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) #
|
|
18
|
+
# #
|
|
19
|
+
# This aGrUM/pyAgrum library is distributed in the hope that it will be #
|
|
20
|
+
# useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, #
|
|
21
|
+
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS #
|
|
22
|
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
|
|
23
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
|
|
24
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, #
|
|
25
|
+
# ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR #
|
|
26
|
+
# OTHER DEALINGS IN THE SOFTWARE. #
|
|
27
|
+
# #
|
|
28
|
+
# See LICENCES for more details. #
|
|
29
|
+
# #
|
|
30
|
+
# SPDX-FileCopyrightText: Copyright 2005-2025 #
|
|
31
|
+
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
32
|
+
# - Christophe GONZALES(_at_AMU) #
|
|
33
|
+
# SPDX-License-Identifier: LGPL-3.0-or-later OR MIT #
|
|
34
|
+
# #
|
|
35
|
+
# Contact : info_at_agrum_dot_org #
|
|
36
|
+
# homepage : http://agrum.gitlab.io #
|
|
37
|
+
# gitlab : https://gitlab.com/agrumery/agrum #
|
|
38
|
+
# #
|
|
39
|
+
############################################################################
|
|
40
|
+
|
|
41
|
+
import warnings
|
|
42
|
+
|
|
43
|
+
# ShapValues
|
|
44
|
+
from pyagrum.explain import CausalShapValues, ConditionalShapValues, Explanation, MarginalShapValues
|
|
45
|
+
|
|
46
|
+
# Calculations
|
|
47
|
+
import pyagrum as gum
|
|
48
|
+
|
|
49
|
+
# Plots
|
|
50
|
+
import matplotlib.pyplot as plt
|
|
51
|
+
from pyagrum.explain.notebook import bar, beeswarm, waterfall
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class ShapValues:
|
|
55
|
+
"""
|
|
56
|
+
Class to compute Shapley values for a target variable in a Bayesian network.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(self, bn, target, logit=True):
|
|
60
|
+
"""
|
|
61
|
+
Parameters:
|
|
62
|
+
------
|
|
63
|
+
bn : pyagrum.BayesNet
|
|
64
|
+
The Bayesian Network.
|
|
65
|
+
target : int | str
|
|
66
|
+
The node id (or node name) of the target.
|
|
67
|
+
background : Tuple(pandas.DataFrame, bool) | None
|
|
68
|
+
A tuple containing a pandas DataFrame and a boolean indicating whether the DataFrame contains labels or positions.
|
|
69
|
+
sample_size : int
|
|
70
|
+
The size of the background sample to generate if `background` is None.
|
|
71
|
+
logit : bool
|
|
72
|
+
If True, applies the logit transformation to the probabilities.
|
|
73
|
+
|
|
74
|
+
Raises:
|
|
75
|
+
------
|
|
76
|
+
TypeError : If bn is not a gum.BayesNet instance or target is not an integer or string.
|
|
77
|
+
ValueError : If target is not a valid node id in the Bayesian Network.
|
|
78
|
+
"""
|
|
79
|
+
if not isinstance(bn, gum.BayesNet):
|
|
80
|
+
raise TypeError("bn must be a gum.BayesNet instance, but got {}".format(type(bn)))
|
|
81
|
+
if isinstance(target, str):
|
|
82
|
+
if target not in bn.names():
|
|
83
|
+
raise ValueError("Target node name '{}' not found in the Bayesian Network.".format(target))
|
|
84
|
+
target = bn.idFromName(target) # Convert node name to ID.
|
|
85
|
+
elif isinstance(target, int):
|
|
86
|
+
if target not in bn.nodes():
|
|
87
|
+
raise ValueError("Target node ID {} not found in the Bayesian Network.".format(target))
|
|
88
|
+
else:
|
|
89
|
+
raise TypeError("Target must be a node ID (int) or a node name (str), but got {}".format(type(target)))
|
|
90
|
+
if not isinstance(logit, bool):
|
|
91
|
+
warnings.warn("logit should be a boolean, unexpected calculation may occur.", UserWarning)
|
|
92
|
+
|
|
93
|
+
# Class attributes.
|
|
94
|
+
self.bn = bn
|
|
95
|
+
self.target = target
|
|
96
|
+
self.logit = logit
|
|
97
|
+
|
|
98
|
+
@staticmethod
|
|
99
|
+
def _plot(explanation: Explanation, y: int, plot: bool, plot_importance: bool, percentage: bool, filename: str):
|
|
100
|
+
ndim = explanation.data.ndim
|
|
101
|
+
n_figures = sum([plot, plot_importance])
|
|
102
|
+
plot_index = 0
|
|
103
|
+
# Local Explanation
|
|
104
|
+
if n_figures > 0:
|
|
105
|
+
_, axs = plt.subplots(1, n_figures, figsize=(n_figures * 6, 5))
|
|
106
|
+
if n_figures == 1:
|
|
107
|
+
axs = [axs] # consistent type for axs
|
|
108
|
+
|
|
109
|
+
if plot:
|
|
110
|
+
if ndim == 1:
|
|
111
|
+
waterfall(explanation=explanation, y=y, ax=axs[plot_index])
|
|
112
|
+
else:
|
|
113
|
+
beeswarm(explanation=explanation, y=y, ax=axs[plot_index])
|
|
114
|
+
plot_index += 1
|
|
115
|
+
if plot_importance:
|
|
116
|
+
bar(explanation=explanation, y=y, ax=axs[plot_index], percentage=percentage)
|
|
117
|
+
|
|
118
|
+
plt.tight_layout()
|
|
119
|
+
|
|
120
|
+
if filename is None:
|
|
121
|
+
plt.show()
|
|
122
|
+
else:
|
|
123
|
+
plt.savefig(filename)
|
|
124
|
+
plt.close()
|
|
125
|
+
|
|
126
|
+
def conditional(
|
|
127
|
+
self,
|
|
128
|
+
df,
|
|
129
|
+
y: int = 1,
|
|
130
|
+
plot: bool = False,
|
|
131
|
+
plot_importance: bool = False,
|
|
132
|
+
percentage: bool = False,
|
|
133
|
+
filename: str = None,
|
|
134
|
+
):
|
|
135
|
+
"""
|
|
136
|
+
Computes the conditional Shapley values for each variable.
|
|
137
|
+
|
|
138
|
+
Parameters:
|
|
139
|
+
----------
|
|
140
|
+
df : pandas DataFrame
|
|
141
|
+
The input data for which to compute the Shapley values.
|
|
142
|
+
y : int, optional
|
|
143
|
+
The target class for which to compute the Shapley values (default is 1).
|
|
144
|
+
plot : bool, optional
|
|
145
|
+
If True, plots the waterfall or beeswarm plot depending on the number of rows in df (default is False).
|
|
146
|
+
plot_importance : bool, optional
|
|
147
|
+
If True, plots the bar chart of feature importance (default is False).
|
|
148
|
+
percentage: bool
|
|
149
|
+
if True, the importance plot is shown in percent.
|
|
150
|
+
filename : str, optional
|
|
151
|
+
If provided, saves the plots to the specified filename instead of displaying them.
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
-------
|
|
155
|
+
Dict[str, float]
|
|
156
|
+
A dictionary containing the importances of each variable in the input data.
|
|
157
|
+
"""
|
|
158
|
+
explainer = ConditionalShapValues(self.bn, self.target, self.logit)
|
|
159
|
+
explanation = explainer.compute((df, True))
|
|
160
|
+
self._plot(explanation, y, plot, plot_importance, percentage, filename)
|
|
161
|
+
return explanation.importances[y]
|
|
162
|
+
|
|
163
|
+
def marginal(
|
|
164
|
+
self, df, y=1, sample_size=200, plot=False, plot_importance=False, percentage: bool = False, filename: str = None
|
|
165
|
+
):
|
|
166
|
+
"""
|
|
167
|
+
Computes the marginal Shapley values for each variable.
|
|
168
|
+
|
|
169
|
+
Parameters:
|
|
170
|
+
----------
|
|
171
|
+
df : pandas DataFrame
|
|
172
|
+
The input data for which to compute the Shapley values.
|
|
173
|
+
y : int, optional
|
|
174
|
+
The target class for which to compute the Shapley values (default is 1).
|
|
175
|
+
sample_size : int, optional
|
|
176
|
+
The number of samples to use for the background data (default is 200).
|
|
177
|
+
plot : bool, optional
|
|
178
|
+
If True, plots the waterfall or beeswarm plot depending on the number of rows in df (default is False).
|
|
179
|
+
plot_importance : bool, optional
|
|
180
|
+
If True, plots the bar chart of feature importance (default is False).
|
|
181
|
+
percentage: bool
|
|
182
|
+
if True, the importance plot is shown in percent.
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
-------
|
|
186
|
+
Dict[str, float]
|
|
187
|
+
A dictionary containing the importances of each variable in the input data.
|
|
188
|
+
"""
|
|
189
|
+
explainer = MarginalShapValues(self.bn, self.target, None, sample_size, self.logit)
|
|
190
|
+
explanation = explainer.compute((df, True))
|
|
191
|
+
self._plot(explanation, y, plot, plot_importance, percentage, filename)
|
|
192
|
+
return explanation.importances[y]
|
|
193
|
+
|
|
194
|
+
def causal(
|
|
195
|
+
self, df, y=1, sample_size=200, plot=False, plot_importance=False, percentage: bool = False, filename: str = None
|
|
196
|
+
):
|
|
197
|
+
"""
|
|
198
|
+
Computes the causal Shapley values for each variable.
|
|
199
|
+
|
|
200
|
+
Parameters:
|
|
201
|
+
----------
|
|
202
|
+
df : pandas DataFrame
|
|
203
|
+
The input data for which to compute the Shapley values.
|
|
204
|
+
y : int, optional
|
|
205
|
+
The target class for which to compute the Shapley values (default is 1).
|
|
206
|
+
sample_size : int, optional
|
|
207
|
+
The number of samples to use for the background data (default is 200).
|
|
208
|
+
plot : bool, optional
|
|
209
|
+
If True, plots the waterfall or beeswarm plot depending on the number of rows in df (default is False).
|
|
210
|
+
plot_importance : bool, optional
|
|
211
|
+
If True, plots the bar chart of feature importance (default is False).
|
|
212
|
+
percentage: bool
|
|
213
|
+
if True, the importance plot is shown in percent.
|
|
214
|
+
filename : str, optional
|
|
215
|
+
If provided, saves the plots to the specified filename instead of displaying them.
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
-------
|
|
219
|
+
Dict[str, float]
|
|
220
|
+
A dictionary containing the importances of each variable in the input data.
|
|
221
|
+
"""
|
|
222
|
+
explainer = CausalShapValues(self.bn, self.target, None, sample_size, self.logit)
|
|
223
|
+
explanation = explainer.compute((df, True))
|
|
224
|
+
self._plot(explanation, y, plot, plot_importance, percentage, filename)
|
|
225
|
+
return explanation.importances[y]
|
pyagrum/lib/__init__.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
############################################################################
|
|
2
|
+
# This file is part of the aGrUM/pyAgrum library. #
|
|
3
|
+
# #
|
|
4
|
+
# Copyright (c) 2005-2025 by #
|
|
5
|
+
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
6
|
+
# - Christophe GONZALES(_at_AMU) #
|
|
7
|
+
# #
|
|
8
|
+
# The aGrUM/pyAgrum library is free software; you can redistribute it #
|
|
9
|
+
# and/or modify it under the terms of either : #
|
|
10
|
+
# #
|
|
11
|
+
# - the GNU Lesser General Public License as published by #
|
|
12
|
+
# the Free Software Foundation, either version 3 of the License, #
|
|
13
|
+
# or (at your option) any later version, #
|
|
14
|
+
# - the MIT license (MIT), #
|
|
15
|
+
# - or both in dual license, as here. #
|
|
16
|
+
# #
|
|
17
|
+
# (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) #
|
|
18
|
+
# #
|
|
19
|
+
# This aGrUM/pyAgrum library is distributed in the hope that it will be #
|
|
20
|
+
# useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, #
|
|
21
|
+
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS #
|
|
22
|
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
|
|
23
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
|
|
24
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, #
|
|
25
|
+
# ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR #
|
|
26
|
+
# OTHER DEALINGS IN THE SOFTWARE. #
|
|
27
|
+
# #
|
|
28
|
+
# See LICENCES for more details. #
|
|
29
|
+
# #
|
|
30
|
+
# SPDX-FileCopyrightText: Copyright 2005-2025 #
|
|
31
|
+
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
32
|
+
# - Christophe GONZALES(_at_AMU) #
|
|
33
|
+
# SPDX-License-Identifier: LGPL-3.0-or-later OR MIT #
|
|
34
|
+
# #
|
|
35
|
+
# Contact : info_at_agrum_dot_org #
|
|
36
|
+
# homepage : http://agrum.gitlab.io #
|
|
37
|
+
# gitlab : https://gitlab.com/agrumery/agrum #
|
|
38
|
+
# #
|
|
39
|
+
############################################################################
|
|
40
|
+
|
|
41
|
+
"""
|
|
42
|
+
pyagrum.lib is a set of python tools for pyagrum.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
__author__ = "Pierre-Henri Wuillemin"
|
|
46
|
+
__copyright__ = "(c) 2016-2024 PARIS"
|