explainiverse 0.2.0__tar.gz → 0.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. {explainiverse-0.2.0 → explainiverse-0.2.1}/PKG-INFO +1 -1
  2. {explainiverse-0.2.0 → explainiverse-0.2.1}/pyproject.toml +1 -1
  3. {explainiverse-0.2.0 → explainiverse-0.2.1}/src/explainiverse/__init__.py +1 -1
  4. {explainiverse-0.2.0 → explainiverse-0.2.1}/src/explainiverse/core/registry.py +18 -0
  5. {explainiverse-0.2.0 → explainiverse-0.2.1}/src/explainiverse/explainers/__init__.py +4 -1
  6. {explainiverse-0.2.0 → explainiverse-0.2.1}/src/explainiverse/explainers/attribution/__init__.py +2 -1
  7. explainiverse-0.2.1/src/explainiverse/explainers/attribution/treeshap_wrapper.py +434 -0
  8. {explainiverse-0.2.0 → explainiverse-0.2.1}/LICENSE +0 -0
  9. {explainiverse-0.2.0 → explainiverse-0.2.1}/README.md +0 -0
  10. {explainiverse-0.2.0 → explainiverse-0.2.1}/src/explainiverse/adapters/__init__.py +0 -0
  11. {explainiverse-0.2.0 → explainiverse-0.2.1}/src/explainiverse/adapters/base_adapter.py +0 -0
  12. {explainiverse-0.2.0 → explainiverse-0.2.1}/src/explainiverse/adapters/sklearn_adapter.py +0 -0
  13. {explainiverse-0.2.0 → explainiverse-0.2.1}/src/explainiverse/core/__init__.py +0 -0
  14. {explainiverse-0.2.0 → explainiverse-0.2.1}/src/explainiverse/core/explainer.py +0 -0
  15. {explainiverse-0.2.0 → explainiverse-0.2.1}/src/explainiverse/core/explanation.py +0 -0
  16. {explainiverse-0.2.0 → explainiverse-0.2.1}/src/explainiverse/engine/__init__.py +0 -0
  17. {explainiverse-0.2.0 → explainiverse-0.2.1}/src/explainiverse/engine/suite.py +0 -0
  18. {explainiverse-0.2.0 → explainiverse-0.2.1}/src/explainiverse/evaluation/__init__.py +0 -0
  19. {explainiverse-0.2.0 → explainiverse-0.2.1}/src/explainiverse/evaluation/metrics.py +0 -0
  20. {explainiverse-0.2.0 → explainiverse-0.2.1}/src/explainiverse/explainers/attribution/lime_wrapper.py +0 -0
  21. {explainiverse-0.2.0 → explainiverse-0.2.1}/src/explainiverse/explainers/attribution/shap_wrapper.py +0 -0
  22. {explainiverse-0.2.0 → explainiverse-0.2.1}/src/explainiverse/explainers/counterfactual/__init__.py +0 -0
  23. {explainiverse-0.2.0 → explainiverse-0.2.1}/src/explainiverse/explainers/counterfactual/dice_wrapper.py +0 -0
  24. {explainiverse-0.2.0 → explainiverse-0.2.1}/src/explainiverse/explainers/global_explainers/__init__.py +0 -0
  25. {explainiverse-0.2.0 → explainiverse-0.2.1}/src/explainiverse/explainers/global_explainers/ale.py +0 -0
  26. {explainiverse-0.2.0 → explainiverse-0.2.1}/src/explainiverse/explainers/global_explainers/partial_dependence.py +0 -0
  27. {explainiverse-0.2.0 → explainiverse-0.2.1}/src/explainiverse/explainers/global_explainers/permutation_importance.py +0 -0
  28. {explainiverse-0.2.0 → explainiverse-0.2.1}/src/explainiverse/explainers/global_explainers/sage.py +0 -0
  29. {explainiverse-0.2.0 → explainiverse-0.2.1}/src/explainiverse/explainers/rule_based/__init__.py +0 -0
  30. {explainiverse-0.2.0 → explainiverse-0.2.1}/src/explainiverse/explainers/rule_based/anchors_wrapper.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: explainiverse
3
- Version: 0.2.0
3
+ Version: 0.2.1
4
4
  Summary: Unified, extensible explainability framework supporting LIME, SHAP, Anchors, Counterfactuals, PDP, ALE, SAGE, and more
5
5
  Home-page: https://github.com/jemsbhai/explainiverse
