pyAgrum-nightly 2.2.1.9.dev202510271761405498__cp310-abi3-win_amd64.whl → 2.3.0.9.dev202510291761586496__cp310-abi3-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyAgrum-nightly might be problematic. Click here for more details.

Files changed (37) hide show
  1. pyagrum/_pyagrum.pyd +0 -0
  2. pyagrum/common.py +1 -1
  3. pyagrum/config.py +1 -0
  4. pyagrum/explain/_ComputationCausal.py +75 -0
  5. pyagrum/explain/_ComputationConditional.py +48 -0
  6. pyagrum/explain/_ComputationMarginal.py +48 -0
  7. pyagrum/explain/_CustomShapleyCache.py +110 -0
  8. pyagrum/explain/_Explainer.py +176 -0
  9. pyagrum/explain/_Explanation.py +70 -0
  10. pyagrum/explain/_FIFOCache.py +54 -0
  11. pyagrum/explain/_ShallCausalValues.py +204 -0
  12. pyagrum/explain/_ShallConditionalValues.py +155 -0
  13. pyagrum/explain/_ShallMarginalValues.py +155 -0
  14. pyagrum/explain/_ShallValues.py +296 -0
  15. pyagrum/explain/_ShapCausalValues.py +208 -0
  16. pyagrum/explain/_ShapConditionalValues.py +126 -0
  17. pyagrum/explain/_ShapMarginalValues.py +191 -0
  18. pyagrum/explain/_ShapleyValues.py +298 -0
  19. pyagrum/explain/__init__.py +81 -0
  20. pyagrum/explain/_explGeneralizedMarkovBlanket.py +152 -0
  21. pyagrum/explain/_explIndependenceListForPairs.py +146 -0
  22. pyagrum/explain/_explInformationGraph.py +264 -0
  23. pyagrum/explain/notebook/__init__.py +54 -0
  24. pyagrum/explain/notebook/_bar.py +142 -0
  25. pyagrum/explain/notebook/_beeswarm.py +174 -0
  26. pyagrum/explain/notebook/_showShapValues.py +97 -0
  27. pyagrum/explain/notebook/_waterfall.py +220 -0
  28. pyagrum/explain/shapley.py +225 -0
  29. pyagrum/lib/explain.py +11 -490
  30. pyagrum/pyagrum.py +17 -10
  31. {pyagrum_nightly-2.2.1.9.dev202510271761405498.dist-info → pyagrum_nightly-2.3.0.9.dev202510291761586496.dist-info}/METADATA +1 -1
  32. {pyagrum_nightly-2.2.1.9.dev202510271761405498.dist-info → pyagrum_nightly-2.3.0.9.dev202510291761586496.dist-info}/RECORD +36 -12
  33. pyagrum/lib/shapley.py +0 -661
  34. {pyagrum_nightly-2.2.1.9.dev202510271761405498.dist-info → pyagrum_nightly-2.3.0.9.dev202510291761586496.dist-info}/LICENSE.md +0 -0
  35. {pyagrum_nightly-2.2.1.9.dev202510271761405498.dist-info → pyagrum_nightly-2.3.0.9.dev202510291761586496.dist-info}/LICENSES/LGPL-3.0-or-later.txt +0 -0
  36. {pyagrum_nightly-2.2.1.9.dev202510271761405498.dist-info → pyagrum_nightly-2.3.0.9.dev202510291761586496.dist-info}/LICENSES/MIT.txt +0 -0
  37. {pyagrum_nightly-2.2.1.9.dev202510271761405498.dist-info → pyagrum_nightly-2.3.0.9.dev202510291761586496.dist-info}/WHEEL +0 -0
