explainiverse 0.7.0__tar.gz → 0.7.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. {explainiverse-0.7.0 → explainiverse-0.7.1}/PKG-INFO +2 -2
  2. {explainiverse-0.7.0 → explainiverse-0.7.1}/README.md +1 -1
  3. {explainiverse-0.7.0 → explainiverse-0.7.1}/pyproject.toml +1 -1
  4. {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/__init__.py +1 -1
  5. {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/adapters/pytorch_adapter.py +88 -25
  6. explainiverse-0.7.1/src/explainiverse/core/explanation.py +179 -0
  7. explainiverse-0.7.1/src/explainiverse/engine/suite.py +252 -0
  8. explainiverse-0.7.1/src/explainiverse/evaluation/metrics.py +314 -0
  9. {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/attribution/lime_wrapper.py +90 -7
  10. explainiverse-0.7.1/src/explainiverse/explainers/attribution/shap_wrapper.py +185 -0
  11. {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/gradient/integrated_gradients.py +189 -76
  12. explainiverse-0.7.0/src/explainiverse/core/explanation.py +0 -24
  13. explainiverse-0.7.0/src/explainiverse/engine/suite.py +0 -143
  14. explainiverse-0.7.0/src/explainiverse/evaluation/metrics.py +0 -233
  15. explainiverse-0.7.0/src/explainiverse/explainers/attribution/shap_wrapper.py +0 -89
  16. {explainiverse-0.7.0 → explainiverse-0.7.1}/LICENSE +0 -0
  17. {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/adapters/__init__.py +0 -0
  18. {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/adapters/base_adapter.py +0 -0
  19. {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/adapters/sklearn_adapter.py +0 -0
  20. {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/core/__init__.py +0 -0
  21. {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/core/explainer.py +0 -0
  22. {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/core/registry.py +0 -0
  23. {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/engine/__init__.py +0 -0
  24. {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/evaluation/__init__.py +0 -0
  25. {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/evaluation/_utils.py +0 -0
  26. {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/evaluation/faithfulness.py +0 -0
  27. {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/evaluation/stability.py +0 -0
  28. {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/__init__.py +0 -0
  29. {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/attribution/__init__.py +0 -0
  30. {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/attribution/treeshap_wrapper.py +0 -0
  31. {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/counterfactual/__init__.py +0 -0
  32. {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/counterfactual/dice_wrapper.py +0 -0
  33. {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/example_based/__init__.py +0 -0
  34. {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/example_based/protodash.py +0 -0
  35. {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/global_explainers/__init__.py +0 -0
  36. {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/global_explainers/ale.py +0 -0
  37. {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/global_explainers/partial_dependence.py +0 -0
  38. {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/global_explainers/permutation_importance.py +0 -0
  39. {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/global_explainers/sage.py +0 -0
  40. {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/gradient/__init__.py +0 -0
  41. {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/gradient/deeplift.py +0 -0
  42. {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/gradient/gradcam.py +0 -0
  43. {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/gradient/saliency.py +0 -0
  44. {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/gradient/smoothgrad.py +0 -0
  45. {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/gradient/tcav.py +0 -0
  46. {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/rule_based/__init__.py +0 -0
  47. {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/rule_based/anchors_wrapper.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: explainiverse
3
- Version: 0.7.0
3
+ Version: 0.7.1
4
4
  Summary: Unified, extensible explainability framework supporting LIME, SHAP, Anchors, Counterfactuals, PDP, ALE, SAGE, and more
5
5
  Home-page: https://github.com/jemsbhai/explainiverse
6
6
  License: MIT
@@ -671,7 +671,7 @@ If you use Explainiverse in your research, please cite:
671
671
  author = {Syed, Muntaser},
672
672
  year = {2025},
673
673
  url = {https://github.com/jemsbhai/explainiverse},
674
- version = {0.7.0}
674
+ version = {0.7.1}
675
675
  }
676
676
  ```
677
677
 
@@ -640,7 +640,7 @@ If you use Explainiverse in your research, please cite:
640
640
  author = {Syed, Muntaser},
641
641
  year = {2025},
642
642
  url = {https://github.com/jemsbhai/explainiverse},
643
- version = {0.7.0}
643
+ version = {0.7.1}
644
644
  }
645
645
  ```
646
646
 
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "explainiverse"
3
- version = "0.7.0"
3
+ version = "0.7.1"
4
4
  description = "Unified, extensible explainability framework supporting LIME, SHAP, Anchors, Counterfactuals, PDP, ALE, SAGE, and more"
5
5
  authors = ["Muntaser Syed <jemsbhai@gmail.com>"]
6
6
  license = "MIT"
@@ -33,7 +33,7 @@ from explainiverse.adapters.sklearn_adapter import SklearnAdapter
33
33
  from explainiverse.adapters import TORCH_AVAILABLE
34
34
  from explainiverse.engine.suite import ExplanationSuite
35
35
 
36
- __version__ = "0.7.0"
36
+ __version__ = "0.7.1"
37
37
 
38
38
  __all__ = [
39
39
  # Core
@@ -25,7 +25,7 @@ Example:
25
25
  """
26
26
 
27
27
  import numpy as np
28
- from typing import List, Optional, Union, Callable
28
+ from typing import List, Optional, Union, Tuple
29
29
 
30
30
  from .base_adapter import BaseModelAdapter
31
31
 
@@ -57,6 +57,11 @@ class PyTorchAdapter(BaseModelAdapter):
57
57
  explainability methods. Handles device management, tensor/numpy
58
58
  conversions, and supports both classification and regression tasks.
59
59
 
60
+ Supports:
61
+ - Multi-class classification (output shape: [batch, n_classes])
62
+ - Binary classification (output shape: [batch, 1] or [batch])
63
+ - Regression (output shape: [batch, n_outputs] or [batch])
64
+
60
65
  Attributes:
61
66
  model: The PyTorch model (nn.Module)
62
67
  task: "classification" or "regression"
@@ -150,11 +155,27 @@ class PyTorchAdapter(BaseModelAdapter):
150
155
  def _apply_activation(self, output: "torch.Tensor") -> "torch.Tensor":
151
156
  """Apply output activation function."""
152
157
  if self.output_activation == "softmax":
158
+ # Handle different output shapes
159
+ if output.dim() == 1 or (output.dim() == 2 and output.shape[1] == 1):
160
+ # Binary: apply sigmoid instead of softmax
161
+ return torch.sigmoid(output)
153
162
  return torch.softmax(output, dim=-1)
154
163
  elif self.output_activation == "sigmoid":
155
164
  return torch.sigmoid(output)
156
165
  return output
157
166
 
167
+ def _normalize_output_shape(self, output: "torch.Tensor") -> "torch.Tensor":
168
+ """
169
+ Normalize output to consistent 2D shape (batch, outputs).
170
+
171
+ Handles:
172
+ - (batch,) -> (batch, 1)
173
+ - (batch, n) -> (batch, n)
174
+ """
175
+ if output.dim() == 1:
176
+ return output.unsqueeze(-1)
177
+ return output
178
+
158
179
  def predict(self, data: np.ndarray) -> np.ndarray:
159
180
  """
160
181
  Generate predictions for input data.
@@ -183,16 +204,66 @@ class PyTorchAdapter(BaseModelAdapter):
183
204
  tensor_batch = self._to_tensor(batch)
184
205
 
185
206
  output = self.model(tensor_batch)
207
+ output = self._normalize_output_shape(output)
186
208
  output = self._apply_activation(output)
187
209
  outputs.append(self._to_numpy(output))
188
210
 
189
211
  return np.vstack(outputs)
190
212
 
213
+ def _get_target_scores(
214
+ self,
215
+ output: "torch.Tensor",
216
+ target_class: Optional[Union[int, "torch.Tensor"]] = None
217
+ ) -> "torch.Tensor":
218
+ """
219
+ Extract target scores for gradient computation.
220
+
221
+ Handles both multi-class and binary classification outputs.
222
+
223
+ Args:
224
+ output: Raw model output (logits)
225
+ target_class: Target class index or tensor of indices
226
+
227
+ Returns:
228
+ Target scores tensor for backpropagation
229
+ """
230
+ batch_size = output.shape[0]
231
+
232
+ # Normalize to 2D
233
+ if output.dim() == 1:
234
+ output = output.unsqueeze(-1)
235
+
236
+ n_outputs = output.shape[1]
237
+
238
+ if self.task == "classification":
239
+ if n_outputs == 1:
240
+ # Binary classification with single logit
241
+ # Score is the logit itself (positive class score)
242
+ return output.squeeze(-1)
243
+ else:
244
+ # Multi-class classification
245
+ if target_class is None:
246
+ target_class = output.argmax(dim=-1)
247
+ elif isinstance(target_class, int):
248
+ target_class = torch.tensor(
249
+ [target_class] * batch_size,
250
+ device=self.device
251
+ )
252
+
253
+ # Gather scores for target class
254
+ return output.gather(1, target_class.view(-1, 1)).squeeze(-1)
255
+ else:
256
+ # Regression: use first output or sum of outputs
257
+ if n_outputs == 1:
258
+ return output.squeeze(-1)
259
+ else:
260
+ return output.sum(dim=-1)
261
+
191
262
  def predict_with_gradients(
192
263
  self,
193
264
  data: np.ndarray,
194
265
  target_class: Optional[int] = None
195
- ) -> tuple:
266
+ ) -> Tuple[np.ndarray, np.ndarray]:
196
267
  """
197
268
  Generate predictions and compute gradients w.r.t. inputs.
198
269
 
@@ -203,11 +274,17 @@ class PyTorchAdapter(BaseModelAdapter):
203
274
  data: Input data as numpy array.
204
275
  target_class: Class index for gradient computation.
205
276
  If None, uses the predicted class.
277
+ For binary classification with single output,
278
+ this is ignored (gradient w.r.t. the single logit).
206
279
 
207
280
  Returns:
208
281
  Tuple of (predictions, gradients) as numpy arrays.
282
+ - predictions: (batch, n_classes) probabilities
283
+ - gradients: same shape as input data
209
284
  """
210
285
  data = np.array(data)
286
+ original_shape = data.shape
287
+
211
288
  if data.ndim == 1:
212
289
  data = data.reshape(1, -1)
213
290
 
@@ -217,20 +294,13 @@ class PyTorchAdapter(BaseModelAdapter):
217
294
 
218
295
  # Forward pass
219
296
  output = self.model(tensor_data)
220
- activated_output = self._apply_activation(output)
221
297
 
222
- # Determine target for gradient
223
- if self.task == "classification":
224
- if target_class is None:
225
- target_class = output.argmax(dim=-1)
226
- elif isinstance(target_class, int):
227
- target_class = torch.tensor([target_class] * data.shape[0], device=self.device)
228
-
229
- # Select target class scores for gradient
230
- target_scores = output.gather(1, target_class.view(-1, 1)).squeeze()
231
- else:
232
- # Regression: gradient w.r.t. output
233
- target_scores = output.squeeze()
298
+ # Get activated output for return
299
+ output_normalized = self._normalize_output_shape(output)
300
+ activated_output = self._apply_activation(output_normalized)
301
+
302
+ # Get target scores for gradient computation
303
+ target_scores = self._get_target_scores(output, target_class)
234
304
 
235
305
  # Backward pass
236
306
  if target_scores.dim() == 0:
@@ -295,7 +365,7 @@ class PyTorchAdapter(BaseModelAdapter):
295
365
  data: np.ndarray,
296
366
  layer_name: str,
297
367
  target_class: Optional[int] = None
298
- ) -> tuple:
368
+ ) -> Tuple[np.ndarray, np.ndarray]:
299
369
  """
300
370
  Get gradients of output w.r.t. a specific layer's activations.
301
371
 
@@ -339,15 +409,8 @@ class PyTorchAdapter(BaseModelAdapter):
339
409
 
340
410
  output = self.model(tensor_data)
341
411
 
342
- if self.task == "classification":
343
- if target_class is None:
344
- target_class = output.argmax(dim=-1)
345
- elif isinstance(target_class, int):
346
- target_class = torch.tensor([target_class] * data.shape[0], device=self.device)
347
-
348
- target_scores = output.gather(1, target_class.view(-1, 1)).squeeze()
349
- else:
350
- target_scores = output.squeeze()
412
+ # Get target scores using the new method
413
+ target_scores = self._get_target_scores(output, target_class)
351
414
 
352
415
  if target_scores.dim() == 0:
353
416
  target_scores.backward()
@@ -0,0 +1,179 @@
1
+ # src/explainiverse/core/explanation.py
2
+ """
3
+ Unified container for explanation results.
4
+
5
+ The Explanation class provides a standardized format for all explainer outputs,
6
+ enabling consistent handling across different explanation methods.
7
+ """
8
+
9
+ from typing import Dict, List, Optional, Any
10
+
11
+
12
+ class Explanation:
13
+ """
14
+ Unified container for explanation results.
15
+
16
+ Attributes:
17
+ explainer_name: Name of the explainer that generated this explanation
18
+ target_class: The class/output being explained
19
+ explanation_data: Dictionary containing explanation details
20
+ (e.g., feature_attributions, heatmaps, rules)
21
+ feature_names: Optional list of feature names for index resolution
22
+ metadata: Optional additional metadata about the explanation
23
+
24
+ Example:
25
+ >>> explanation = Explanation(
26
+ ... explainer_name="LIME",
27
+ ... target_class="cat",
28
+ ... explanation_data={"feature_attributions": {"fur": 0.8, "whiskers": 0.6}},
29
+ ... feature_names=["fur", "whiskers", "tail", "ears"]
30
+ ... )
31
+ >>> print(explanation.get_top_features(k=2))
32
+ [('fur', 0.8), ('whiskers', 0.6)]
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ explainer_name: str,
38
+ target_class: str,
39
+ explanation_data: Dict[str, Any],
40
+ feature_names: Optional[List[str]] = None,
41
+ metadata: Optional[Dict[str, Any]] = None
42
+ ):
43
+ """
44
+ Initialize an Explanation object.
45
+
46
+ Args:
47
+ explainer_name: Name of the explainer (e.g., "LIME", "SHAP")
48
+ target_class: The target class or output being explained
49
+ explanation_data: Dictionary containing the explanation details.
50
+ Common keys include:
51
+ - "feature_attributions": Dict[str, float] mapping feature names to importance
52
+ - "attributions_raw": List[float] of raw attribution values
53
+ - "heatmap": np.ndarray for image explanations
54
+ - "rules": List of rule strings for rule-based explanations
55
+ feature_names: Optional list of feature names. If provided, enables
56
+ index-based lookup in evaluation metrics.
57
+ metadata: Optional additional metadata (e.g., computation time, parameters)
58
+ """
59
+ self.explainer_name = explainer_name
60
+ self.target_class = target_class
61
+ self.explanation_data = explanation_data
62
+ self.feature_names = list(feature_names) if feature_names is not None else None
63
+ self.metadata = metadata or {}
64
+
65
+ def __repr__(self):
66
+ n_features = len(self.feature_names) if self.feature_names else "N/A"
67
+ return (
68
+ f"Explanation(explainer='{self.explainer_name}', "
69
+ f"target='{self.target_class}', "
70
+ f"keys={list(self.explanation_data.keys())}, "
71
+ f"n_features={n_features})"
72
+ )
73
+
74
+ def get_attributions(self) -> Optional[Dict[str, float]]:
75
+ """
76
+ Get feature attributions if available.
77
+
78
+ Returns:
79
+ Dictionary mapping feature names to attribution values,
80
+ or None if not available.
81
+ """
82
+ return self.explanation_data.get("feature_attributions")
83
+
84
+ def get_top_features(self, k: int = 5, absolute: bool = True) -> List[tuple]:
85
+ """
86
+ Get the top-k most important features.
87
+
88
+ Args:
89
+ k: Number of top features to return
90
+ absolute: If True, rank by absolute value of attribution
91
+
92
+ Returns:
93
+ List of (feature_name, attribution_value) tuples sorted by importance
94
+ """
95
+ attributions = self.get_attributions()
96
+ if not attributions:
97
+ return []
98
+
99
+ if absolute:
100
+ sorted_items = sorted(
101
+ attributions.items(),
102
+ key=lambda x: abs(x[1]),
103
+ reverse=True
104
+ )
105
+ else:
106
+ sorted_items = sorted(
107
+ attributions.items(),
108
+ key=lambda x: x[1],
109
+ reverse=True
110
+ )
111
+
112
+ return sorted_items[:k]
113
+
114
+ def get_feature_index(self, feature_name: str) -> Optional[int]:
115
+ """
116
+ Get the index of a feature by name.
117
+
118
+ Args:
119
+ feature_name: Name of the feature
120
+
121
+ Returns:
122
+ Index of the feature, or None if not found or feature_names not set
123
+ """
124
+ if self.feature_names is None:
125
+ return None
126
+ try:
127
+ return self.feature_names.index(feature_name)
128
+ except ValueError:
129
+ return None
130
+
131
+ def plot(self, plot_type: str = 'bar', **kwargs):
132
+ """
133
+ Visualize the explanation.
134
+
135
+ Args:
136
+ plot_type: Type of plot ('bar', 'waterfall', 'heatmap')
137
+ **kwargs: Additional arguments passed to the plotting function
138
+
139
+ Note:
140
+ This is a placeholder for future visualization integration.
141
+ """
142
+ print(
143
+ f"[plot: {plot_type}] Plotting explanation for {self.target_class} "
144
+ f"from {self.explainer_name}."
145
+ )
146
+
147
+ def to_dict(self) -> Dict[str, Any]:
148
+ """
149
+ Convert explanation to a dictionary for serialization.
150
+
151
+ Returns:
152
+ Dictionary representation of the explanation
153
+ """
154
+ return {
155
+ "explainer_name": self.explainer_name,
156
+ "target_class": self.target_class,
157
+ "explanation_data": self.explanation_data,
158
+ "feature_names": self.feature_names,
159
+ "metadata": self.metadata
160
+ }
161
+
162
+ @classmethod
163
+ def from_dict(cls, data: Dict[str, Any]) -> "Explanation":
164
+ """
165
+ Create an Explanation from a dictionary.
166
+
167
+ Args:
168
+ data: Dictionary with explanation data
169
+
170
+ Returns:
171
+ Explanation instance
172
+ """
173
+ return cls(
174
+ explainer_name=data["explainer_name"],
175
+ target_class=data["target_class"],
176
+ explanation_data=data["explanation_data"],
177
+ feature_names=data.get("feature_names"),
178
+ metadata=data.get("metadata", {})
179
+ )
@@ -0,0 +1,252 @@
1
+ # src/explainiverse/engine/suite.py
2
+ """
3
+ ExplanationSuite - Multi-explainer comparison and evaluation.
4
+
5
+ Provides utilities for running multiple explainers on the same instances
6
+ and comparing their outputs.
7
+ """
8
+
9
+ from typing import Dict, List, Optional, Any, Tuple
10
+ import numpy as np
11
+
12
+
13
+ class ExplanationSuite:
14
+ """
15
+ Run and compare multiple explainers on the same instances.
16
+
17
+ This class provides a unified interface for:
18
+ - Running multiple explainers on a single instance
19
+ - Comparing attribution scores side-by-side
20
+ - Suggesting the best explainer based on model/task characteristics
21
+ - Evaluating explainers using ROAR (Remove And Retrain)
22
+
23
+ Example:
24
+ >>> from explainiverse import ExplanationSuite, SklearnAdapter
25
+ >>> suite = ExplanationSuite(
26
+ ... model=adapter,
27
+ ... explainer_configs=[
28
+ ... ("lime", {"training_data": X_train, "feature_names": fnames, "class_names": cnames}),
29
+ ... ("shap", {"background_data": X_train[:50], "feature_names": fnames, "class_names": cnames}),
30
+ ... ]
31
+ ... )
32
+ >>> results = suite.run(X_test[0])
33
+ >>> suite.compare()
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ model,
39
+ explainer_configs: List[Tuple[str, Dict[str, Any]]],
40
+ data_meta: Optional[Dict[str, Any]] = None
41
+ ):
42
+ """
43
+ Initialize the ExplanationSuite.
44
+
45
+ Args:
46
+ model: A model adapter (e.g., SklearnAdapter, PyTorchAdapter)
47
+ explainer_configs: List of (explainer_name, kwargs) tuples.
48
+ The explainer_name should match a registered explainer in
49
+ the default_registry (e.g., "lime", "shap", "treeshap").
50
+ data_meta: Optional metadata about the task, scope, or preference.
51
+ Can include "task" ("classification" or "regression").
52
+ """
53
+ self.model = model
54
+ self.configs = explainer_configs
55
+ self.data_meta = data_meta or {}
56
+ self.explanations: Dict[str, Any] = {}
57
+ self._registry = None
58
+
59
+ def _get_registry(self):
60
+ """Lazy load the registry to avoid circular imports."""
61
+ if self._registry is None:
62
+ from explainiverse.core.registry import default_registry
63
+ self._registry = default_registry
64
+ return self._registry
65
+
66
+ def run(self, instance: np.ndarray) -> Dict[str, Any]:
67
+ """
68
+ Run all configured explainers on a single instance.
69
+
70
+ Args:
71
+ instance: Input instance to explain (1D numpy array)
72
+
73
+ Returns:
74
+ Dictionary mapping explainer names to Explanation objects
75
+ """
76
+ instance = np.asarray(instance)
77
+ registry = self._get_registry()
78
+
79
+ for name, params in self.configs:
80
+ try:
81
+ explainer = registry.create(name, model=self.model, **params)
82
+ explanation = explainer.explain(instance)
83
+ self.explanations[name] = explanation
84
+ except Exception as e:
85
+ print(f"[ExplanationSuite] Warning: Failed to run {name}: {e}")
86
+ continue
87
+
88
+ return self.explanations
89
+
90
+ def compare(self) -> None:
91
+ """
92
+ Print attribution scores side-by-side for comparison.
93
+ """
94
+ if not self.explanations:
95
+ print("No explanations to compare. Run suite.run(instance) first.")
96
+ return
97
+
98
+ # Collect all feature names across explanations
99
+ all_keys = set()
100
+ for explanation in self.explanations.values():
101
+ attrs = explanation.explanation_data.get("feature_attributions", {})
102
+ all_keys.update(attrs.keys())
103
+
104
+ print("\nSide-by-Side Comparison:")
105
+ print("-" * 60)
106
+
107
+ # Header
108
+ header = ["Feature"] + list(self.explanations.keys())
109
+ print(" | ".join(f"{h:>15}" for h in header))
110
+ print("-" * 60)
111
+
112
+ # Rows
113
+ for key in sorted(all_keys):
114
+ row = [f"{key:>15}"]
115
+ for name in self.explanations:
116
+ value = self.explanations[name].explanation_data.get(
117
+ "feature_attributions", {}
118
+ ).get(key, None)
119
+ if value is not None:
120
+ row.append(f"{value:>15.4f}")
121
+ else:
122
+ row.append(f"{'—':>15}")
123
+ print(" | ".join(row))
124
+
125
+ def suggest_best(self) -> str:
126
+ """
127
+ Suggest the best explainer based on model type and task characteristics.
128
+
129
+ Returns:
130
+ Name of the suggested explainer
131
+ """
132
+ task = self.data_meta.get("task", "unknown")
133
+ model = self.model.model if hasattr(self.model, 'model') else self.model
134
+
135
+ # 1. Regression: SHAP preferred due to consistent output
136
+ if task == "regression":
137
+ return "shap"
138
+
139
+ # 2. Model with predict_proba → SHAP handles probabilistic outputs well
140
+ if hasattr(model, "predict_proba"):
141
+ try:
142
+ # Check output dimensions
143
+ if hasattr(model, 'n_features_in_'):
144
+ test_input = np.zeros((1, model.n_features_in_))
145
+ output = self.model.predict(test_input)
146
+ if output.shape[1] > 2:
147
+ return "shap" # Multi-class, SHAP more stable
148
+ else:
149
+ return "lime" # Binary, both are okay
150
+ except Exception:
151
+ return "shap"
152
+
153
+ # 3. Tree-based models → prefer TreeSHAP
154
+ model_type_str = str(type(model)).lower()
155
+ if any(tree_type in model_type_str for tree_type in ['tree', 'forest', 'xgb', 'lgbm', 'catboost']):
156
+ return "treeshap"
157
+
158
+ # 4. Neural networks → prefer gradient methods
159
+ if 'torch' in model_type_str or 'keras' in model_type_str or 'tensorflow' in model_type_str:
160
+ return "integrated_gradients"
161
+
162
+ # 5. Default fallback
163
+ return "lime"
164
+
165
+ def evaluate_roar(
166
+ self,
167
+ X_train: np.ndarray,
168
+ y_train: np.ndarray,
169
+ X_test: np.ndarray,
170
+ y_test: np.ndarray,
171
+ top_k: int = 2,
172
+ model_class=None,
173
+ model_kwargs: Optional[Dict] = None
174
+ ) -> Dict[str, float]:
175
+ """
176
+ Evaluate each explainer using ROAR (Remove And Retrain).
177
+
178
+ ROAR measures explanation quality by retraining the model after
179
+ removing the top-k important features identified by each explainer.
180
+ A larger accuracy drop indicates more faithful explanations.
181
+
182
+ Args:
183
+ X_train, y_train: Training data
184
+ X_test, y_test: Test data
185
+ top_k: Number of features to mask
186
+ model_class: Model constructor with .fit() and .predict()
187
+ If None, uses the same type as self.model.model
188
+ model_kwargs: Optional keyword args for new model instance
189
+
190
+ Returns:
191
+ Dict mapping explainer names to accuracy drops
192
+ """
193
+ from explainiverse.evaluation.metrics import compute_roar
194
+
195
+ model_kwargs = model_kwargs or {}
196
+
197
+ # Default to type(self.model.model) if not provided
198
+ if model_class is None:
199
+ raw_model = self.model.model if hasattr(self.model, 'model') else self.model
200
+ model_class = type(raw_model)
201
+
202
+ roar_scores = {}
203
+
204
+ for name, explanation in self.explanations.items():
205
+ print(f"[ROAR] Evaluating explainer: {name}")
206
+ try:
207
+ roar = compute_roar(
208
+ model_class=model_class,
209
+ X_train=X_train,
210
+ y_train=y_train,
211
+ X_test=X_test,
212
+ y_test=y_test,
213
+ explanations=[explanation],
214
+ top_k=top_k,
215
+ model_kwargs=model_kwargs
216
+ )
217
+ roar_scores[name] = roar
218
+ except Exception as e:
219
+ print(f"[ROAR] Failed for {name}: {e}")
220
+ roar_scores[name] = 0.0
221
+
222
+ return roar_scores
223
+
224
+ def get_explanation(self, name: str):
225
+ """
226
+ Get a specific explanation by explainer name.
227
+
228
+ Args:
229
+ name: Name of the explainer
230
+
231
+ Returns:
232
+ Explanation object or None if not found
233
+ """
234
+ return self.explanations.get(name)
235
+
236
+ def list_explainers(self) -> List[str]:
237
+ """
238
+ List all configured explainer names.
239
+
240
+ Returns:
241
+ List of explainer names
242
+ """
243
+ return [name for name, _ in self.configs]
244
+
245
+ def list_completed(self) -> List[str]:
246
+ """
247
+ List explainers that have been run successfully.
248
+
249
+ Returns:
250
+ List of explainer names with results
251
+ """
252
+ return list(self.explanations.keys())