6
6
  License: MIT
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "explainiverse"
3
- version = "0.2.0"
3
+ version = "0.2.1"
4
4
  description = "Unified, extensible explainability framework supporting LIME, SHAP, Anchors, Counterfactuals, PDP, ALE, SAGE, and more"
5
5
  authors = ["Muntaser Syed <jemsbhai@gmail.com>"]
6
6
  license = "MIT"
@@ -27,7 +27,7 @@ from explainiverse.core.registry import (
27
27
  from explainiverse.adapters.sklearn_adapter import SklearnAdapter
28
28
  from explainiverse.engine.suite import ExplanationSuite
29
29
 
30
- __version__ = "0.2.0"
30
+ __version__ = "0.2.1"
31
31
 
32
32
  __all__ = [
33
33
  # Core
@@ -362,6 +362,7 @@ def _create_default_registry() -> ExplainerRegistry:
362
362
  """Create and populate the default global registry."""
363
363
  from explainiverse.explainers.attribution.lime_wrapper import LimeExplainer
364
364
  from explainiverse.explainers.attribution.shap_wrapper import ShapExplainer
365
+ from explainiverse.explainers.attribution.treeshap_wrapper import TreeShapExplainer
365
366
  from explainiverse.explainers.rule_based.anchors_wrapper import AnchorsExplainer
366
367
  from explainiverse.explainers.global_explainers.permutation_importance import PermutationImportanceExplainer
367
368
  from explainiverse.explainers.global_explainers.partial_dependence import PartialDependenceExplainer
@@ -409,6 +410,23 @@ def _create_default_registry() -> ExplainerRegistry:
409
410
  )
410
411
  )
411
412
 
413
+ # Register TreeSHAP (optimized for tree models)
414
+ registry.register(
415
+ name="treeshap",
416
+ explainer_class=TreeShapExplainer,
417
+ meta=ExplainerMeta(
418
+ scope="local",
419
+ model_types=["tree", "ensemble"],
420
+ data_types=["tabular"],
421
+ task_types=["classification", "regression"],
422
+ description="TreeSHAP - exact SHAP values for tree-based models (RandomForest, XGBoost, etc.)",
423
+ paper_reference="Lundberg et al., 2018 - 'Consistent Individualized Feature Attribution for Tree Ensembles'",
424
+ complexity="O(TLD^2) - polynomial in tree depth",
425
+ requires_training_data=False,
426
+ supports_batching=True
427
+ )
428
+ )
429
+
412
430
  # Register Anchors