pyagrum/lib/shapley.py DELETED
@@ -1,661 +0,0 @@
1
- ############################################################################
2
- # This file is part of the aGrUM/pyAgrum library. #
3
- # #
4
- # Copyright (c) 2005-2025 by #
5
- # - Pierre-Henri WUILLEMIN(_at_LIP6) #
6
- # - Christophe GONZALES(_at_AMU) #
7
- # #
8
- # The aGrUM/pyAgrum library is free software; you can redistribute it #
9
- # and/or modify it under the terms of either : #
10
- # #
11
- # - the GNU Lesser General Public License as published by #
12
- # the Free Software Foundation, either version 3 of the License, #
13
- # or (at your option) any later version, #
14
- # - the MIT license (MIT), #
15
- # - or both in dual license, as here. #
16
- # #
17
- # (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) #
18
- # #
19
- # This aGrUM/pyAgrum library is distributed in the hope that it will be #
20
- # useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, #
21
- # INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS #
22
- # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
23
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
24
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, #
25
- # ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR #
26
- # OTHER DEALINGS IN THE SOFTWARE. #
27
- # #
28
- # See LICENCES for more details. #
29
- # #
30
- # SPDX-FileCopyrightText: Copyright 2005-2025 #
31
- # - Pierre-Henri WUILLEMIN(_at_LIP6) #
32
- # - Christophe GONZALES(_at_AMU) #
33
- # SPDX-License-Identifier: LGPL-3.0-or-later OR MIT #
34
- # #
35
- # Contact : info_at_agrum_dot_org #
36
- # homepage : http://agrum.gitlab.io #
37
- # gitlab : https://gitlab.com/agrumery/agrum #
38
- # #
39
- ############################################################################
40
-
41
- """
42
- tools for BN qualitative analysis and explainability
43
- """
44
-
45
- import math
46
- from typing import Dict
47
- import itertools
48
-
49
- import numpy as np
50
- import pandas as pd
51
-
52
- import matplotlib as mpl
53
- import matplotlib.pyplot as plt
54
-
55
- from copy import deepcopy
56
- from mpl_toolkits.axes_grid1 import make_axes_locatable
57
- import matplotlib.colors as mcolors
58
-
59
- import pyagrum as gum
60
-
61
- _cdict = {
62
- "red": ((0.0, 0.1, 0.3), (1.0, 0.6, 1.0)),
63
- "green": ((0.0, 0.0, 0.0), (1.0, 0.6, 0.8)),
64
- "blue": ((0.0, 0.0, 0.0), (1.0, 1, 0.8)),
65
- }
66
- _INFOcmap = mpl.colors.LinearSegmentedColormap("my_colormap", _cdict, 256)
67
-
68
-
69
- class ShapValues:
70
- """
71
- The ShapValue class implements the calculation of Shap values in Bayesian networks.
72
-
73
- The main implementation is based on Conditional Shap values [3]_, but the Interventional calculation method proposed in [2]_ is also present. In addition, a new causal method, based on [1]_, is implemented which is well suited for Bayesian networks.
74
-
75
- .. [1] Heskes, T., Sijben, E., Bucur, I., & Claassen, T. (2020). Causal Shapley Values: Exploiting Causal Knowledge. 34th Conference on Neural Information Processing Systems. Vancouver, Canada.
76
-
77
- .. [2] Janzing, D., Minorics, L., & Blöbaum, P. (2019). Feature relevance quantification in explainable AI: A causality problem. arXiv: Machine Learning. Retrieved 6 24, 2021, from https://arxiv.org/abs/1910.13413
78
-
79
- .. [3] Lundberg, S. M., & Su-In, L. (2017). A Unified Approach to Interpreting Model. 31st Conference on Neural Information Processing Systems. Long Beach, CA, USA.
80
- """
81
-
82
- @staticmethod
83
- def _logit(p):
84
- return np.log(p / (1 - p))
85
-
86
- @staticmethod
87
- def _comb(n, k):
88
- return math.comb(n, k)
89
-
90
- @staticmethod
91
- def _fact(n):
92
- return math.factorial(n)
93
-
94
- def __init__(self, bn, target):
95
- self.bn = bn
96
- self.target = target
97
- self.feats_names = self._get_feats_name(bn, target)
98
- self.results = None
99
-
100
- ################################## VARIABLES ##################################
101
-
102
- @staticmethod
103
- def _get_feats_name(bn, target):
104
- list_feats_name = list(bn.names())
105
- list_feats_name.remove(target)
106
- return list_feats_name
107
-
108
- def _get_list_names_order(self):
109
- ### Return a list of BN's nodes names
110
- list_node_names = [None] * len(self.bn.names())
111
- for name in self.bn.names():
112
- list_node_names[self.bn.idFromName(name)] = name
113
- return list_node_names
114
-
115
- @staticmethod
116
- def _coal_encoding(convert, coal):
117
- ### Convert a list of nodes : ['X', 'Z'] to an array of 0 and 1.
118
- ### ['X', 'Z'] -> [1,0,1,0]
119
- temp = np.zeros(len(convert), dtype=int)
120
- for var in coal:
121
- i = convert.index(var)
122
- temp[i] = 1
123
- return temp
124
-
125
- def _get_markov_blanket(self):
126
- feats_markov_blanket = []
127
- for i in gum.MarkovBlanket(self.bn, self.target).nodes():
128
- convert = self._get_list_names_order()
129
- feats_markov_blanket.append(convert[i])
130
- feats_markov_blanket.remove(self.target)
131
- return feats_markov_blanket
132
-
133
- ################################## Get All Combinations ##################################
134
-
135
- def _get_all_coal_compress(self):
136
- ### Return : all coalitions with the feature
137
- return (
138
- list(itertools.compress(self.feats_names, mask))
139
- for mask in itertools.product(*[[0, 1]] * (len(self.feats_names)))
140
- )
141
-
142
- ################################## PREDICTION ##################################
143
- def _filtrage(self, df, conditions):
144
- ### Return : a selected part of DataFrame based on conditions, conditions must be in a dict :
145
- ### The key is the name of the features and the value is the value that the features should take
146
- ### Example : {'X1':0, 'X2':1}
147
- if conditions == {}:
148
- return df
149
-
150
- first = next(iter(conditions))
151
- new_df = df[df[first] == conditions[first]]
152
- conditions.pop(first)
153
- return self._filtrage(new_df, conditions)
154
-
155
- def _init_Inference(self):
156
- ie = gum.LazyPropagation(self.bn)
157
- ie.addTarget(self.target)
158
- for name in self.feats_names:
159
- ie.addEvidence(name, 0)
160
- return ie
161
-
162
- ##### Prediction fonctions ####
163
-
164
- def _pred_markov_blanket(self, df, ie, markov_blanket):
165
- unique = df.groupby(markov_blanket).agg(freq=(self.target, "count")).reset_index()
166
- result = 0
167
- for i in range(len(unique)):
168
- for name in markov_blanket:
169
- ie.chgEvidence(name, str(unique[name].iloc[i]))
170
- ie.makeInference()
171
- predicted = ie.posterior(self.target).toarray()[1]
172
- result = result + self._logit(predicted) * unique["freq"].iloc[i] / len(df)
173
- return result
174
-
175
- def _pred_markov_blanket_logit(self, df, ie, markov_blanket):
176
- unique = df.groupby(markov_blanket).agg(freq=("Y", "count")).reset_index()
177
- result = 0
178
- for i in range(len(unique)):
179
- for name in markov_blanket:
180
- ie.chgEvidence(name, str(unique[name].iloc[i]))
181
- ie.makeInference()
182
- predicted = ie.posterior(self.target).toarray()[1]
183
- result = result + predicted * unique["freq"].iloc[i] / len(df)
184
- return self._logit(result)
185
-
186
- def _evidenceImpact(self, condi, ie):
187
- ie.eraseAllEvidence()
188
- for key in condi.keys():
189
- ie.addEvidence(key, str(condi[key]))
190
- ie.makeInference()
191
- return self._logit(ie.posterior(self.target).toarray()[1])
192
-
193
- ############################## MARGINAL#########################################
194
-
195
- def _predict_marginal(self, df, ie):
196
- result = []
197
- for i in range(len(df)):
198
- for name in self.feats_names:
199
- ie.chgEvidence(name, str(df[name].iloc[i]))
200
- ie.makeInference()
201
- result.append(ie.posterior(self.target).toarray()[1])
202
- return np.array(result)
203
-
204
- ################################## COMPUTE SHAP ##################################
205
-
206
- def _compute_SHAP_i(self, S_U_i, S, v, size_S):
207
- size_all_features = len(self.bn.nodes()) - 1
208
- diff = v[f"{S_U_i}"] - v[f"{S}"]
209
- return diff / self._invcoeff_shap(size_S, size_all_features)
210
-
211
- def _invcoeff_shap(self, S_size, len_features):
212
- return (len_features - S_size) * self._comb(len_features, S_size)
213
-
214
- ################################## Get the two coalitions to substract S_U_i and S #################################
215
-
216
- @staticmethod
217
- def _gen_coalitions2(list_feats):
218
- ### !!!! THE FEATURE i HAVE TO BE THE LAST IN THE LIST OF FEATURES !!!
219
- ### Return : all coalitions with the feature
220
- for mask in itertools.product(*[[0, 1]] * (len(list_feats) - 1)):
221
- S_U_i = itertools.compress(list_feats, mask + (1,))
222
- S = itertools.compress(list_feats, mask + (0,))
223
- yield S_U_i, S
224
-
225
- ################################## Function to Compute CONDITIONNAL SHAP Value ##################################
226
- def _conditional(self, train: pd.DataFrame) -> Dict[str, float]:
227
- """
228
- Compute the conditional Shap Values for each variables.
229
-
230
- Parameters
231
- ----------
232
- train :pandas.DataFrame
233
- the database
234
-
235
- Returns
236
- -------
237
- a dictionary Dict[str,float]
238
- """
239
-
240
- ie = self._init_Inference()
241
-
242
- v = train.groupby(self.feats_names).agg(freq=(self.feats_names[0], "count")).reset_index()
243
-
244
- convert = self._get_list_names_order()
245
-
246
- for i in range(len(v)):
247
- v["Baseline"] = self._evidenceImpact({}, ie)
248
- for coal in self._get_all_coal_compress():
249
- S = list(coal)
250
- condi = {}
251
- for var in S:
252
- condi[var] = v.loc[i, var]
253
- col_arr_name = self._coal_encoding(convert, coal)
254
- v.loc[i, f"{col_arr_name}"] = self._evidenceImpact(condi, ie)
255
- df = pd.DataFrame()
256
- for feat in self.feats_names:
257
- list_i_last = self.feats_names.copy()
258
- index_i = list_i_last.index(feat)
259
- list_i_last[len(list_i_last) - 1], list_i_last[index_i] = list_i_last[index_i], list_i_last[len(list_i_last) - 1]
260
-
261
- somme = 0
262
- for coal1, coal2 in self._gen_coalitions2(list_i_last):
263
- S_U_i = self._coal_encoding(convert, list(coal1))
264
- S = self._coal_encoding(convert, list(coal2))
265
- size_S = sum(S)
266
- somme = somme + self._compute_SHAP_i(S_U_i, S, v, size_S)
267
- df[feat] = somme
268
- self.results = df
269
- return df, v
270
-
271
- def conditional(self, train, plot=False, plot_importance=False, percentage=False, filename=None):
272
- """
273
- Compute the conditional Shap Values for each variables.
274
-
275
- Parameters
276
- ----------
277
- train :pandas.DataFrame
278
- the database
279
- plot: bool
280
- if True, plot the violin graph of the shap values
281
- plot_importance: bool
282
- if True, plot the importance plot
283
- percentage: bool
284
- if True, the importance plot is shown in percent.
285
- filename: str
286
- if not None, save the plot in the file
287
-
288
- Returns
289
- -------
290
- a dictionary Dict[str,float]
291
- """
292
- results, v = self._conditional(train)
293
- res = {}
294
- for col in results.columns:
295
- res[col] = abs(results[col]).mean()
296
-
297
- self._plotResults(results, v, plot, plot_importance, percentage)
298
-
299
- if plot or plot_importance:
300
- if filename is not None:
301
- plt.savefig(filename)
302
- else:
303
- plt.show()
304
- plt.close()
305
-
306
- return res
307
-
308
- ################################## Function to Compute MARGINAL SHAP Value ##################################
309
- def _marginal(self, df, size_sample_df):
310
- ie = self._init_Inference()
311
- convert = self._get_list_names_order()
312
- test = df[:size_sample_df]
313
- v = df.groupby(self.feats_names).agg(freq=(self.feats_names[0], "count")).reset_index()
314
- df = pd.DataFrame()
315
-
316
- for i in range(len(v)):
317
- for coal in self._get_all_coal_compress():
318
- intervention = test.copy()
319
- for var in coal:
320
- intervention[var] = v.loc[i, var]
321
- col_arr_name = self._coal_encoding(convert, coal)
322
- v.loc[i, f"{col_arr_name}"] = np.mean(self._logit(self._predict_marginal(intervention, ie)))
323
-
324
- for feat in self.feats_names:
325
- list_i_last = self.feats_names.copy()
326
- index_i = list_i_last.index(feat)
327
- list_i_last[len(list_i_last) - 1], list_i_last[index_i] = list_i_last[index_i], list_i_last[len(list_i_last) - 1]
328
-
329
- somme = 0
330
- for coal1, coal2 in self._gen_coalitions2(list_i_last):
331
- S_U_i = self._coal_encoding(convert, list(coal1))
332
- S = self._coal_encoding(convert, list(coal2))
333
- size_S = sum(S)
334
- somme = somme + self._compute_SHAP_i(S_U_i, S, v, size_S)
335
-
336
- df[feat] = somme
337
-
338
- self.results = df
339
- return df, v
340
-
341
- def marginal(self, train, sample_size=200, plot=False, plot_importance=False, percentage=False, filename=None):
342
- """
343
- Compute the marginal Shap Values for each variables.
344
-
345
- Parameters
346
- ----------
347
- train :pandas.DataFrame
348
- the database
349
- sample_size : int
350
- The computation of marginal ShapValue is very slow. The parameter allow to compute only on a fragment of the database.
351
- plot: bool
352
- if True, plot the violin graph of the shap values
353
- plot_importance: bool
354
- if True, plot the importance plot
355
- percentage: bool
356
- if True, the importance plot is shown in percent.
357
- filename: str
358
- if not None, save the plot in the file
359
-
360
- Returns
361
- -------
362
- a dictionary Dict[str,float]
363
- """
364
- results, v = self._marginal(train, sample_size)
365
- res = {}
366
- for col in results.columns:
367
- res[col] = abs(results[col]).mean()
368
-
369
- self._plotResults(results, v, plot, plot_importance, percentage)
370
-
371
- if plot or plot_importance:
372
- if filename is not None:
373
- plt.savefig(filename)
374
- else:
375
- plt.show()
376
- plt.close()
377
-
378
- return res
379
-
380
- ################################## MUTILATION ######################################
381
-
382
- @staticmethod
383
- def _mutilation_Inference(bn, feats_name, target):
384
- ie = gum.LazyPropagation(bn)
385
- ie.addTarget(target)
386
- for name in feats_name:
387
- ie.addEvidence(name, 0)
388
- return ie
389
-
390
- def _causal(self, train):
391
- v = train.groupby(self.feats_names).agg(freq=(self.feats_names[0], "count")).reset_index()
392
- ie = self._init_Inference()
393
-
394
- convert = self._get_list_names_order()
395
- df = pd.DataFrame()
396
-
397
- v["Baseline"] = self._evidenceImpact({}, ie)
398
-
399
- for coal in self._get_all_coal_compress():
400
- for i in range(len(v)):
401
- bn_temp = gum.BayesNet(self.bn)
402
- S = list(coal)
403
- condi = {}
404
- for var in S:
405
- condi[var] = v.loc[i, var]
406
- for parent in bn_temp.parents(var):
407
- bn_temp.eraseArc(parent, bn_temp.idFromName(var))
408
- ie = self._mutilation_Inference(bn_temp, self.feats_names, self.target)
409
- col_arr_name = self._coal_encoding(convert, coal)
410
- v.loc[i, f"{col_arr_name}"] = self._evidenceImpact(condi, ie)
411
- for feat in self.feats_names:
412
- list_i_last = self.feats_names.copy()
413
- index_i = list_i_last.index(feat)
414
- list_i_last[len(list_i_last) - 1], list_i_last[index_i] = list_i_last[index_i], list_i_last[len(list_i_last) - 1]
415
-
416
- somme = 0
417
- for coal1, coal2 in self._gen_coalitions2(list_i_last):
418
- S_U_i = self._coal_encoding(convert, list(coal1))
419
- S = self._coal_encoding(convert, list(coal2))
420
- size_S = sum(S)
421
- somme = somme + self._compute_SHAP_i(S_U_i, S, v, size_S)
422
- df[feat] = somme
423
-
424
- self.results = df
425
- return df, v
426
-
427
- def causal(self, train, plot=False, plot_importance=False, percentage=False, filename=None):
428
- """
429
- Compute the causal Shap Values for each variables.
430
-
431
- Parameters
432
- ----------
433
- train :pandas.DataFrame
434
- the database
435
- plot: bool
436
- if True, plot the violin graph of the shap values
437
- plot_importance: bool
438
- if True, plot the importance plot
439
- percentage: bool
440
- if True, the importance plot is shown in percent.
441
- filename: str
442
- if not None, save the plot in the file
443
-
444
- Returns
445
- -------
446
- a dictionary Dict[str,float]
447
- """
448
- results, v = self._causal(train)
449
-
450
- res = {}
451
- for col in results.columns:
452
- res[col] = abs(results[col]).mean()
453
-
454
- self._plotResults(results, v, plot, plot_importance, percentage)
455
-
456
- if plot or plot_importance:
457
- if filename is not None:
458
- plt.savefig(filename)
459
- else:
460
- plt.show()
461
- plt.close()
462
-
463
- return res
464
-
465
- ################################## PLOT SHAP Value ##################################
466
-
467
- def _plotResults(self, results, v, plot=True, plot_importance=False, percentage=False):
468
- ax1 = ax2 = None
469
- if plot and plot_importance:
470
- fig = plt.figure(figsize=(15, 0.5 * len(results.columns)))
471
- ax1 = fig.add_subplot(1, 2, 1)
472
- ax2 = fig.add_subplot(1, 2, 2)
473
- if plot:
474
- shap_dict = results.to_dict(orient="list")
475
- sorted_dict = dict(sorted(shap_dict.items(), key=lambda x: sum(abs(i) for i in x[1]) / len(x[1])))
476
- data = np.array([sorted_dict[key] for key in sorted_dict])
477
- features = sorted_dict.keys()
478
- v = v[features]
479
- colors = v.transpose().to_numpy()
480
- self._plot_beeswarm_(data, colors, 250, 1.5, features, ax=ax1)
481
- if plot_importance:
482
- self._plot_importance(results, percentage, ax=ax2)
483
-
484
- @staticmethod
485
- def _plot_beeswarm_(data, colors, N, K, features, cmap=None, ax=None):
486
- """
487
- returns a beeswarm plot (or stripplot) from a given data.
488
-
489
- Parameters
490
- ----------
491
- data: list of numpy array.
492
- Each elements of the list is a numpy array containing the shapley values for the feature to be displayed for a category.
493
-
494
- colors: list of numpy array.
495
- Each elements of the list is a numpy array containing the values of the data point to be displayed for a category.
496
-
497
- Returns
498
- -------
499
- matplotlib.pyplot.scatter
500
- """
501
- min_value = np.min(data, axis=(0, 1))
502
- max_value = np.max(data, axis=(0, 1))
503
- bin_size = (max_value - min_value) / N
504
- if bin_size == 0:
505
- bin_size = 1
506
- horiz_shift = K * bin_size
507
-
508
- if ax is None:
509
- _, ax = plt.subplots()
510
- if cmap is None:
511
- # Set Color Map
512
- ## Define the hex colors
513
- color1 = gum.config["notebook", "tensor_color_0"]
514
- color2 = gum.config["notebook", "tensor_color_1"]
515
-
516
- ## Create the custom colormap
517
- cmap = mcolors.LinearSegmentedColormap.from_list("", [color1, color2])
518
-
519
- for n, d in enumerate(data):
520
- pos = n + 1
521
-
522
- d_shifted = d + np.random.normal(0, horiz_shift, len(d))
523
- # Sorting values
524
- d_sort = np.sort(d_shifted)
525
-
526
- # Creation of bins
527
- d_x = bin_size
528
- if (np.max(d_sort) - np.min(d_sort)) % d_x == 0:
529
- nb_bins = (np.max(d_sort) - np.min(d_sort)) // d_x
530
- else:
531
- nb_bins = (np.max(d_sort) - np.min(d_sort)) // d_x + 1
532
- bins = [np.min(d_sort) + i * d_x for i in range(int(nb_bins) + 1)]
533
-
534
- # Group by Bins
535
- subarr = []
536
- for k in range(1, len(bins)):
537
- group = d_sort[np.logical_and(d_sort >= bins[k - 1], d_sort < bins[k])]
538
- subarr.append(group)
539
-
540
- # For each bin compute the d_y (vertical shift)
541
- d_y = 0.025
542
- subarr_jitter = deepcopy(subarr)
543
- for i in range(len(subarr)):
544
- L = subarr[i].size
545
- if L > 0:
546
- for j in range(L):
547
- shift = d_y * (L - 1) / 2 - d_y * j
548
- subarr_jitter[i][j] = shift
549
-
550
- jitter = np.concatenate(subarr_jitter)
551
-
552
- sc = ax.scatter(d_shifted, pos + jitter, s=10, c=colors[n], cmap=cmap, alpha=0.7)
553
-
554
- ## Create the colorbar
555
- divider = make_axes_locatable(ax)
556
- cax = divider.append_axes("right", size="5%", pad=0.1)
557
-
558
- cbar = plt.colorbar(sc, cax=cax, aspect=80)
559
- cbar.set_label("Data Point Value")
560
-
561
- ## Add text above and below the colorbar
562
- cax.text(0.5, 1.025, "High", transform=cax.transAxes, ha="center", va="center", fontsize=10)
563
- cax.text(0.5, -0.025, "Low", transform=cax.transAxes, ha="center", va="center", fontsize=10)
564
-
565
- ## Add x-axis tick labels
566
- ax.set_yticks([i + 1 for i in range(len(features))])
567
- ax.set_yticklabels(features)
568
-
569
- ## Set axis labels and title
570
- ax.set_ylabel("Features")
571
- ax.set_xlabel("Impact on output")
572
- ax.set_title("Shapley value (impact on model output)")
573
-
574
- # Show plot
575
- return ax.get_figure()
576
-
577
- @staticmethod
578
- def _plot_importance(results, percentage=False, ax=None):
579
- series = pd.DataFrame(abs(results).mean(), columns=["value"])
580
- series["feat"] = series.index
581
-
582
- if ax is None:
583
- _, ax = plt.subplots()
584
-
585
- if percentage:
586
- series["value"] = series["value"].div(series["value"].sum(axis=0)).multiply(100)
587
- series = series.sort_values("value", ascending=True)
588
- ax.barh(series.feat, series.value, align="center")
589
- ax.set_xlabel("Mean(|SHAP value|)")
590
- ax.set_title("Feature Importance in %")
591
- else:
592
- series = series.sort_values("value", ascending=True)
593
- ax.barh(series.feat, series.value, align="center")
594
- ax.set_xlabel("Mean(|SHAP value|)")
595
- ax.set_title("Feature Importance")
596
-
597
- return ax.get_figure()
598
-
599
- @staticmethod
600
- def _plot_scatter(results, ax=None):
601
- if ax is None:
602
- _, ax = plt.subplots()
603
-
604
- res = {}
605
- for col in results.columns:
606
- res[col] = results[col].to_numpy()
607
- names = list(res.keys())
608
- values = list(res.values())
609
- for xe, ye in zip(names, values):
610
- ax.scatter(ye, [xe] * len(ye))
611
- ax.set_title("Shapley value (impact on model output)")
612
- return ax.get_figure()
613
-
614
- @staticmethod
615
- def _plot_violin(results, ax=None):
616
- data = []
617
- pos = []
618
- label = []
619
- series = pd.DataFrame(abs(results).mean(), columns=["value"])
620
- series = series.sort_values("value", ascending=True)
621
- series["feat"] = series.index
622
- for i, col in enumerate(series.feat):
623
- data.append(results[col].to_numpy())
624
- pos.append(i)
625
- label.append(col)
626
- if ax is None:
627
- _, ax = plt.subplots()
628
- ax.violinplot(data, pos, vert=False)
629
- ax.set_yticks(pos)
630
- ax.set_yticklabels(label)
631
- ax.set_title("Shapley value (impact on model output)")
632
- return ax.get_figure()
633
-
634
-
635
- def getShapValues(bn, shaps, cmap="plasma"):
636
- """
637
- Just a wrapper around BN2dot to easily show the Shap values
638
-
639
- Parameters
640
- ----------
641
- bn : pyagrum.BayesNet
642
- The Bayesian network
643
- shaps: dict[str,float]
644
- The (Shap) values associates to each variable
645
- cmap: Matplotlib.ColorMap
646
- The colormap used for colouring the nodes
647
-
648
- Returns
649
- -------
650
- a pydot.graph
651
- """
652
- from pyagrum.lib.bn2graph import BN2dot
653
-
654
- norm_color = {}
655
- raw = list(shaps.values())
656
- norm = [float(i) / sum(raw) for i in raw]
657
- for i, feat in enumerate(list(shaps.keys())):
658
- norm_color[feat] = norm[i]
659
- cm = plt.get_cmap(cmap)
660
- g = BN2dot(bn, nodeColor=norm_color, cmapNode=cm)
661
- return g