pyAgrum-nightly 2.2.1.9.dev202510271761405498__cp310-abi3-win_amd64.whl → 2.3.0.9.dev202510291761586496__cp310-abi3-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyAgrum-nightly might be problematic. Click here for more details.
- pyagrum/_pyagrum.pyd +0 -0
- pyagrum/common.py +1 -1
- pyagrum/config.py +1 -0
- pyagrum/explain/_ComputationCausal.py +75 -0
- pyagrum/explain/_ComputationConditional.py +48 -0
- pyagrum/explain/_ComputationMarginal.py +48 -0
- pyagrum/explain/_CustomShapleyCache.py +110 -0
- pyagrum/explain/_Explainer.py +176 -0
- pyagrum/explain/_Explanation.py +70 -0
- pyagrum/explain/_FIFOCache.py +54 -0
- pyagrum/explain/_ShallCausalValues.py +204 -0
- pyagrum/explain/_ShallConditionalValues.py +155 -0
- pyagrum/explain/_ShallMarginalValues.py +155 -0
- pyagrum/explain/_ShallValues.py +296 -0
- pyagrum/explain/_ShapCausalValues.py +208 -0
- pyagrum/explain/_ShapConditionalValues.py +126 -0
- pyagrum/explain/_ShapMarginalValues.py +191 -0
- pyagrum/explain/_ShapleyValues.py +298 -0
- pyagrum/explain/__init__.py +81 -0
- pyagrum/explain/_explGeneralizedMarkovBlanket.py +152 -0
- pyagrum/explain/_explIndependenceListForPairs.py +146 -0
- pyagrum/explain/_explInformationGraph.py +264 -0
- pyagrum/explain/notebook/__init__.py +54 -0
- pyagrum/explain/notebook/_bar.py +142 -0
- pyagrum/explain/notebook/_beeswarm.py +174 -0
- pyagrum/explain/notebook/_showShapValues.py +97 -0
- pyagrum/explain/notebook/_waterfall.py +220 -0
- pyagrum/explain/shapley.py +225 -0
- pyagrum/lib/explain.py +11 -490
- pyagrum/pyagrum.py +17 -10
- {pyagrum_nightly-2.2.1.9.dev202510271761405498.dist-info → pyagrum_nightly-2.3.0.9.dev202510291761586496.dist-info}/METADATA +1 -1
- {pyagrum_nightly-2.2.1.9.dev202510271761405498.dist-info → pyagrum_nightly-2.3.0.9.dev202510291761586496.dist-info}/RECORD +36 -12
- pyagrum/lib/shapley.py +0 -661
- {pyagrum_nightly-2.2.1.9.dev202510271761405498.dist-info → pyagrum_nightly-2.3.0.9.dev202510291761586496.dist-info}/LICENSE.md +0 -0
- {pyagrum_nightly-2.2.1.9.dev202510271761405498.dist-info → pyagrum_nightly-2.3.0.9.dev202510291761586496.dist-info}/LICENSES/LGPL-3.0-or-later.txt +0 -0
- {pyagrum_nightly-2.2.1.9.dev202510271761405498.dist-info → pyagrum_nightly-2.3.0.9.dev202510291761586496.dist-info}/LICENSES/MIT.txt +0 -0
- {pyagrum_nightly-2.2.1.9.dev202510271761405498.dist-info → pyagrum_nightly-2.3.0.9.dev202510291761586496.dist-info}/WHEEL +0 -0
pyagrum/lib/shapley.py
DELETED
|
@@ -1,661 +0,0 @@
|
|
|
1
|
-
############################################################################
|
|
2
|
-
# This file is part of the aGrUM/pyAgrum library. #
|
|
3
|
-
# #
|
|
4
|
-
# Copyright (c) 2005-2025 by #
|
|
5
|
-
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
6
|
-
# - Christophe GONZALES(_at_AMU) #
|
|
7
|
-
# #
|
|
8
|
-
# The aGrUM/pyAgrum library is free software; you can redistribute it #
|
|
9
|
-
# and/or modify it under the terms of either : #
|
|
10
|
-
# #
|
|
11
|
-
# - the GNU Lesser General Public License as published by #
|
|
12
|
-
# the Free Software Foundation, either version 3 of the License, #
|
|
13
|
-
# or (at your option) any later version, #
|
|
14
|
-
# - the MIT license (MIT), #
|
|
15
|
-
# - or both in dual license, as here. #
|
|
16
|
-
# #
|
|
17
|
-
# (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) #
|
|
18
|
-
# #
|
|
19
|
-
# This aGrUM/pyAgrum library is distributed in the hope that it will be #
|
|
20
|
-
# useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, #
|
|
21
|
-
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS #
|
|
22
|
-
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
|
|
23
|
-
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
|
|
24
|
-
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, #
|
|
25
|
-
# ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR #
|
|
26
|
-
# OTHER DEALINGS IN THE SOFTWARE. #
|
|
27
|
-
# #
|
|
28
|
-
# See LICENCES for more details. #
|
|
29
|
-
# #
|
|
30
|
-
# SPDX-FileCopyrightText: Copyright 2005-2025 #
|
|
31
|
-
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
32
|
-
# - Christophe GONZALES(_at_AMU) #
|
|
33
|
-
# SPDX-License-Identifier: LGPL-3.0-or-later OR MIT #
|
|
34
|
-
# #
|
|
35
|
-
# Contact : info_at_agrum_dot_org #
|
|
36
|
-
# homepage : http://agrum.gitlab.io #
|
|
37
|
-
# gitlab : https://gitlab.com/agrumery/agrum #
|
|
38
|
-
# #
|
|
39
|
-
############################################################################
|
|
40
|
-
|
|
41
|
-
"""
|
|
42
|
-
tools for BN qualitative analysis and explainability
|
|
43
|
-
"""
|
|
44
|
-
|
|
45
|
-
import math
|
|
46
|
-
from typing import Dict
|
|
47
|
-
import itertools
|
|
48
|
-
|
|
49
|
-
import numpy as np
|
|
50
|
-
import pandas as pd
|
|
51
|
-
|
|
52
|
-
import matplotlib as mpl
|
|
53
|
-
import matplotlib.pyplot as plt
|
|
54
|
-
|
|
55
|
-
from copy import deepcopy
|
|
56
|
-
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
|
57
|
-
import matplotlib.colors as mcolors
|
|
58
|
-
|
|
59
|
-
import pyagrum as gum
|
|
60
|
-
|
|
61
|
-
_cdict = {
|
|
62
|
-
"red": ((0.0, 0.1, 0.3), (1.0, 0.6, 1.0)),
|
|
63
|
-
"green": ((0.0, 0.0, 0.0), (1.0, 0.6, 0.8)),
|
|
64
|
-
"blue": ((0.0, 0.0, 0.0), (1.0, 1, 0.8)),
|
|
65
|
-
}
|
|
66
|
-
_INFOcmap = mpl.colors.LinearSegmentedColormap("my_colormap", _cdict, 256)
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
class ShapValues:
|
|
70
|
-
"""
|
|
71
|
-
The ShapValue class implements the calculation of Shap values in Bayesian networks.
|
|
72
|
-
|
|
73
|
-
The main implementation is based on Conditional Shap values [3]_, but the Interventional calculation method proposed in [2]_ is also present. In addition, a new causal method, based on [1]_, is implemented which is well suited for Bayesian networks.
|
|
74
|
-
|
|
75
|
-
.. [1] Heskes, T., Sijben, E., Bucur, I., & Claassen, T. (2020). Causal Shapley Values: Exploiting Causal Knowledge. 34th Conference on Neural Information Processing Systems. Vancouver, Canada.
|
|
76
|
-
|
|
77
|
-
.. [2] Janzing, D., Minorics, L., & Blöbaum, P. (2019). Feature relevance quantification in explainable AI: A causality problem. arXiv: Machine Learning. Retrieved 6 24, 2021, from https://arxiv.org/abs/1910.13413
|
|
78
|
-
|
|
79
|
-
.. [3] Lundberg, S. M., & Su-In, L. (2017). A Unified Approach to Interpreting Model. 31st Conference on Neural Information Processing Systems. Long Beach, CA, USA.
|
|
80
|
-
"""
|
|
81
|
-
|
|
82
|
-
@staticmethod
|
|
83
|
-
def _logit(p):
|
|
84
|
-
return np.log(p / (1 - p))
|
|
85
|
-
|
|
86
|
-
@staticmethod
|
|
87
|
-
def _comb(n, k):
|
|
88
|
-
return math.comb(n, k)
|
|
89
|
-
|
|
90
|
-
@staticmethod
|
|
91
|
-
def _fact(n):
|
|
92
|
-
return math.factorial(n)
|
|
93
|
-
|
|
94
|
-
def __init__(self, bn, target):
|
|
95
|
-
self.bn = bn
|
|
96
|
-
self.target = target
|
|
97
|
-
self.feats_names = self._get_feats_name(bn, target)
|
|
98
|
-
self.results = None
|
|
99
|
-
|
|
100
|
-
################################## VARIABLES ##################################
|
|
101
|
-
|
|
102
|
-
@staticmethod
|
|
103
|
-
def _get_feats_name(bn, target):
|
|
104
|
-
list_feats_name = list(bn.names())
|
|
105
|
-
list_feats_name.remove(target)
|
|
106
|
-
return list_feats_name
|
|
107
|
-
|
|
108
|
-
def _get_list_names_order(self):
|
|
109
|
-
### Return a list of BN's nodes names
|
|
110
|
-
list_node_names = [None] * len(self.bn.names())
|
|
111
|
-
for name in self.bn.names():
|
|
112
|
-
list_node_names[self.bn.idFromName(name)] = name
|
|
113
|
-
return list_node_names
|
|
114
|
-
|
|
115
|
-
@staticmethod
|
|
116
|
-
def _coal_encoding(convert, coal):
|
|
117
|
-
### Convert a list of nodes : ['X', 'Z'] to an array of 0 and 1.
|
|
118
|
-
### ['X', 'Z'] -> [1,0,1,0]
|
|
119
|
-
temp = np.zeros(len(convert), dtype=int)
|
|
120
|
-
for var in coal:
|
|
121
|
-
i = convert.index(var)
|
|
122
|
-
temp[i] = 1
|
|
123
|
-
return temp
|
|
124
|
-
|
|
125
|
-
def _get_markov_blanket(self):
|
|
126
|
-
feats_markov_blanket = []
|
|
127
|
-
for i in gum.MarkovBlanket(self.bn, self.target).nodes():
|
|
128
|
-
convert = self._get_list_names_order()
|
|
129
|
-
feats_markov_blanket.append(convert[i])
|
|
130
|
-
feats_markov_blanket.remove(self.target)
|
|
131
|
-
return feats_markov_blanket
|
|
132
|
-
|
|
133
|
-
################################## Get All Combinations ##################################
|
|
134
|
-
|
|
135
|
-
def _get_all_coal_compress(self):
|
|
136
|
-
### Return : all coalitions with the feature
|
|
137
|
-
return (
|
|
138
|
-
list(itertools.compress(self.feats_names, mask))
|
|
139
|
-
for mask in itertools.product(*[[0, 1]] * (len(self.feats_names)))
|
|
140
|
-
)
|
|
141
|
-
|
|
142
|
-
################################## PREDICTION ##################################
|
|
143
|
-
def _filtrage(self, df, conditions):
|
|
144
|
-
### Return : a selected part of DataFrame based on conditions, conditions must be in a dict :
|
|
145
|
-
### The key is the name of the features and the value is the value that the features should take
|
|
146
|
-
### Example : {'X1':0, 'X2':1}
|
|
147
|
-
if conditions == {}:
|
|
148
|
-
return df
|
|
149
|
-
|
|
150
|
-
first = next(iter(conditions))
|
|
151
|
-
new_df = df[df[first] == conditions[first]]
|
|
152
|
-
conditions.pop(first)
|
|
153
|
-
return self._filtrage(new_df, conditions)
|
|
154
|
-
|
|
155
|
-
def _init_Inference(self):
|
|
156
|
-
ie = gum.LazyPropagation(self.bn)
|
|
157
|
-
ie.addTarget(self.target)
|
|
158
|
-
for name in self.feats_names:
|
|
159
|
-
ie.addEvidence(name, 0)
|
|
160
|
-
return ie
|
|
161
|
-
|
|
162
|
-
##### Prediction fonctions ####
|
|
163
|
-
|
|
164
|
-
def _pred_markov_blanket(self, df, ie, markov_blanket):
|
|
165
|
-
unique = df.groupby(markov_blanket).agg(freq=(self.target, "count")).reset_index()
|
|
166
|
-
result = 0
|
|
167
|
-
for i in range(len(unique)):
|
|
168
|
-
for name in markov_blanket:
|
|
169
|
-
ie.chgEvidence(name, str(unique[name].iloc[i]))
|
|
170
|
-
ie.makeInference()
|
|
171
|
-
predicted = ie.posterior(self.target).toarray()[1]
|
|
172
|
-
result = result + self._logit(predicted) * unique["freq"].iloc[i] / len(df)
|
|
173
|
-
return result
|
|
174
|
-
|
|
175
|
-
def _pred_markov_blanket_logit(self, df, ie, markov_blanket):
|
|
176
|
-
unique = df.groupby(markov_blanket).agg(freq=("Y", "count")).reset_index()
|
|
177
|
-
result = 0
|
|
178
|
-
for i in range(len(unique)):
|
|
179
|
-
for name in markov_blanket:
|
|
180
|
-
ie.chgEvidence(name, str(unique[name].iloc[i]))
|
|
181
|
-
ie.makeInference()
|
|
182
|
-
predicted = ie.posterior(self.target).toarray()[1]
|
|
183
|
-
result = result + predicted * unique["freq"].iloc[i] / len(df)
|
|
184
|
-
return self._logit(result)
|
|
185
|
-
|
|
186
|
-
def _evidenceImpact(self, condi, ie):
|
|
187
|
-
ie.eraseAllEvidence()
|
|
188
|
-
for key in condi.keys():
|
|
189
|
-
ie.addEvidence(key, str(condi[key]))
|
|
190
|
-
ie.makeInference()
|
|
191
|
-
return self._logit(ie.posterior(self.target).toarray()[1])
|
|
192
|
-
|
|
193
|
-
############################## MARGINAL#########################################
|
|
194
|
-
|
|
195
|
-
def _predict_marginal(self, df, ie):
|
|
196
|
-
result = []
|
|
197
|
-
for i in range(len(df)):
|
|
198
|
-
for name in self.feats_names:
|
|
199
|
-
ie.chgEvidence(name, str(df[name].iloc[i]))
|
|
200
|
-
ie.makeInference()
|
|
201
|
-
result.append(ie.posterior(self.target).toarray()[1])
|
|
202
|
-
return np.array(result)
|
|
203
|
-
|
|
204
|
-
################################## COMPUTE SHAP ##################################
|
|
205
|
-
|
|
206
|
-
def _compute_SHAP_i(self, S_U_i, S, v, size_S):
|
|
207
|
-
size_all_features = len(self.bn.nodes()) - 1
|
|
208
|
-
diff = v[f"{S_U_i}"] - v[f"{S}"]
|
|
209
|
-
return diff / self._invcoeff_shap(size_S, size_all_features)
|
|
210
|
-
|
|
211
|
-
def _invcoeff_shap(self, S_size, len_features):
|
|
212
|
-
return (len_features - S_size) * self._comb(len_features, S_size)
|
|
213
|
-
|
|
214
|
-
################################## Get the two coalitions to substract S_U_i and S #################################
|
|
215
|
-
|
|
216
|
-
@staticmethod
|
|
217
|
-
def _gen_coalitions2(list_feats):
|
|
218
|
-
### !!!! THE FEATURE i HAVE TO BE THE LAST IN THE LIST OF FEATURES !!!
|
|
219
|
-
### Return : all coalitions with the feature
|
|
220
|
-
for mask in itertools.product(*[[0, 1]] * (len(list_feats) - 1)):
|
|
221
|
-
S_U_i = itertools.compress(list_feats, mask + (1,))
|
|
222
|
-
S = itertools.compress(list_feats, mask + (0,))
|
|
223
|
-
yield S_U_i, S
|
|
224
|
-
|
|
225
|
-
################################## Function to Compute CONDITIONNAL SHAP Value ##################################
|
|
226
|
-
def _conditional(self, train: pd.DataFrame) -> Dict[str, float]:
|
|
227
|
-
"""
|
|
228
|
-
Compute the conditional Shap Values for each variables.
|
|
229
|
-
|
|
230
|
-
Parameters
|
|
231
|
-
----------
|
|
232
|
-
train :pandas.DataFrame
|
|
233
|
-
the database
|
|
234
|
-
|
|
235
|
-
Returns
|
|
236
|
-
-------
|
|
237
|
-
a dictionary Dict[str,float]
|
|
238
|
-
"""
|
|
239
|
-
|
|
240
|
-
ie = self._init_Inference()
|
|
241
|
-
|
|
242
|
-
v = train.groupby(self.feats_names).agg(freq=(self.feats_names[0], "count")).reset_index()
|
|
243
|
-
|
|
244
|
-
convert = self._get_list_names_order()
|
|
245
|
-
|
|
246
|
-
for i in range(len(v)):
|
|
247
|
-
v["Baseline"] = self._evidenceImpact({}, ie)
|
|
248
|
-
for coal in self._get_all_coal_compress():
|
|
249
|
-
S = list(coal)
|
|
250
|
-
condi = {}
|
|
251
|
-
for var in S:
|
|
252
|
-
condi[var] = v.loc[i, var]
|
|
253
|
-
col_arr_name = self._coal_encoding(convert, coal)
|
|
254
|
-
v.loc[i, f"{col_arr_name}"] = self._evidenceImpact(condi, ie)
|
|
255
|
-
df = pd.DataFrame()
|
|
256
|
-
for feat in self.feats_names:
|
|
257
|
-
list_i_last = self.feats_names.copy()
|
|
258
|
-
index_i = list_i_last.index(feat)
|
|
259
|
-
list_i_last[len(list_i_last) - 1], list_i_last[index_i] = list_i_last[index_i], list_i_last[len(list_i_last) - 1]
|
|
260
|
-
|
|
261
|
-
somme = 0
|
|
262
|
-
for coal1, coal2 in self._gen_coalitions2(list_i_last):
|
|
263
|
-
S_U_i = self._coal_encoding(convert, list(coal1))
|
|
264
|
-
S = self._coal_encoding(convert, list(coal2))
|
|
265
|
-
size_S = sum(S)
|
|
266
|
-
somme = somme + self._compute_SHAP_i(S_U_i, S, v, size_S)
|
|
267
|
-
df[feat] = somme
|
|
268
|
-
self.results = df
|
|
269
|
-
return df, v
|
|
270
|
-
|
|
271
|
-
def conditional(self, train, plot=False, plot_importance=False, percentage=False, filename=None):
|
|
272
|
-
"""
|
|
273
|
-
Compute the conditional Shap Values for each variables.
|
|
274
|
-
|
|
275
|
-
Parameters
|
|
276
|
-
----------
|
|
277
|
-
train :pandas.DataFrame
|
|
278
|
-
the database
|
|
279
|
-
plot: bool
|
|
280
|
-
if True, plot the violin graph of the shap values
|
|
281
|
-
plot_importance: bool
|
|
282
|
-
if True, plot the importance plot
|
|
283
|
-
percentage: bool
|
|
284
|
-
if True, the importance plot is shown in percent.
|
|
285
|
-
filename: str
|
|
286
|
-
if not None, save the plot in the file
|
|
287
|
-
|
|
288
|
-
Returns
|
|
289
|
-
-------
|
|
290
|
-
a dictionary Dict[str,float]
|
|
291
|
-
"""
|
|
292
|
-
results, v = self._conditional(train)
|
|
293
|
-
res = {}
|
|
294
|
-
for col in results.columns:
|
|
295
|
-
res[col] = abs(results[col]).mean()
|
|
296
|
-
|
|
297
|
-
self._plotResults(results, v, plot, plot_importance, percentage)
|
|
298
|
-
|
|
299
|
-
if plot or plot_importance:
|
|
300
|
-
if filename is not None:
|
|
301
|
-
plt.savefig(filename)
|
|
302
|
-
else:
|
|
303
|
-
plt.show()
|
|
304
|
-
plt.close()
|
|
305
|
-
|
|
306
|
-
return res
|
|
307
|
-
|
|
308
|
-
################################## Function to Compute MARGINAL SHAP Value ##################################
|
|
309
|
-
def _marginal(self, df, size_sample_df):
|
|
310
|
-
ie = self._init_Inference()
|
|
311
|
-
convert = self._get_list_names_order()
|
|
312
|
-
test = df[:size_sample_df]
|
|
313
|
-
v = df.groupby(self.feats_names).agg(freq=(self.feats_names[0], "count")).reset_index()
|
|
314
|
-
df = pd.DataFrame()
|
|
315
|
-
|
|
316
|
-
for i in range(len(v)):
|
|
317
|
-
for coal in self._get_all_coal_compress():
|
|
318
|
-
intervention = test.copy()
|
|
319
|
-
for var in coal:
|
|
320
|
-
intervention[var] = v.loc[i, var]
|
|
321
|
-
col_arr_name = self._coal_encoding(convert, coal)
|
|
322
|
-
v.loc[i, f"{col_arr_name}"] = np.mean(self._logit(self._predict_marginal(intervention, ie)))
|
|
323
|
-
|
|
324
|
-
for feat in self.feats_names:
|
|
325
|
-
list_i_last = self.feats_names.copy()
|
|
326
|
-
index_i = list_i_last.index(feat)
|
|
327
|
-
list_i_last[len(list_i_last) - 1], list_i_last[index_i] = list_i_last[index_i], list_i_last[len(list_i_last) - 1]
|
|
328
|
-
|
|
329
|
-
somme = 0
|
|
330
|
-
for coal1, coal2 in self._gen_coalitions2(list_i_last):
|
|
331
|
-
S_U_i = self._coal_encoding(convert, list(coal1))
|
|
332
|
-
S = self._coal_encoding(convert, list(coal2))
|
|
333
|
-
size_S = sum(S)
|
|
334
|
-
somme = somme + self._compute_SHAP_i(S_U_i, S, v, size_S)
|
|
335
|
-
|
|
336
|
-
df[feat] = somme
|
|
337
|
-
|
|
338
|
-
self.results = df
|
|
339
|
-
return df, v
|
|
340
|
-
|
|
341
|
-
def marginal(self, train, sample_size=200, plot=False, plot_importance=False, percentage=False, filename=None):
|
|
342
|
-
"""
|
|
343
|
-
Compute the marginal Shap Values for each variables.
|
|
344
|
-
|
|
345
|
-
Parameters
|
|
346
|
-
----------
|
|
347
|
-
train :pandas.DataFrame
|
|
348
|
-
the database
|
|
349
|
-
sample_size : int
|
|
350
|
-
The computation of marginal ShapValue is very slow. The parameter allow to compute only on a fragment of the database.
|
|
351
|
-
plot: bool
|
|
352
|
-
if True, plot the violin graph of the shap values
|
|
353
|
-
plot_importance: bool
|
|
354
|
-
if True, plot the importance plot
|
|
355
|
-
percentage: bool
|
|
356
|
-
if True, the importance plot is shown in percent.
|
|
357
|
-
filename: str
|
|
358
|
-
if not None, save the plot in the file
|
|
359
|
-
|
|
360
|
-
Returns
|
|
361
|
-
-------
|
|
362
|
-
a dictionary Dict[str,float]
|
|
363
|
-
"""
|
|
364
|
-
results, v = self._marginal(train, sample_size)
|
|
365
|
-
res = {}
|
|
366
|
-
for col in results.columns:
|
|
367
|
-
res[col] = abs(results[col]).mean()
|
|
368
|
-
|
|
369
|
-
self._plotResults(results, v, plot, plot_importance, percentage)
|
|
370
|
-
|
|
371
|
-
if plot or plot_importance:
|
|
372
|
-
if filename is not None:
|
|
373
|
-
plt.savefig(filename)
|
|
374
|
-
else:
|
|
375
|
-
plt.show()
|
|
376
|
-
plt.close()
|
|
377
|
-
|
|
378
|
-
return res
|
|
379
|
-
|
|
380
|
-
################################## MUTILATION ######################################
|
|
381
|
-
|
|
382
|
-
@staticmethod
|
|
383
|
-
def _mutilation_Inference(bn, feats_name, target):
|
|
384
|
-
ie = gum.LazyPropagation(bn)
|
|
385
|
-
ie.addTarget(target)
|
|
386
|
-
for name in feats_name:
|
|
387
|
-
ie.addEvidence(name, 0)
|
|
388
|
-
return ie
|
|
389
|
-
|
|
390
|
-
def _causal(self, train):
|
|
391
|
-
v = train.groupby(self.feats_names).agg(freq=(self.feats_names[0], "count")).reset_index()
|
|
392
|
-
ie = self._init_Inference()
|
|
393
|
-
|
|
394
|
-
convert = self._get_list_names_order()
|
|
395
|
-
df = pd.DataFrame()
|
|
396
|
-
|
|
397
|
-
v["Baseline"] = self._evidenceImpact({}, ie)
|
|
398
|
-
|
|
399
|
-
for coal in self._get_all_coal_compress():
|
|
400
|
-
for i in range(len(v)):
|
|
401
|
-
bn_temp = gum.BayesNet(self.bn)
|
|
402
|
-
S = list(coal)
|
|
403
|
-
condi = {}
|
|
404
|
-
for var in S:
|
|
405
|
-
condi[var] = v.loc[i, var]
|
|
406
|
-
for parent in bn_temp.parents(var):
|
|
407
|
-
bn_temp.eraseArc(parent, bn_temp.idFromName(var))
|
|
408
|
-
ie = self._mutilation_Inference(bn_temp, self.feats_names, self.target)
|
|
409
|
-
col_arr_name = self._coal_encoding(convert, coal)
|
|
410
|
-
v.loc[i, f"{col_arr_name}"] = self._evidenceImpact(condi, ie)
|
|
411
|
-
for feat in self.feats_names:
|
|
412
|
-
list_i_last = self.feats_names.copy()
|
|
413
|
-
index_i = list_i_last.index(feat)
|
|
414
|
-
list_i_last[len(list_i_last) - 1], list_i_last[index_i] = list_i_last[index_i], list_i_last[len(list_i_last) - 1]
|
|
415
|
-
|
|
416
|
-
somme = 0
|
|
417
|
-
for coal1, coal2 in self._gen_coalitions2(list_i_last):
|
|
418
|
-
S_U_i = self._coal_encoding(convert, list(coal1))
|
|
419
|
-
S = self._coal_encoding(convert, list(coal2))
|
|
420
|
-
size_S = sum(S)
|
|
421
|
-
somme = somme + self._compute_SHAP_i(S_U_i, S, v, size_S)
|
|
422
|
-
df[feat] = somme
|
|
423
|
-
|
|
424
|
-
self.results = df
|
|
425
|
-
return df, v
|
|
426
|
-
|
|
427
|
-
def causal(self, train, plot=False, plot_importance=False, percentage=False, filename=None):
|
|
428
|
-
"""
|
|
429
|
-
Compute the causal Shap Values for each variables.
|
|
430
|
-
|
|
431
|
-
Parameters
|
|
432
|
-
----------
|
|
433
|
-
train :pandas.DataFrame
|
|
434
|
-
the database
|
|
435
|
-
plot: bool
|
|
436
|
-
if True, plot the violin graph of the shap values
|
|
437
|
-
plot_importance: bool
|
|
438
|
-
if True, plot the importance plot
|
|
439
|
-
percentage: bool
|
|
440
|
-
if True, the importance plot is shown in percent.
|
|
441
|
-
filename: str
|
|
442
|
-
if not None, save the plot in the file
|
|
443
|
-
|
|
444
|
-
Returns
|
|
445
|
-
-------
|
|
446
|
-
a dictionary Dict[str,float]
|
|
447
|
-
"""
|
|
448
|
-
results, v = self._causal(train)
|
|
449
|
-
|
|
450
|
-
res = {}
|
|
451
|
-
for col in results.columns:
|
|
452
|
-
res[col] = abs(results[col]).mean()
|
|
453
|
-
|
|
454
|
-
self._plotResults(results, v, plot, plot_importance, percentage)
|
|
455
|
-
|
|
456
|
-
if plot or plot_importance:
|
|
457
|
-
if filename is not None:
|
|
458
|
-
plt.savefig(filename)
|
|
459
|
-
else:
|
|
460
|
-
plt.show()
|
|
461
|
-
plt.close()
|
|
462
|
-
|
|
463
|
-
return res
|
|
464
|
-
|
|
465
|
-
################################## PLOT SHAP Value ##################################
|
|
466
|
-
|
|
467
|
-
def _plotResults(self, results, v, plot=True, plot_importance=False, percentage=False):
|
|
468
|
-
ax1 = ax2 = None
|
|
469
|
-
if plot and plot_importance:
|
|
470
|
-
fig = plt.figure(figsize=(15, 0.5 * len(results.columns)))
|
|
471
|
-
ax1 = fig.add_subplot(1, 2, 1)
|
|
472
|
-
ax2 = fig.add_subplot(1, 2, 2)
|
|
473
|
-
if plot:
|
|
474
|
-
shap_dict = results.to_dict(orient="list")
|
|
475
|
-
sorted_dict = dict(sorted(shap_dict.items(), key=lambda x: sum(abs(i) for i in x[1]) / len(x[1])))
|
|
476
|
-
data = np.array([sorted_dict[key] for key in sorted_dict])
|
|
477
|
-
features = sorted_dict.keys()
|
|
478
|
-
v = v[features]
|
|
479
|
-
colors = v.transpose().to_numpy()
|
|
480
|
-
self._plot_beeswarm_(data, colors, 250, 1.5, features, ax=ax1)
|
|
481
|
-
if plot_importance:
|
|
482
|
-
self._plot_importance(results, percentage, ax=ax2)
|
|
483
|
-
|
|
484
|
-
@staticmethod
|
|
485
|
-
def _plot_beeswarm_(data, colors, N, K, features, cmap=None, ax=None):
|
|
486
|
-
"""
|
|
487
|
-
returns a beeswarm plot (or stripplot) from a given data.
|
|
488
|
-
|
|
489
|
-
Parameters
|
|
490
|
-
----------
|
|
491
|
-
data: list of numpy array.
|
|
492
|
-
Each elements of the list is a numpy array containing the shapley values for the feature to be displayed for a category.
|
|
493
|
-
|
|
494
|
-
colors: list of numpy array.
|
|
495
|
-
Each elements of the list is a numpy array containing the values of the data point to be displayed for a category.
|
|
496
|
-
|
|
497
|
-
Returns
|
|
498
|
-
-------
|
|
499
|
-
matplotlib.pyplot.scatter
|
|
500
|
-
"""
|
|
501
|
-
min_value = np.min(data, axis=(0, 1))
|
|
502
|
-
max_value = np.max(data, axis=(0, 1))
|
|
503
|
-
bin_size = (max_value - min_value) / N
|
|
504
|
-
if bin_size == 0:
|
|
505
|
-
bin_size = 1
|
|
506
|
-
horiz_shift = K * bin_size
|
|
507
|
-
|
|
508
|
-
if ax is None:
|
|
509
|
-
_, ax = plt.subplots()
|
|
510
|
-
if cmap is None:
|
|
511
|
-
# Set Color Map
|
|
512
|
-
## Define the hex colors
|
|
513
|
-
color1 = gum.config["notebook", "tensor_color_0"]
|
|
514
|
-
color2 = gum.config["notebook", "tensor_color_1"]
|
|
515
|
-
|
|
516
|
-
## Create the custom colormap
|
|
517
|
-
cmap = mcolors.LinearSegmentedColormap.from_list("", [color1, color2])
|
|
518
|
-
|
|
519
|
-
for n, d in enumerate(data):
|
|
520
|
-
pos = n + 1
|
|
521
|
-
|
|
522
|
-
d_shifted = d + np.random.normal(0, horiz_shift, len(d))
|
|
523
|
-
# Sorting values
|
|
524
|
-
d_sort = np.sort(d_shifted)
|
|
525
|
-
|
|
526
|
-
# Creation of bins
|
|
527
|
-
d_x = bin_size
|
|
528
|
-
if (np.max(d_sort) - np.min(d_sort)) % d_x == 0:
|
|
529
|
-
nb_bins = (np.max(d_sort) - np.min(d_sort)) // d_x
|
|
530
|
-
else:
|
|
531
|
-
nb_bins = (np.max(d_sort) - np.min(d_sort)) // d_x + 1
|
|
532
|
-
bins = [np.min(d_sort) + i * d_x for i in range(int(nb_bins) + 1)]
|
|
533
|
-
|
|
534
|
-
# Group by Bins
|
|
535
|
-
subarr = []
|
|
536
|
-
for k in range(1, len(bins)):
|
|
537
|
-
group = d_sort[np.logical_and(d_sort >= bins[k - 1], d_sort < bins[k])]
|
|
538
|
-
subarr.append(group)
|
|
539
|
-
|
|
540
|
-
# For each bin compute the d_y (vertical shift)
|
|
541
|
-
d_y = 0.025
|
|
542
|
-
subarr_jitter = deepcopy(subarr)
|
|
543
|
-
for i in range(len(subarr)):
|
|
544
|
-
L = subarr[i].size
|
|
545
|
-
if L > 0:
|
|
546
|
-
for j in range(L):
|
|
547
|
-
shift = d_y * (L - 1) / 2 - d_y * j
|
|
548
|
-
subarr_jitter[i][j] = shift
|
|
549
|
-
|
|
550
|
-
jitter = np.concatenate(subarr_jitter)
|
|
551
|
-
|
|
552
|
-
sc = ax.scatter(d_shifted, pos + jitter, s=10, c=colors[n], cmap=cmap, alpha=0.7)
|
|
553
|
-
|
|
554
|
-
## Create the colorbar
|
|
555
|
-
divider = make_axes_locatable(ax)
|
|
556
|
-
cax = divider.append_axes("right", size="5%", pad=0.1)
|
|
557
|
-
|
|
558
|
-
cbar = plt.colorbar(sc, cax=cax, aspect=80)
|
|
559
|
-
cbar.set_label("Data Point Value")
|
|
560
|
-
|
|
561
|
-
## Add text above and below the colorbar
|
|
562
|
-
cax.text(0.5, 1.025, "High", transform=cax.transAxes, ha="center", va="center", fontsize=10)
|
|
563
|
-
cax.text(0.5, -0.025, "Low", transform=cax.transAxes, ha="center", va="center", fontsize=10)
|
|
564
|
-
|
|
565
|
-
## Add x-axis tick labels
|
|
566
|
-
ax.set_yticks([i + 1 for i in range(len(features))])
|
|
567
|
-
ax.set_yticklabels(features)
|
|
568
|
-
|
|
569
|
-
## Set axis labels and title
|
|
570
|
-
ax.set_ylabel("Features")
|
|
571
|
-
ax.set_xlabel("Impact on output")
|
|
572
|
-
ax.set_title("Shapley value (impact on model output)")
|
|
573
|
-
|
|
574
|
-
# Show plot
|
|
575
|
-
return ax.get_figure()
|
|
576
|
-
|
|
577
|
-
@staticmethod
|
|
578
|
-
def _plot_importance(results, percentage=False, ax=None):
|
|
579
|
-
series = pd.DataFrame(abs(results).mean(), columns=["value"])
|
|
580
|
-
series["feat"] = series.index
|
|
581
|
-
|
|
582
|
-
if ax is None:
|
|
583
|
-
_, ax = plt.subplots()
|
|
584
|
-
|
|
585
|
-
if percentage:
|
|
586
|
-
series["value"] = series["value"].div(series["value"].sum(axis=0)).multiply(100)
|
|
587
|
-
series = series.sort_values("value", ascending=True)
|
|
588
|
-
ax.barh(series.feat, series.value, align="center")
|
|
589
|
-
ax.set_xlabel("Mean(|SHAP value|)")
|
|
590
|
-
ax.set_title("Feature Importance in %")
|
|
591
|
-
else:
|
|
592
|
-
series = series.sort_values("value", ascending=True)
|
|
593
|
-
ax.barh(series.feat, series.value, align="center")
|
|
594
|
-
ax.set_xlabel("Mean(|SHAP value|)")
|
|
595
|
-
ax.set_title("Feature Importance")
|
|
596
|
-
|
|
597
|
-
return ax.get_figure()
|
|
598
|
-
|
|
599
|
-
@staticmethod
|
|
600
|
-
def _plot_scatter(results, ax=None):
|
|
601
|
-
if ax is None:
|
|
602
|
-
_, ax = plt.subplots()
|
|
603
|
-
|
|
604
|
-
res = {}
|
|
605
|
-
for col in results.columns:
|
|
606
|
-
res[col] = results[col].to_numpy()
|
|
607
|
-
names = list(res.keys())
|
|
608
|
-
values = list(res.values())
|
|
609
|
-
for xe, ye in zip(names, values):
|
|
610
|
-
ax.scatter(ye, [xe] * len(ye))
|
|
611
|
-
ax.set_title("Shapley value (impact on model output)")
|
|
612
|
-
return ax.get_figure()
|
|
613
|
-
|
|
614
|
-
@staticmethod
|
|
615
|
-
def _plot_violin(results, ax=None):
|
|
616
|
-
data = []
|
|
617
|
-
pos = []
|
|
618
|
-
label = []
|
|
619
|
-
series = pd.DataFrame(abs(results).mean(), columns=["value"])
|
|
620
|
-
series = series.sort_values("value", ascending=True)
|
|
621
|
-
series["feat"] = series.index
|
|
622
|
-
for i, col in enumerate(series.feat):
|
|
623
|
-
data.append(results[col].to_numpy())
|
|
624
|
-
pos.append(i)
|
|
625
|
-
label.append(col)
|
|
626
|
-
if ax is None:
|
|
627
|
-
_, ax = plt.subplots()
|
|
628
|
-
ax.violinplot(data, pos, vert=False)
|
|
629
|
-
ax.set_yticks(pos)
|
|
630
|
-
ax.set_yticklabels(label)
|
|
631
|
-
ax.set_title("Shapley value (impact on model output)")
|
|
632
|
-
return ax.get_figure()
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
def getShapValues(bn, shaps, cmap="plasma"):
|
|
636
|
-
"""
|
|
637
|
-
Just a wrapper around BN2dot to easily show the Shap values
|
|
638
|
-
|
|
639
|
-
Parameters
|
|
640
|
-
----------
|
|
641
|
-
bn : pyagrum.BayesNet
|
|
642
|
-
The Bayesian network
|
|
643
|
-
shaps: dict[str,float]
|
|
644
|
-
The (Shap) values associates to each variable
|
|
645
|
-
cmap: Matplotlib.ColorMap
|
|
646
|
-
The colormap used for colouring the nodes
|
|
647
|
-
|
|
648
|
-
Returns
|
|
649
|
-
-------
|
|
650
|
-
a pydot.graph
|
|
651
|
-
"""
|
|
652
|
-
from pyagrum.lib.bn2graph import BN2dot
|
|
653
|
-
|
|
654
|
-
norm_color = {}
|
|
655
|
-
raw = list(shaps.values())
|
|
656
|
-
norm = [float(i) / sum(raw) for i in raw]
|
|
657
|
-
for i, feat in enumerate(list(shaps.keys())):
|
|
658
|
-
norm_color[feat] = norm[i]
|
|
659
|
-
cm = plt.get_cmap(cmap)
|
|
660
|
-
g = BN2dot(bn, nodeColor=norm_color, cmapNode=cm)
|
|
661
|
-
return g
|
|
File without changes
|
|
File without changes
|
|
File without changes
|