413
431
  registry.register(
414
432
  name="anchors",
@@ -4,7 +4,8 @@ Explainiverse Explainers - comprehensive XAI method implementations.
4
4
 
5
5
  Local Explainers (instance-level):
6
6
  - LIME: Local Interpretable Model-agnostic Explanations
7
- - SHAP: SHapley Additive exPlanations
7
+ - SHAP: SHapley Additive exPlanations (KernelSHAP - model-agnostic)
8
+ - TreeSHAP: Optimized exact SHAP for tree-based models
8
9
  - Anchors: High-precision rule-based explanations
9
10
  - Counterfactual: Diverse counterfactual explanations
10
11
 
@@ -17,6 +18,7 @@ Global Explainers (model-level):
17
18
 
18
19
  from explainiverse.explainers.attribution.lime_wrapper import LimeExplainer
19
20
  from explainiverse.explainers.attribution.shap_wrapper import ShapExplainer
21
+ from explainiverse.explainers.attribution.treeshap_wrapper import TreeShapExplainer
20
22
  from explainiverse.explainers.rule_based.anchors_wrapper import AnchorsExplainer
21
23
  from explainiverse.explainers.counterfactual.dice_wrapper import CounterfactualExplainer
22
24
  from explainiverse.explainers.global_explainers.permutation_importance import PermutationImportanceExplainer
@@ -28,6 +30,7 @@ __all__ = [
28
30
  # Local explainers
29
31
  "LimeExplainer",
30
32
  "ShapExplainer",
33
+ "TreeShapExplainer",
31
34
  "AnchorsExplainer",
32
35
  "CounterfactualExplainer",
33
36
  # Global explainers
@@ -5,5 +5,6 @@ Attribution-based explainers - feature importance explanations.
5
5
 
6
6
  from explainiverse.explainers.attribution.lime_wrapper import LimeExplainer
7
7
  from explainiverse.explainers.attribution.shap_wrapper import ShapExplainer
8
+ from explainiverse.explainers.attribution.treeshap_wrapper import TreeShapExplainer
8
9
 
9
- __all__ = ["LimeExplainer", "ShapExplainer"]
10
+ __all__ = ["LimeExplainer", "ShapExplainer", "TreeShapExplainer"]
@@ -0,0 +1,434 @@
1
+ # src/explainiverse/explainers/attribution/treeshap_wrapper.py
2
+ """
3
+ TreeSHAP Explainer - Optimized SHAP for Tree-based Models.
4
+
5
+ TreeSHAP computes exact SHAP values in polynomial time for tree-based models,
6
+ making it significantly faster than KernelSHAP while providing exact (not
7
+ approximate) Shapley values.
8
+
9
+ Reference:
10
+ Lundberg, S.M., Erion, G.G., & Lee, S.I. (2018). Consistent Individualized
11
+ Feature Attribution for Tree Ensembles. arXiv:1802.03888.
12
+
13
+ Supported Models:
14
+ - scikit-learn: RandomForest, GradientBoosting, DecisionTree, ExtraTrees
15
+ - XGBoost: XGBClassifier, XGBRegressor
16
+ - LightGBM: LGBMClassifier, LGBMRegressor (if installed)
17
+ - CatBoost: CatBoostClassifier, CatBoostRegressor (if installed)
18
+ """
19
+
20
+ import numpy as np
21
+ import shap
22
+ from typing import List, Optional, Union
23
+
24
+ from explainiverse.core.explainer import BaseExplainer
25
+ from explainiverse.core.explanation import Explanation
26
+
27
+
28
+ # Tree-based model types that TreeSHAP supports
29
+ SUPPORTED_TREE_MODELS = (
30
+ "RandomForestClassifier",
31
+ "RandomForestRegressor",
32
+ "GradientBoostingClassifier",
33
+ "GradientBoostingRegressor",
34
+ "DecisionTreeClassifier",
35
+ "DecisionTreeRegressor",
36
+ "ExtraTreesClassifier",
37
+ "ExtraTreesRegressor",
38
+ "XGBClassifier",
39
+ "XGBRegressor",
40
+ "XGBRFClassifier",
41
+ "XGBRFRegressor",
42
+ "LGBMClassifier",
43
+ "LGBMRegressor",
44
+ "CatBoostClassifier",
45
+ "CatBoostRegressor",
46
+ "HistGradientBoostingClassifier",
47
+ "HistGradientBoostingRegressor",
48
+ )
49
+
50
+
51
+ def _is_tree_model(model) -> bool:
52
+ """Check if a model is a supported tree-based model."""
53
+ model_name = type(model).__name__
54
+ return model_name in SUPPORTED_TREE_MODELS
55
+
56
+
57
+ def _get_raw_model(model):
58
+ """
59
+ Extract the raw model from an adapter if necessary.
60
+
61
+ TreeExplainer needs the actual sklearn/xgboost model, not an adapter.
62
+ """
63
+ # If it's an adapter, get the underlying model
64
+ if hasattr(model, 'model'):
65
+ return model.model
66
+ return model
67
+
68
+
69
+ class TreeShapExplainer(BaseExplainer):
70
+ """
71
+ TreeSHAP explainer for tree-based models.
72
+
73
+ Uses SHAP's TreeExplainer to compute exact SHAP values in polynomial time.
74
+ This is significantly faster than KernelSHAP for supported tree models
75
+ and provides exact Shapley values rather than approximations.
76
+
77
+ Key advantages over KernelSHAP:
78
+ - Exact SHAP values (not approximations)
79
+ - O(TLD²) complexity vs O(TL2^M) for KernelSHAP
80
+ - Can compute interaction values
81
+ - No background data sampling needed
82
+
83
+ Attributes:
84
+ model: The tree-based model (sklearn, XGBoost, LightGBM, or CatBoost)
85
+ feature_names: List of feature names
86
+ class_names: List of class names for classification
87
+ explainer: The underlying SHAP TreeExplainer
88
+ task: "classification" or "regression"
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ model,
94
+ feature_names: List[str],
95
+ class_names: Optional[List[str]] = None,
96
+ background_data: Optional[np.ndarray] = None,
97
+ task: str = "classification",
98
+ model_output: str = "auto",
99
+ feature_perturbation: str = "tree_path_dependent"
100
+ ):
101
+ """
102
+ Initialize the TreeSHAP explainer.
103
+
104
+ Args:
105
+ model: A tree-based model or adapter containing one.
106
+ Supported: RandomForest, GradientBoosting, XGBoost,
107
+ LightGBM, CatBoost, DecisionTree, ExtraTrees.
108
+ feature_names: List of feature names.
109
+ class_names: List of class names (for classification).
110
+ background_data: Optional background dataset for interventional
111
+ feature perturbation. If None, uses tree_path_dependent.
112
+ task: "classification" or "regression".
113
+ model_output: How to transform model output. Options:
114
+ - "auto": Automatically detect
115
+ - "raw": Raw model output
116
+ - "probability": Probability output (classification)
117
+ - "log_loss": Log loss output
118
+ feature_perturbation: Method for handling feature perturbation:
119
+ - "tree_path_dependent": Fast, uses tree structure
120
+ - "interventional": Slower, requires background data
121
+ """
122
+ # Extract raw model if wrapped in adapter
123
+ raw_model = _get_raw_model(model)
124
+
125
+ # Validate that it's a supported tree model
126
+ if not _is_tree_model(raw_model):
127
+ model_type = type(raw_model).__name__
128
+ raise ValueError(
129
+ f"TreeSHAP requires a tree-based model. Got {model_type}. "
130
+ f"Supported models: {', '.join(SUPPORTED_TREE_MODELS[:6])}..."
131
+ )
132
+
133
+ super().__init__(model)
134
+ self.raw_model = raw_model
135
+ self.feature_names = list(feature_names)
136
+ self.class_names = list(class_names) if class_names else None
137
+ self.task = task
138
+ self.model_output = model_output
139
+ self.feature_perturbation = feature_perturbation
140
+
141
+ # Create TreeExplainer
142
+ explainer_kwargs = {}
143
+
144
+ if feature_perturbation == "interventional" and background_data is not None:
145
+ explainer_kwargs["data"] = background_data
146
+ explainer_kwargs["feature_perturbation"] = "interventional"
147
+
148
+ if model_output != "auto":
149
+ explainer_kwargs["model_output"] = model_output
150
+
151
+ self.explainer = shap.TreeExplainer(raw_model, **explainer_kwargs)
152
+ self.background_data = background_data
153
+
154
+ def explain(
155
+ self,
156
+ instance: np.ndarray,
157
+ target_class: Optional[int] = None,
158
+ check_additivity: bool = False
159
+ ) -> Explanation:
160
+ """
161
+ Generate TreeSHAP explanation for a single instance.
162
+
163
+ Args:
164
+ instance: 1D numpy array of input features.
165
+ target_class: For multi-class, which class to explain.
166
+ If None, uses the predicted class.
167
+ check_additivity: Whether to verify SHAP values sum to
168
+ prediction - expected_value.
169
+
170
+ Returns:
171
+ Explanation object with feature attributions.
172
+ """
173
+ instance = np.array(instance).flatten()
174
+ instance_2d = instance.reshape(1, -1)
175
+
176
+ # Compute SHAP values
177
+ shap_values = self.explainer.shap_values(
178
+ instance_2d,
179
+ check_additivity=check_additivity
180
+ )
181
+
182
+ # Handle different output formats
183
+ if isinstance(shap_values, list):
184
+ # Multi-class classification: list of arrays, one per class
185
+ n_classes = len(shap_values)
186
+
187
+ if target_class is None:
188
+ # Use predicted class
189
+ if hasattr(self.raw_model, 'predict'):
190
+ pred = self.raw_model.predict(instance_2d)[0]
191
+ target_class = int(pred)
192
+ else:
193
+ target_class = 0
194
+
195
+ # Ensure target_class is valid
196
+ target_class = min(target_class, n_classes - 1)
197
+ class_shap = shap_values[target_class][0]
198
+
199
+ # Get class name
200
+ if self.class_names and target_class < len(self.class_names):
201
+ label_name = self.class_names[target_class]
202
+ else:
203
+ label_name = f"class_{target_class}"
204
+
205
+ # Store all class SHAP values for reference
206
+ all_class_shap = {
207
+ (self.class_names[i] if self.class_names and i < len(self.class_names)
208
+ else f"class_{i}"): shap_values[i][0].tolist()
209
+ for i in range(n_classes)
210
+ }
211
+ else:
212
+ # Binary classification or regression
213
+ class_shap = shap_values[0] if shap_values.ndim > 1 else shap_values.flatten()
214
+ label_name = self.class_names[1] if self.class_names and len(self.class_names) > 1 else "output"
215
+ all_class_shap = None
216
+
217
+ # Build attributions dict
218
+ flat_shap = np.array(class_shap).flatten()
219
+ attributions = {
220
+ fname: float(flat_shap[i])
221
+ for i, fname in enumerate(self.feature_names)
222
+ }
223
+
224
+ # Get expected value (base value)
225
+ expected_value = self.explainer.expected_value
226
+ if isinstance(expected_value, (list, np.ndarray)):
227
+ if target_class is not None and target_class < len(expected_value):
228
+ base_value = float(expected_value[target_class])
229
+ else:
230
+ base_value = float(expected_value[0])
231
+ else:
232
+ base_value = float(expected_value)
233
+
234
+ explanation_data = {
235
+ "feature_attributions": attributions,
236
+ "base_value": base_value,
237
+ "shap_values_raw": flat_shap.tolist(),
238
+ }
239
+
240
+ if all_class_shap is not None:
241
+ explanation_data["all_class_shap_values"] = all_class_shap
242
+
243
+ return Explanation(
244
+ explainer_name="TreeSHAP",
245
+ target_class=label_name,
246
+ explanation_data=explanation_data
247
+ )
248
+
249
+ def explain_batch(
250
+ self,
251
+ X: np.ndarray,
252
+ target_class: Optional[int] = None,
253
+ check_additivity: bool = False
254
+ ) -> List[Explanation]:
255
+ """
256
+ Generate TreeSHAP explanations for multiple instances efficiently.
257
+
258
+ TreeSHAP can process batches more efficiently than individual calls.
259
+
260
+ Args:
261
+ X: 2D numpy array of instances (n_samples, n_features).
262
+ target_class: For multi-class, which class to explain.
263
+ check_additivity: Whether to verify SHAP value additivity.
264
+
265
+ Returns:
266
+ List of Explanation objects.
267
+ """
268
+ X = np.array(X)
269
+ if X.ndim == 1:
270
+ X = X.reshape(1, -1)
271
+
272
+ # Compute SHAP values for all instances at once
273
+ shap_values = self.explainer.shap_values(X, check_additivity=check_additivity)
274
+
275
+ explanations = []
276
+ for i in range(X.shape[0]):
277
+ if isinstance(shap_values, list):
278
+ # Multi-class
279
+ n_classes = len(shap_values)
280
+ tc = target_class if target_class is not None else 0
281
+ tc = min(tc, n_classes - 1)
282
+ class_shap = shap_values[tc][i]
283
+
284
+ if self.class_names and tc < len(self.class_names):
285
+ label_name = self.class_names[tc]
286
+ else:
287
+ label_name = f"class_{tc}"
288
+ else:
289
+ class_shap = shap_values[i]
290
+ label_name = self.class_names[1] if self.class_names and len(self.class_names) > 1 else "output"
291
+
292
+ flat_shap = np.array(class_shap).flatten()
293
+ attributions = {
294
+ fname: float(flat_shap[j])
295
+ for j, fname in enumerate(self.feature_names)
296
+ }
297
+
298
+ expected_value = self.explainer.expected_value
299
+ if isinstance(expected_value, (list, np.ndarray)):
300
+ tc = target_class if target_class is not None else 0
301
+ base_value = float(expected_value[min(tc, len(expected_value) - 1)])
302
+ else:
303
+ base_value = float(expected_value)
304
+
305
+ explanations.append(Explanation(
306
+ explainer_name="TreeSHAP",
307
+ target_class=label_name,
308
+ explanation_data={
309
+ "feature_attributions": attributions,
310
+ "base_value": base_value,
311
+ "shap_values_raw": flat_shap.tolist(),
312
+ }
313
+ ))
314
+
315
+ return explanations
316
+
317
+ def explain_interactions(
318
+ self,
319
+ instance: np.ndarray,
320
+ target_class: Optional[int] = None
321
+ ) -> Explanation:
322
+ """
323
+ Compute SHAP interaction values for an instance.
324
+
325
+ Interaction values show how pairs of features jointly contribute
326
+ to the prediction. The diagonal contains main effects.
327
+
328
+ Args:
329
+ instance: 1D numpy array of input features.
330
+ target_class: For multi-class, which class to explain.
331
+
332
+ Returns:
333
+ Explanation object with interaction matrix.
334
+ """
335
+ instance = np.array(instance).flatten()
336
+ instance_2d = instance.reshape(1, -1)
337
+
338
+ # Compute interaction values
339
+ interaction_values = self.explainer.shap_interaction_values(instance_2d)
340
+
341
+ # Determine target class for prediction
342
+ if target_class is None and hasattr(self.raw_model, 'predict'):
343
+ target_class = int(self.raw_model.predict(instance_2d)[0])
344
+ elif target_class is None:
345
+ target_class = 0
346
+
347
+ # Handle different return formats from shap_interaction_values
348
+ if isinstance(interaction_values, list):
349
+ # Multi-class: list of arrays, one per class
350
+ n_classes = len(interaction_values)
351
+ tc = min(target_class, n_classes - 1)
352
+ interactions = np.array(interaction_values[tc][0])
353
+
354
+ if self.class_names and tc < len(self.class_names):
355
+ label_name = self.class_names[tc]
356
+ else:
357
+ label_name = f"class_{tc}"
358
+ elif interaction_values.ndim == 4:
359
+ # Shape: (n_samples, n_features, n_features, n_classes)
360
+ n_classes = interaction_values.shape[3]
361
+ tc = min(target_class, n_classes - 1)
362
+ interactions = interaction_values[0, :, :, tc]
363
+
364
+ if self.class_names and tc < len(self.class_names):
365
+ label_name = self.class_names[tc]
366
+ else:
367
+ label_name = f"class_{tc}"
368
+ else:
369
+ # Binary or regression: (n_samples, n_features, n_features)
370
+ interactions = interaction_values[0]
371
+ label_name = self.class_names[1] if self.class_names and len(self.class_names) > 1 else "output"
372
+
373
+ # Ensure interactions is 2D (n_features x n_features)
374
+ interactions = np.array(interactions)
375
+ if interactions.ndim > 2:
376
+ # If still multi-dimensional, take first slice
377
+ interactions = interactions[:, :, 0] if interactions.ndim == 3 else interactions
378
+
379
+ # Build interaction dict with feature name pairs
380
+ n_features = len(self.feature_names)
381
+ interaction_dict = {}
382
+ main_effects = {}
383
+
384
+ for i in range(n_features):
385
+ fname_i = self.feature_names[i]
386
+ val = interactions[i, i]
387
+ main_effects[fname_i] = float(val) if np.isscalar(val) or val.size == 1 else float(val.flat[0])
388
+
389
+ for j in range(i + 1, n_features):
390
+ fname_j = self.feature_names[j]
391
+ # Interaction values are symmetric, so we sum both directions
392
+ val_ij = interactions[i, j]
393
+ val_ji = interactions[j, i]
394
+ ij = float(val_ij) if np.isscalar(val_ij) or val_ij.size == 1 else float(val_ij.flat[0])
395
+ ji = float(val_ji) if np.isscalar(val_ji) or val_ji.size == 1 else float(val_ji.flat[0])
396
+ interaction_dict[f"{fname_i} x {fname_j}"] = ij + ji
397
+
398
+ # Sort interactions by absolute value
399
+ sorted_interactions = dict(sorted(
400
+ interaction_dict.items(),
401
+ key=lambda x: abs(x[1]),
402
+ reverse=True
403
+ ))
404
+
405
+ return Explanation(
406
+ explainer_name="TreeSHAP_Interactions",
407
+ target_class=label_name,
408
+ explanation_data={
409
+ "feature_attributions": main_effects,
410
+ "interactions": sorted_interactions,
411
+ "interaction_matrix": interactions.tolist(),
412
+ "feature_names": self.feature_names
413
+ }
414
+ )
415
+
416
+ def get_expected_value(self, target_class: Optional[int] = None) -> float:
417
+ """
418
+ Get the expected (base) value of the model.
419
+
420
+ This is the average model output over the background dataset.
421
+
422
+ Args:
423
+ target_class: For multi-class, which class's expected value.
424
+
425
+ Returns:
426
+ The expected value as a float.
427
+ """
428
+ expected_value = self.explainer.expected_value
429
+
430
+ if isinstance(expected_value, (list, np.ndarray)):
431
+ tc = target_class if target_class is not None else 0
432
+ return float(expected_value[min(tc, len(expected_value) - 1)])
433
+
434
+ return float(expected_value)
File without changes
File without changes