pyAgrum-nightly 2.1.1.9.dev202506061747485979__cp310-abi3-manylinux2014_aarch64.whl → 2.3.1.9.dev202601031765915415__cp310-abi3-manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. pyagrum/__init__.py +6 -2
  2. pyagrum/_pyagrum.so +0 -0
  3. pyagrum/bnmixture/BNMInference.py +6 -2
  4. pyagrum/bnmixture/BNMLearning.py +12 -2
  5. pyagrum/bnmixture/BNMixture.py +6 -2
  6. pyagrum/bnmixture/__init__.py +6 -2
  7. pyagrum/bnmixture/notebook.py +6 -2
  8. pyagrum/causal/_CausalFormula.py +6 -2
  9. pyagrum/causal/_CausalModel.py +6 -2
  10. pyagrum/causal/__init__.py +6 -2
  11. pyagrum/causal/_causalImpact.py +6 -2
  12. pyagrum/causal/_dSeparation.py +6 -2
  13. pyagrum/causal/_doAST.py +6 -2
  14. pyagrum/causal/_doCalculus.py +6 -2
  15. pyagrum/causal/_doorCriteria.py +6 -2
  16. pyagrum/causal/_exceptions.py +6 -2
  17. pyagrum/causal/_types.py +6 -2
  18. pyagrum/causal/causalEffectEstimation/_CausalEffectEstimation.py +6 -2
  19. pyagrum/causal/causalEffectEstimation/_IVEstimators.py +6 -2
  20. pyagrum/causal/causalEffectEstimation/_RCTEstimators.py +6 -2
  21. pyagrum/causal/causalEffectEstimation/__init__.py +6 -2
  22. pyagrum/causal/causalEffectEstimation/_backdoorEstimators.py +6 -2
  23. pyagrum/causal/causalEffectEstimation/_causalBNEstimator.py +6 -2
  24. pyagrum/causal/causalEffectEstimation/_frontdoorEstimators.py +6 -2
  25. pyagrum/causal/causalEffectEstimation/_learners.py +6 -2
  26. pyagrum/causal/causalEffectEstimation/_utils.py +6 -2
  27. pyagrum/causal/notebook.py +8 -3
  28. pyagrum/clg/CLG.py +6 -2
  29. pyagrum/clg/GaussianVariable.py +6 -2
  30. pyagrum/clg/SEM.py +6 -2
  31. pyagrum/clg/__init__.py +6 -2
  32. pyagrum/clg/canonicalForm.py +6 -2
  33. pyagrum/clg/constants.py +6 -2
  34. pyagrum/clg/forwardSampling.py +6 -2
  35. pyagrum/clg/learning.py +6 -2
  36. pyagrum/clg/notebook.py +6 -2
  37. pyagrum/clg/variableElimination.py +6 -2
  38. pyagrum/common.py +7 -3
  39. pyagrum/config.py +7 -2
  40. pyagrum/ctbn/CIM.py +6 -2
  41. pyagrum/ctbn/CTBN.py +6 -2
  42. pyagrum/ctbn/CTBNGenerator.py +6 -2
  43. pyagrum/ctbn/CTBNInference.py +6 -2
  44. pyagrum/ctbn/CTBNLearner.py +6 -2
  45. pyagrum/ctbn/SamplesStats.py +6 -2
  46. pyagrum/ctbn/StatsIndepTest.py +6 -2
  47. pyagrum/ctbn/__init__.py +6 -2
  48. pyagrum/ctbn/constants.py +6 -2
  49. pyagrum/ctbn/notebook.py +6 -2
  50. pyagrum/deprecated.py +6 -2
  51. pyagrum/explain/_ComputationCausal.py +75 -0
  52. pyagrum/explain/_ComputationConditional.py +48 -0
  53. pyagrum/explain/_ComputationMarginal.py +48 -0
  54. pyagrum/explain/_CustomShapleyCache.py +110 -0
  55. pyagrum/explain/_Explainer.py +176 -0
  56. pyagrum/explain/_Explanation.py +70 -0
  57. pyagrum/explain/_FIFOCache.py +54 -0
  58. pyagrum/explain/_ShallCausalValues.py +204 -0
  59. pyagrum/explain/_ShallConditionalValues.py +155 -0
  60. pyagrum/explain/_ShallMarginalValues.py +155 -0
  61. pyagrum/explain/_ShallValues.py +296 -0
  62. pyagrum/explain/_ShapCausalValues.py +208 -0
  63. pyagrum/explain/_ShapConditionalValues.py +126 -0
  64. pyagrum/explain/_ShapMarginalValues.py +191 -0
  65. pyagrum/explain/_ShapleyValues.py +298 -0
  66. pyagrum/explain/__init__.py +81 -0
  67. pyagrum/explain/_explGeneralizedMarkovBlanket.py +152 -0
  68. pyagrum/explain/_explIndependenceListForPairs.py +146 -0
  69. pyagrum/explain/_explInformationGraph.py +264 -0
  70. pyagrum/explain/notebook/__init__.py +54 -0
  71. pyagrum/explain/notebook/_bar.py +142 -0
  72. pyagrum/explain/notebook/_beeswarm.py +174 -0
  73. pyagrum/explain/notebook/_showShapValues.py +97 -0
  74. pyagrum/explain/notebook/_waterfall.py +220 -0
  75. pyagrum/explain/shapley.py +225 -0
  76. pyagrum/lib/__init__.py +6 -2
  77. pyagrum/lib/_colors.py +6 -2
  78. pyagrum/lib/bn2graph.py +6 -2
  79. pyagrum/lib/bn2roc.py +6 -2
  80. pyagrum/lib/bn2scores.py +6 -2
  81. pyagrum/lib/bn_vs_bn.py +6 -2
  82. pyagrum/lib/cn2graph.py +6 -2
  83. pyagrum/lib/discreteTypeProcessor.py +99 -81
  84. pyagrum/lib/discretizer.py +6 -2
  85. pyagrum/lib/dynamicBN.py +6 -2
  86. pyagrum/lib/explain.py +17 -492
  87. pyagrum/lib/export.py +6 -2
  88. pyagrum/lib/id2graph.py +6 -2
  89. pyagrum/lib/image.py +6 -2
  90. pyagrum/lib/ipython.py +6 -2
  91. pyagrum/lib/mrf2graph.py +6 -2
  92. pyagrum/lib/notebook.py +6 -2
  93. pyagrum/lib/proba_histogram.py +6 -2
  94. pyagrum/lib/utils.py +6 -2
  95. pyagrum/pyagrum.py +976 -126
  96. pyagrum/skbn/_MBCalcul.py +6 -2
  97. pyagrum/skbn/__init__.py +6 -2
  98. pyagrum/skbn/_learningMethods.py +6 -2
  99. pyagrum/skbn/_utils.py +6 -2
  100. pyagrum/skbn/bnclassifier.py +6 -2
  101. pyagrum_nightly-2.1.1.9.dev202506061747485979.dist-info/LICENSE → pyagrum_nightly-2.3.1.9.dev202601031765915415.dist-info/LICENSE.md +3 -1
  102. pyagrum_nightly-2.3.1.9.dev202601031765915415.dist-info/LICENSES/LGPL-3.0-or-later.txt +304 -0
  103. pyagrum_nightly-2.3.1.9.dev202601031765915415.dist-info/LICENSES/MIT.txt +18 -0
  104. {pyagrum_nightly-2.1.1.9.dev202506061747485979.dist-info → pyagrum_nightly-2.3.1.9.dev202601031765915415.dist-info}/METADATA +3 -1
  105. pyagrum_nightly-2.3.1.9.dev202601031765915415.dist-info/RECORD +107 -0
  106. {pyagrum_nightly-2.1.1.9.dev202506061747485979.dist-info → pyagrum_nightly-2.3.1.9.dev202601031765915415.dist-info}/WHEEL +1 -1
  107. pyagrum/lib/shapley.py +0 -657
  108. pyagrum_nightly-2.1.1.9.dev202506061747485979.dist-info/LICENSE.LGPL +0 -165
  109. pyagrum_nightly-2.1.1.9.dev202506061747485979.dist-info/LICENSE.MIT +0 -17
  110. pyagrum_nightly-2.1.1.9.dev202506061747485979.dist-info/RECORD +0 -83
pyagrum/lib/shapley.py DELETED
@@ -1,657 +0,0 @@
1
- ############################################################################
2
- # This file is part of the aGrUM/pyAgrum library. #
3
- # #
4
- # Copyright (c) 2005-2025 by #
5
- # - Pierre-Henri WUILLEMIN(_at_LIP6) #
6
- # - Christophe GONZALES(_at_AMU) #
7
- # #
8
- # The aGrUM/pyAgrum library is free software; you can redistribute it #
9
- # and/or modify it under the terms of either : #
10
- # #
11
- # - the GNU Lesser General Public License as published by #
12
- # the Free Software Foundation, either version 3 of the License, #
13
- # or (at your option) any later version, #
14
- # - the MIT license (MIT), #
15
- # - or both in dual license, as here. #
16
- # #
17
- # (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) #
18
- # #
19
- # This aGrUM/pyAgrum library is distributed in the hope that it will be #
20
- # useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, #
21
- # INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS #
22
- # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
23
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
24
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, #
25
- # ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR #
26
- # OTHER DEALINGS IN THE SOFTWARE. #
27
- # #
28
- # See the GNU Lesser General Public License (LICENSE.LGPL) and the MIT #
29
- # licence (LICENSE.MIT) for more details. #
30
- # #
31
- # Contact : info_at_agrum_dot_org #
32
- # homepage : http://agrum.gitlab.io #
33
- # gitlab : https://gitlab.com/agrumery/agrum #
34
- # #
35
- ############################################################################
36
-
37
- """
38
- tools for BN qualitative analysis and explainability
39
- """
40
-
41
- import math
42
- from typing import Dict
43
- import itertools
44
-
45
- import numpy as np
46
- import pandas as pd
47
-
48
- import matplotlib as mpl
49
- import matplotlib.pyplot as plt
50
-
51
- from copy import deepcopy
52
- from mpl_toolkits.axes_grid1 import make_axes_locatable
53
- import matplotlib.colors as mcolors
54
-
55
- import pyagrum as gum
56
-
57
- _cdict = {
58
- "red": ((0.0, 0.1, 0.3), (1.0, 0.6, 1.0)),
59
- "green": ((0.0, 0.0, 0.0), (1.0, 0.6, 0.8)),
60
- "blue": ((0.0, 0.0, 0.0), (1.0, 1, 0.8)),
61
- }
62
- _INFOcmap = mpl.colors.LinearSegmentedColormap("my_colormap", _cdict, 256)
63
-
64
-
65
- class ShapValues:
66
- """
67
- The ShapValue class implements the calculation of Shap values in Bayesian networks.
68
-
69
- The main implementation is based on Conditional Shap values [3]_, but the Interventional calculation method proposed in [2]_ is also present. In addition, a new causal method, based on [1]_, is implemented which is well suited for Bayesian networks.
70
-
71
- .. [1] Heskes, T., Sijben, E., Bucur, I., & Claassen, T. (2020). Causal Shapley Values: Exploiting Causal Knowledge. 34th Conference on Neural Information Processing Systems. Vancouver, Canada.
72
-
73
- .. [2] Janzing, D., Minorics, L., & Blöbaum, P. (2019). Feature relevance quantification in explainable AI: A causality problem. arXiv: Machine Learning. Retrieved 6 24, 2021, from https://arxiv.org/abs/1910.13413
74
-
75
- .. [3] Lundberg, S. M., & Su-In, L. (2017). A Unified Approach to Interpreting Model. 31st Conference on Neural Information Processing Systems. Long Beach, CA, USA.
76
- """
77
-
78
- @staticmethod
79
- def _logit(p):
80
- return np.log(p / (1 - p))
81
-
82
- @staticmethod
83
- def _comb(n, k):
84
- return math.comb(n, k)
85
-
86
- @staticmethod
87
- def _fact(n):
88
- return math.factorial(n)
89
-
90
- def __init__(self, bn, target):
91
- self.bn = bn
92
- self.target = target
93
- self.feats_names = self._get_feats_name(bn, target)
94
- self.results = None
95
-
96
- ################################## VARIABLES ##################################
97
-
98
- @staticmethod
99
- def _get_feats_name(bn, target):
100
- list_feats_name = list(bn.names())
101
- list_feats_name.remove(target)
102
- return list_feats_name
103
-
104
- def _get_list_names_order(self):
105
- ### Return a list of BN's nodes names
106
- list_node_names = [None] * len(self.bn.names())
107
- for name in self.bn.names():
108
- list_node_names[self.bn.idFromName(name)] = name
109
- return list_node_names
110
-
111
- @staticmethod
112
- def _coal_encoding(convert, coal):
113
- ### Convert a list of nodes : ['X', 'Z'] to an array of 0 and 1.
114
- ### ['X', 'Z'] -> [1,0,1,0]
115
- temp = np.zeros(len(convert), dtype=int)
116
- for var in coal:
117
- i = convert.index(var)
118
- temp[i] = 1
119
- return temp
120
-
121
- def _get_markov_blanket(self):
122
- feats_markov_blanket = []
123
- for i in gum.MarkovBlanket(self.bn, self.target).nodes():
124
- convert = self._get_list_names_order()
125
- feats_markov_blanket.append(convert[i])
126
- feats_markov_blanket.remove(self.target)
127
- return feats_markov_blanket
128
-
129
- ################################## Get All Combinations ##################################
130
-
131
- def _get_all_coal_compress(self):
132
- ### Return : all coalitions with the feature
133
- return (
134
- list(itertools.compress(self.feats_names, mask))
135
- for mask in itertools.product(*[[0, 1]] * (len(self.feats_names)))
136
- )
137
-
138
- ################################## PREDICTION ##################################
139
- def _filtrage(self, df, conditions):
140
- ### Return : a selected part of DataFrame based on conditions, conditions must be in a dict :
141
- ### The key is the name of the features and the value is the value that the features should take
142
- ### Example : {'X1':0, 'X2':1}
143
- if conditions == {}:
144
- return df
145
-
146
- first = next(iter(conditions))
147
- new_df = df[df[first] == conditions[first]]
148
- conditions.pop(first)
149
- return self._filtrage(new_df, conditions)
150
-
151
- def _init_Inference(self):
152
- ie = gum.LazyPropagation(self.bn)
153
- ie.addTarget(self.target)
154
- for name in self.feats_names:
155
- ie.addEvidence(name, 0)
156
- return ie
157
-
158
- ##### Prediction fonctions ####
159
-
160
- def _pred_markov_blanket(self, df, ie, markov_blanket):
161
- unique = df.groupby(markov_blanket).agg(freq=(self.target, "count")).reset_index()
162
- result = 0
163
- for i in range(len(unique)):
164
- for name in markov_blanket:
165
- ie.chgEvidence(name, str(unique[name].iloc[i]))
166
- ie.makeInference()
167
- predicted = ie.posterior(self.target).toarray()[1]
168
- result = result + self._logit(predicted) * unique["freq"].iloc[i] / len(df)
169
- return result
170
-
171
- def _pred_markov_blanket_logit(self, df, ie, markov_blanket):
172
- unique = df.groupby(markov_blanket).agg(freq=("Y", "count")).reset_index()
173
- result = 0
174
- for i in range(len(unique)):
175
- for name in markov_blanket:
176
- ie.chgEvidence(name, str(unique[name].iloc[i]))
177
- ie.makeInference()
178
- predicted = ie.posterior(self.target).toarray()[1]
179
- result = result + predicted * unique["freq"].iloc[i] / len(df)
180
- return self._logit(result)
181
-
182
- def _evidenceImpact(self, condi, ie):
183
- ie.eraseAllEvidence()
184
- for key in condi.keys():
185
- ie.addEvidence(key, str(condi[key]))
186
- ie.makeInference()
187
- return self._logit(ie.posterior(self.target).toarray()[1])
188
-
189
- ############################## MARGINAL#########################################
190
-
191
- def _predict_marginal(self, df, ie):
192
- result = []
193
- for i in range(len(df)):
194
- for name in self.feats_names:
195
- ie.chgEvidence(name, str(df[name].iloc[i]))
196
- ie.makeInference()
197
- result.append(ie.posterior(self.target).toarray()[1])
198
- return np.array(result)
199
-
200
- ################################## COMPUTE SHAP ##################################
201
-
202
- def _compute_SHAP_i(self, S_U_i, S, v, size_S):
203
- size_all_features = len(self.bn.nodes()) - 1
204
- diff = v[f"{S_U_i}"] - v[f"{S}"]
205
- return diff / self._invcoeff_shap(size_S, size_all_features)
206
-
207
- def _invcoeff_shap(self, S_size, len_features):
208
- return (len_features - S_size) * self._comb(len_features, S_size)
209
-
210
- ################################## Get the two coalitions to substract S_U_i and S #################################
211
-
212
- @staticmethod
213
- def _gen_coalitions2(list_feats):
214
- ### !!!! THE FEATURE i HAVE TO BE THE LAST IN THE LIST OF FEATURES !!!
215
- ### Return : all coalitions with the feature
216
- for mask in itertools.product(*[[0, 1]] * (len(list_feats) - 1)):
217
- S_U_i = itertools.compress(list_feats, mask + (1,))
218
- S = itertools.compress(list_feats, mask + (0,))
219
- yield S_U_i, S
220
-
221
- ################################## Function to Compute CONDITIONNAL SHAP Value ##################################
222
- def _conditional(self, train: pd.DataFrame) -> Dict[str, float]:
223
- """
224
- Compute the conditional Shap Values for each variables.
225
-
226
- Parameters
227
- ----------
228
- train :pandas.DataFrame
229
- the database
230
-
231
- Returns
232
- -------
233
- a dictionary Dict[str,float]
234
- """
235
-
236
- ie = self._init_Inference()
237
-
238
- v = train.groupby(self.feats_names).agg(freq=(self.feats_names[0], "count")).reset_index()
239
-
240
- convert = self._get_list_names_order()
241
-
242
- for i in range(len(v)):
243
- v["Baseline"] = self._evidenceImpact({}, ie)
244
- for coal in self._get_all_coal_compress():
245
- S = list(coal)
246
- condi = {}
247
- for var in S:
248
- condi[var] = v.loc[i, var]
249
- col_arr_name = self._coal_encoding(convert, coal)
250
- v.loc[i, f"{col_arr_name}"] = self._evidenceImpact(condi, ie)
251
- df = pd.DataFrame()
252
- for feat in self.feats_names:
253
- list_i_last = self.feats_names.copy()
254
- index_i = list_i_last.index(feat)
255
- list_i_last[len(list_i_last) - 1], list_i_last[index_i] = list_i_last[index_i], list_i_last[len(list_i_last) - 1]
256
-
257
- somme = 0
258
- for coal1, coal2 in self._gen_coalitions2(list_i_last):
259
- S_U_i = self._coal_encoding(convert, list(coal1))
260
- S = self._coal_encoding(convert, list(coal2))
261
- size_S = sum(S)
262
- somme = somme + self._compute_SHAP_i(S_U_i, S, v, size_S)
263
- df[feat] = somme
264
- self.results = df
265
- return df, v
266
-
267
- def conditional(self, train, plot=False, plot_importance=False, percentage=False, filename=None):
268
- """
269
- Compute the conditional Shap Values for each variables.
270
-
271
- Parameters
272
- ----------
273
- train :pandas.DataFrame
274
- the database
275
- plot: bool
276
- if True, plot the violin graph of the shap values
277
- plot_importance: bool
278
- if True, plot the importance plot
279
- percentage: bool
280
- if True, the importance plot is shown in percent.
281
- filename: str
282
- if not None, save the plot in the file
283
-
284
- Returns
285
- -------
286
- a dictionary Dict[str,float]
287
- """
288
- results, v = self._conditional(train)
289
- res = {}
290
- for col in results.columns:
291
- res[col] = abs(results[col]).mean()
292
-
293
- self._plotResults(results, v, plot, plot_importance, percentage)
294
-
295
- if plot or plot_importance:
296
- if filename is not None:
297
- plt.savefig(filename)
298
- else:
299
- plt.show()
300
- plt.close()
301
-
302
- return res
303
-
304
- ################################## Function to Compute MARGINAL SHAP Value ##################################
305
- def _marginal(self, df, size_sample_df):
306
- ie = self._init_Inference()
307
- convert = self._get_list_names_order()
308
- test = df[:size_sample_df]
309
- v = df.groupby(self.feats_names).agg(freq=(self.feats_names[0], "count")).reset_index()
310
- df = pd.DataFrame()
311
-
312
- for i in range(len(v)):
313
- for coal in self._get_all_coal_compress():
314
- intervention = test.copy()
315
- for var in coal:
316
- intervention[var] = v.loc[i, var]
317
- col_arr_name = self._coal_encoding(convert, coal)
318
- v.loc[i, f"{col_arr_name}"] = np.mean(self._logit(self._predict_marginal(intervention, ie)))
319
-
320
- for feat in self.feats_names:
321
- list_i_last = self.feats_names.copy()
322
- index_i = list_i_last.index(feat)
323
- list_i_last[len(list_i_last) - 1], list_i_last[index_i] = list_i_last[index_i], list_i_last[len(list_i_last) - 1]
324
-
325
- somme = 0
326
- for coal1, coal2 in self._gen_coalitions2(list_i_last):
327
- S_U_i = self._coal_encoding(convert, list(coal1))
328
- S = self._coal_encoding(convert, list(coal2))
329
- size_S = sum(S)
330
- somme = somme + self._compute_SHAP_i(S_U_i, S, v, size_S)
331
-
332
- df[feat] = somme
333
-
334
- self.results = df
335
- return df, v
336
-
337
- def marginal(self, train, sample_size=200, plot=False, plot_importance=False, percentage=False, filename=None):
338
- """
339
- Compute the marginal Shap Values for each variables.
340
-
341
- Parameters
342
- ----------
343
- train :pandas.DataFrame
344
- the database
345
- sample_size : int
346
- The computation of marginal ShapValue is very slow. The parameter allow to compute only on a fragment of the database.
347
- plot: bool
348
- if True, plot the violin graph of the shap values
349
- plot_importance: bool
350
- if True, plot the importance plot
351
- percentage: bool
352
- if True, the importance plot is shown in percent.
353
- filename: str
354
- if not None, save the plot in the file
355
-
356
- Returns
357
- -------
358
- a dictionary Dict[str,float]
359
- """
360
- results, v = self._marginal(train, sample_size)
361
- res = {}
362
- for col in results.columns:
363
- res[col] = abs(results[col]).mean()
364
-
365
- self._plotResults(results, v, plot, plot_importance, percentage)
366
-
367
- if plot or plot_importance:
368
- if filename is not None:
369
- plt.savefig(filename)
370
- else:
371
- plt.show()
372
- plt.close()
373
-
374
- return res
375
-
376
- ################################## MUTILATION ######################################
377
-
378
- @staticmethod
379
- def _mutilation_Inference(bn, feats_name, target):
380
- ie = gum.LazyPropagation(bn)
381
- ie.addTarget(target)
382
- for name in feats_name:
383
- ie.addEvidence(name, 0)
384
- return ie
385
-
386
- def _causal(self, train):
387
- v = train.groupby(self.feats_names).agg(freq=(self.feats_names[0], "count")).reset_index()
388
- ie = self._init_Inference()
389
-
390
- convert = self._get_list_names_order()
391
- df = pd.DataFrame()
392
-
393
- v["Baseline"] = self._evidenceImpact({}, ie)
394
-
395
- for coal in self._get_all_coal_compress():
396
- for i in range(len(v)):
397
- bn_temp = gum.BayesNet(self.bn)
398
- S = list(coal)
399
- condi = {}
400
- for var in S:
401
- condi[var] = v.loc[i, var]
402
- for parent in bn_temp.parents(var):
403
- bn_temp.eraseArc(parent, bn_temp.idFromName(var))
404
- ie = self._mutilation_Inference(bn_temp, self.feats_names, self.target)
405
- col_arr_name = self._coal_encoding(convert, coal)
406
- v.loc[i, f"{col_arr_name}"] = self._evidenceImpact(condi, ie)
407
- for feat in self.feats_names:
408
- list_i_last = self.feats_names.copy()
409
- index_i = list_i_last.index(feat)
410
- list_i_last[len(list_i_last) - 1], list_i_last[index_i] = list_i_last[index_i], list_i_last[len(list_i_last) - 1]
411
-
412
- somme = 0
413
- for coal1, coal2 in self._gen_coalitions2(list_i_last):
414
- S_U_i = self._coal_encoding(convert, list(coal1))
415
- S = self._coal_encoding(convert, list(coal2))
416
- size_S = sum(S)
417
- somme = somme + self._compute_SHAP_i(S_U_i, S, v, size_S)
418
- df[feat] = somme
419
-
420
- self.results = df
421
- return df, v
422
-
423
- def causal(self, train, plot=False, plot_importance=False, percentage=False, filename=None):
424
- """
425
- Compute the causal Shap Values for each variables.
426
-
427
- Parameters
428
- ----------
429
- train :pandas.DataFrame
430
- the database
431
- plot: bool
432
- if True, plot the violin graph of the shap values
433
- plot_importance: bool
434
- if True, plot the importance plot
435
- percentage: bool
436
- if True, the importance plot is shown in percent.
437
- filename: str
438
- if not None, save the plot in the file
439
-
440
- Returns
441
- -------
442
- a dictionary Dict[str,float]
443
- """
444
- results, v = self._causal(train)
445
-
446
- res = {}
447
- for col in results.columns:
448
- res[col] = abs(results[col]).mean()
449
-
450
- self._plotResults(results, v, plot, plot_importance, percentage)
451
-
452
- if plot or plot_importance:
453
- if filename is not None:
454
- plt.savefig(filename)
455
- else:
456
- plt.show()
457
- plt.close()
458
-
459
- return res
460
-
461
- ################################## PLOT SHAP Value ##################################
462
-
463
- def _plotResults(self, results, v, plot=True, plot_importance=False, percentage=False):
464
- ax1 = ax2 = None
465
- if plot and plot_importance:
466
- fig = plt.figure(figsize=(15, 0.5 * len(results.columns)))
467
- ax1 = fig.add_subplot(1, 2, 1)
468
- ax2 = fig.add_subplot(1, 2, 2)
469
- if plot:
470
- shap_dict = results.to_dict(orient="list")
471
- sorted_dict = dict(sorted(shap_dict.items(), key=lambda x: sum(abs(i) for i in x[1]) / len(x[1])))
472
- data = np.array([sorted_dict[key] for key in sorted_dict])
473
- features = sorted_dict.keys()
474
- v = v[features]
475
- colors = v.transpose().to_numpy()
476
- self._plot_beeswarm_(data, colors, 250, 1.5, features, ax=ax1)
477
- if plot_importance:
478
- self._plot_importance(results, percentage, ax=ax2)
479
-
480
- @staticmethod
481
- def _plot_beeswarm_(data, colors, N, K, features, cmap=None, ax=None):
482
- """
483
- returns a beeswarm plot (or stripplot) from a given data.
484
-
485
- Parameters
486
- ----------
487
- data: list of numpy array.
488
- Each elements of the list is a numpy array containing the shapley values for the feature to be displayed for a category.
489
-
490
- colors: list of numpy array.
491
- Each elements of the list is a numpy array containing the values of the data point to be displayed for a category.
492
-
493
- Returns
494
- -------
495
- matplotlib.pyplot.scatter
496
- """
497
- min_value = np.min(data, axis=(0, 1))
498
- max_value = np.max(data, axis=(0, 1))
499
- bin_size = (max_value - min_value) / N
500
- if bin_size == 0:
501
- bin_size = 1
502
- horiz_shift = K * bin_size
503
-
504
- if ax is None:
505
- _, ax = plt.subplots()
506
- if cmap is None:
507
- # Set Color Map
508
- ## Define the hex colors
509
- color1 = gum.config["notebook", "tensor_color_0"]
510
- color2 = gum.config["notebook", "tensor_color_1"]
511
-
512
- ## Create the custom colormap
513
- cmap = mcolors.LinearSegmentedColormap.from_list("", [color1, color2])
514
-
515
- for n, d in enumerate(data):
516
- pos = n + 1
517
-
518
- d_shifted = d + np.random.normal(0, horiz_shift, len(d))
519
- # Sorting values
520
- d_sort = np.sort(d_shifted)
521
-
522
- # Creation of bins
523
- d_x = bin_size
524
- if (np.max(d_sort) - np.min(d_sort)) % d_x == 0:
525
- nb_bins = (np.max(d_sort) - np.min(d_sort)) // d_x
526
- else:
527
- nb_bins = (np.max(d_sort) - np.min(d_sort)) // d_x + 1
528
- bins = [np.min(d_sort) + i * d_x for i in range(int(nb_bins) + 1)]
529
-
530
- # Group by Bins
531
- subarr = []
532
- for k in range(1, len(bins)):
533
- group = d_sort[np.logical_and(d_sort >= bins[k - 1], d_sort < bins[k])]
534
- subarr.append(group)
535
-
536
- # For each bin compute the d_y (vertical shift)
537
- d_y = 0.025
538
- subarr_jitter = deepcopy(subarr)
539
- for i in range(len(subarr)):
540
- L = subarr[i].size
541
- if L > 0:
542
- for j in range(L):
543
- shift = d_y * (L - 1) / 2 - d_y * j
544
- subarr_jitter[i][j] = shift
545
-
546
- jitter = np.concatenate(subarr_jitter)
547
-
548
- sc = ax.scatter(d_shifted, pos + jitter, s=10, c=colors[n], cmap=cmap, alpha=0.7)
549
-
550
- ## Create the colorbar
551
- divider = make_axes_locatable(ax)
552
- cax = divider.append_axes("right", size="5%", pad=0.1)
553
-
554
- cbar = plt.colorbar(sc, cax=cax, aspect=80)
555
- cbar.set_label("Data Point Value")
556
-
557
- ## Add text above and below the colorbar
558
- cax.text(0.5, 1.025, "High", transform=cax.transAxes, ha="center", va="center", fontsize=10)
559
- cax.text(0.5, -0.025, "Low", transform=cax.transAxes, ha="center", va="center", fontsize=10)
560
-
561
- ## Add x-axis tick labels
562
- ax.set_yticks([i + 1 for i in range(len(features))])
563
- ax.set_yticklabels(features)
564
-
565
- ## Set axis labels and title
566
- ax.set_ylabel("Features")
567
- ax.set_xlabel("Impact on output")
568
- ax.set_title("Shapley value (impact on model output)")
569
-
570
- # Show plot
571
- return ax.get_figure()
572
-
573
- @staticmethod
574
- def _plot_importance(results, percentage=False, ax=None):
575
- series = pd.DataFrame(abs(results).mean(), columns=["value"])
576
- series["feat"] = series.index
577
-
578
- if ax is None:
579
- _, ax = plt.subplots()
580
-
581
- if percentage:
582
- series["value"] = series["value"].div(series["value"].sum(axis=0)).multiply(100)
583
- series = series.sort_values("value", ascending=True)
584
- ax.barh(series.feat, series.value, align="center")
585
- ax.set_xlabel("Mean(|SHAP value|)")
586
- ax.set_title("Feature Importance in %")
587
- else:
588
- series = series.sort_values("value", ascending=True)
589
- ax.barh(series.feat, series.value, align="center")
590
- ax.set_xlabel("Mean(|SHAP value|)")
591
- ax.set_title("Feature Importance")
592
-
593
- return ax.get_figure()
594
-
595
- @staticmethod
596
- def _plot_scatter(results, ax=None):
597
- if ax is None:
598
- _, ax = plt.subplots()
599
-
600
- res = {}
601
- for col in results.columns:
602
- res[col] = results[col].to_numpy()
603
- names = list(res.keys())
604
- values = list(res.values())
605
- for xe, ye in zip(names, values):
606
- ax.scatter(ye, [xe] * len(ye))
607
- ax.set_title("Shapley value (impact on model output)")
608
- return ax.get_figure()
609
-
610
- @staticmethod
611
- def _plot_violin(results, ax=None):
612
- data = []
613
- pos = []
614
- label = []
615
- series = pd.DataFrame(abs(results).mean(), columns=["value"])
616
- series = series.sort_values("value", ascending=True)
617
- series["feat"] = series.index
618
- for i, col in enumerate(series.feat):
619
- data.append(results[col].to_numpy())
620
- pos.append(i)
621
- label.append(col)
622
- if ax is None:
623
- _, ax = plt.subplots()
624
- ax.violinplot(data, pos, vert=False)
625
- ax.set_yticks(pos)
626
- ax.set_yticklabels(label)
627
- ax.set_title("Shapley value (impact on model output)")
628
- return ax.get_figure()
629
-
630
-
631
- def getShapValues(bn, shaps, cmap="plasma"):
632
- """
633
- Just a wrapper around BN2dot to easily show the Shap values
634
-
635
- Parameters
636
- ----------
637
- bn : pyagrum.BayesNet
638
- The Bayesian network
639
- shaps: dict[str,float]
640
- The (Shap) values associates to each variable
641
- cmap: Matplotlib.ColorMap
642
- The colormap used for colouring the nodes
643
-
644
- Returns
645
- -------
646
- a pydot.graph
647
- """
648
- from pyagrum.lib.bn2graph import BN2dot
649
-
650
- norm_color = {}
651
- raw = list(shaps.values())
652
- norm = [float(i) / sum(raw) for i in raw]
653
- for i, feat in enumerate(list(shaps.keys())):
654
- norm_color[feat] = norm[i]
655
- cm = plt.get_cmap(cmap)
656
- g = BN2dot(bn, nodeColor=norm_color, cmapNode=cm)
657
- return g