explainiverse 0.7.0__tar.gz → 0.7.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {explainiverse-0.7.0 → explainiverse-0.7.1}/PKG-INFO +2 -2
- {explainiverse-0.7.0 → explainiverse-0.7.1}/README.md +1 -1
- {explainiverse-0.7.0 → explainiverse-0.7.1}/pyproject.toml +1 -1
- {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/__init__.py +1 -1
- {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/adapters/pytorch_adapter.py +88 -25
- explainiverse-0.7.1/src/explainiverse/core/explanation.py +179 -0
- explainiverse-0.7.1/src/explainiverse/engine/suite.py +252 -0
- explainiverse-0.7.1/src/explainiverse/evaluation/metrics.py +314 -0
- {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/attribution/lime_wrapper.py +90 -7
- explainiverse-0.7.1/src/explainiverse/explainers/attribution/shap_wrapper.py +185 -0
- {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/gradient/integrated_gradients.py +189 -76
- explainiverse-0.7.0/src/explainiverse/core/explanation.py +0 -24
- explainiverse-0.7.0/src/explainiverse/engine/suite.py +0 -143
- explainiverse-0.7.0/src/explainiverse/evaluation/metrics.py +0 -233
- explainiverse-0.7.0/src/explainiverse/explainers/attribution/shap_wrapper.py +0 -89
- {explainiverse-0.7.0 → explainiverse-0.7.1}/LICENSE +0 -0
- {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/adapters/__init__.py +0 -0
- {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/adapters/base_adapter.py +0 -0
- {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/adapters/sklearn_adapter.py +0 -0
- {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/core/__init__.py +0 -0
- {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/core/explainer.py +0 -0
- {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/core/registry.py +0 -0
- {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/engine/__init__.py +0 -0
- {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/evaluation/__init__.py +0 -0
- {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/evaluation/_utils.py +0 -0
- {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/evaluation/faithfulness.py +0 -0
- {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/evaluation/stability.py +0 -0
- {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/__init__.py +0 -0
- {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/attribution/__init__.py +0 -0
- {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/attribution/treeshap_wrapper.py +0 -0
- {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/counterfactual/__init__.py +0 -0
- {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/counterfactual/dice_wrapper.py +0 -0
- {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/example_based/__init__.py +0 -0
- {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/example_based/protodash.py +0 -0
- {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/global_explainers/__init__.py +0 -0
- {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/global_explainers/ale.py +0 -0
- {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/global_explainers/partial_dependence.py +0 -0
- {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/global_explainers/permutation_importance.py +0 -0
- {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/global_explainers/sage.py +0 -0
- {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/gradient/__init__.py +0 -0
- {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/gradient/deeplift.py +0 -0
- {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/gradient/gradcam.py +0 -0
- {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/gradient/saliency.py +0 -0
- {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/gradient/smoothgrad.py +0 -0
- {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/gradient/tcav.py +0 -0
- {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/rule_based/__init__.py +0 -0
- {explainiverse-0.7.0 → explainiverse-0.7.1}/src/explainiverse/explainers/rule_based/anchors_wrapper.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: explainiverse
|
|
3
|
-
Version: 0.7.
|
|
3
|
+
Version: 0.7.1
|
|
4
4
|
Summary: Unified, extensible explainability framework supporting LIME, SHAP, Anchors, Counterfactuals, PDP, ALE, SAGE, and more
|
|
5
5
|
Home-page: https://github.com/jemsbhai/explainiverse
|
|
6
6
|
License: MIT
|
|
@@ -671,7 +671,7 @@ If you use Explainiverse in your research, please cite:
|
|
|
671
671
|
author = {Syed, Muntaser},
|
|
672
672
|
year = {2025},
|
|
673
673
|
url = {https://github.com/jemsbhai/explainiverse},
|
|
674
|
-
version = {0.7.
|
|
674
|
+
version = {0.7.1}
|
|
675
675
|
}
|
|
676
676
|
```
|
|
677
677
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "explainiverse"
|
|
3
|
-
version = "0.7.
|
|
3
|
+
version = "0.7.1"
|
|
4
4
|
description = "Unified, extensible explainability framework supporting LIME, SHAP, Anchors, Counterfactuals, PDP, ALE, SAGE, and more"
|
|
5
5
|
authors = ["Muntaser Syed <jemsbhai@gmail.com>"]
|
|
6
6
|
license = "MIT"
|
|
@@ -25,7 +25,7 @@ Example:
|
|
|
25
25
|
"""
|
|
26
26
|
|
|
27
27
|
import numpy as np
|
|
28
|
-
from typing import List, Optional, Union,
|
|
28
|
+
from typing import List, Optional, Union, Tuple
|
|
29
29
|
|
|
30
30
|
from .base_adapter import BaseModelAdapter
|
|
31
31
|
|
|
@@ -57,6 +57,11 @@ class PyTorchAdapter(BaseModelAdapter):
|
|
|
57
57
|
explainability methods. Handles device management, tensor/numpy
|
|
58
58
|
conversions, and supports both classification and regression tasks.
|
|
59
59
|
|
|
60
|
+
Supports:
|
|
61
|
+
- Multi-class classification (output shape: [batch, n_classes])
|
|
62
|
+
- Binary classification (output shape: [batch, 1] or [batch])
|
|
63
|
+
- Regression (output shape: [batch, n_outputs] or [batch])
|
|
64
|
+
|
|
60
65
|
Attributes:
|
|
61
66
|
model: The PyTorch model (nn.Module)
|
|
62
67
|
task: "classification" or "regression"
|
|
@@ -150,11 +155,27 @@ class PyTorchAdapter(BaseModelAdapter):
|
|
|
150
155
|
def _apply_activation(self, output: "torch.Tensor") -> "torch.Tensor":
|
|
151
156
|
"""Apply output activation function."""
|
|
152
157
|
if self.output_activation == "softmax":
|
|
158
|
+
# Handle different output shapes
|
|
159
|
+
if output.dim() == 1 or (output.dim() == 2 and output.shape[1] == 1):
|
|
160
|
+
# Binary: apply sigmoid instead of softmax
|
|
161
|
+
return torch.sigmoid(output)
|
|
153
162
|
return torch.softmax(output, dim=-1)
|
|
154
163
|
elif self.output_activation == "sigmoid":
|
|
155
164
|
return torch.sigmoid(output)
|
|
156
165
|
return output
|
|
157
166
|
|
|
167
|
+
def _normalize_output_shape(self, output: "torch.Tensor") -> "torch.Tensor":
|
|
168
|
+
"""
|
|
169
|
+
Normalize output to consistent 2D shape (batch, outputs).
|
|
170
|
+
|
|
171
|
+
Handles:
|
|
172
|
+
- (batch,) -> (batch, 1)
|
|
173
|
+
- (batch, n) -> (batch, n)
|
|
174
|
+
"""
|
|
175
|
+
if output.dim() == 1:
|
|
176
|
+
return output.unsqueeze(-1)
|
|
177
|
+
return output
|
|
178
|
+
|
|
158
179
|
def predict(self, data: np.ndarray) -> np.ndarray:
|
|
159
180
|
"""
|
|
160
181
|
Generate predictions for input data.
|
|
@@ -183,16 +204,66 @@ class PyTorchAdapter(BaseModelAdapter):
|
|
|
183
204
|
tensor_batch = self._to_tensor(batch)
|
|
184
205
|
|
|
185
206
|
output = self.model(tensor_batch)
|
|
207
|
+
output = self._normalize_output_shape(output)
|
|
186
208
|
output = self._apply_activation(output)
|
|
187
209
|
outputs.append(self._to_numpy(output))
|
|
188
210
|
|
|
189
211
|
return np.vstack(outputs)
|
|
190
212
|
|
|
213
|
+
def _get_target_scores(
|
|
214
|
+
self,
|
|
215
|
+
output: "torch.Tensor",
|
|
216
|
+
target_class: Optional[Union[int, "torch.Tensor"]] = None
|
|
217
|
+
) -> "torch.Tensor":
|
|
218
|
+
"""
|
|
219
|
+
Extract target scores for gradient computation.
|
|
220
|
+
|
|
221
|
+
Handles both multi-class and binary classification outputs.
|
|
222
|
+
|
|
223
|
+
Args:
|
|
224
|
+
output: Raw model output (logits)
|
|
225
|
+
target_class: Target class index or tensor of indices
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
Target scores tensor for backpropagation
|
|
229
|
+
"""
|
|
230
|
+
batch_size = output.shape[0]
|
|
231
|
+
|
|
232
|
+
# Normalize to 2D
|
|
233
|
+
if output.dim() == 1:
|
|
234
|
+
output = output.unsqueeze(-1)
|
|
235
|
+
|
|
236
|
+
n_outputs = output.shape[1]
|
|
237
|
+
|
|
238
|
+
if self.task == "classification":
|
|
239
|
+
if n_outputs == 1:
|
|
240
|
+
# Binary classification with single logit
|
|
241
|
+
# Score is the logit itself (positive class score)
|
|
242
|
+
return output.squeeze(-1)
|
|
243
|
+
else:
|
|
244
|
+
# Multi-class classification
|
|
245
|
+
if target_class is None:
|
|
246
|
+
target_class = output.argmax(dim=-1)
|
|
247
|
+
elif isinstance(target_class, int):
|
|
248
|
+
target_class = torch.tensor(
|
|
249
|
+
[target_class] * batch_size,
|
|
250
|
+
device=self.device
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
# Gather scores for target class
|
|
254
|
+
return output.gather(1, target_class.view(-1, 1)).squeeze(-1)
|
|
255
|
+
else:
|
|
256
|
+
# Regression: use first output or sum of outputs
|
|
257
|
+
if n_outputs == 1:
|
|
258
|
+
return output.squeeze(-1)
|
|
259
|
+
else:
|
|
260
|
+
return output.sum(dim=-1)
|
|
261
|
+
|
|
191
262
|
def predict_with_gradients(
|
|
192
263
|
self,
|
|
193
264
|
data: np.ndarray,
|
|
194
265
|
target_class: Optional[int] = None
|
|
195
|
-
) ->
|
|
266
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
196
267
|
"""
|
|
197
268
|
Generate predictions and compute gradients w.r.t. inputs.
|
|
198
269
|
|
|
@@ -203,11 +274,17 @@ class PyTorchAdapter(BaseModelAdapter):
|
|
|
203
274
|
data: Input data as numpy array.
|
|
204
275
|
target_class: Class index for gradient computation.
|
|
205
276
|
If None, uses the predicted class.
|
|
277
|
+
For binary classification with single output,
|
|
278
|
+
this is ignored (gradient w.r.t. the single logit).
|
|
206
279
|
|
|
207
280
|
Returns:
|
|
208
281
|
Tuple of (predictions, gradients) as numpy arrays.
|
|
282
|
+
- predictions: (batch, n_classes) probabilities
|
|
283
|
+
- gradients: same shape as input data
|
|
209
284
|
"""
|
|
210
285
|
data = np.array(data)
|
|
286
|
+
original_shape = data.shape
|
|
287
|
+
|
|
211
288
|
if data.ndim == 1:
|
|
212
289
|
data = data.reshape(1, -1)
|
|
213
290
|
|
|
@@ -217,20 +294,13 @@ class PyTorchAdapter(BaseModelAdapter):
|
|
|
217
294
|
|
|
218
295
|
# Forward pass
|
|
219
296
|
output = self.model(tensor_data)
|
|
220
|
-
activated_output = self._apply_activation(output)
|
|
221
297
|
|
|
222
|
-
#
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
# Select target class scores for gradient
|
|
230
|
-
target_scores = output.gather(1, target_class.view(-1, 1)).squeeze()
|
|
231
|
-
else:
|
|
232
|
-
# Regression: gradient w.r.t. output
|
|
233
|
-
target_scores = output.squeeze()
|
|
298
|
+
# Get activated output for return
|
|
299
|
+
output_normalized = self._normalize_output_shape(output)
|
|
300
|
+
activated_output = self._apply_activation(output_normalized)
|
|
301
|
+
|
|
302
|
+
# Get target scores for gradient computation
|
|
303
|
+
target_scores = self._get_target_scores(output, target_class)
|
|
234
304
|
|
|
235
305
|
# Backward pass
|
|
236
306
|
if target_scores.dim() == 0:
|
|
@@ -295,7 +365,7 @@ class PyTorchAdapter(BaseModelAdapter):
|
|
|
295
365
|
data: np.ndarray,
|
|
296
366
|
layer_name: str,
|
|
297
367
|
target_class: Optional[int] = None
|
|
298
|
-
) ->
|
|
368
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
299
369
|
"""
|
|
300
370
|
Get gradients of output w.r.t. a specific layer's activations.
|
|
301
371
|
|
|
@@ -339,15 +409,8 @@ class PyTorchAdapter(BaseModelAdapter):
|
|
|
339
409
|
|
|
340
410
|
output = self.model(tensor_data)
|
|
341
411
|
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
target_class = output.argmax(dim=-1)
|
|
345
|
-
elif isinstance(target_class, int):
|
|
346
|
-
target_class = torch.tensor([target_class] * data.shape[0], device=self.device)
|
|
347
|
-
|
|
348
|
-
target_scores = output.gather(1, target_class.view(-1, 1)).squeeze()
|
|
349
|
-
else:
|
|
350
|
-
target_scores = output.squeeze()
|
|
412
|
+
# Get target scores using the new method
|
|
413
|
+
target_scores = self._get_target_scores(output, target_class)
|
|
351
414
|
|
|
352
415
|
if target_scores.dim() == 0:
|
|
353
416
|
target_scores.backward()
|
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
# src/explainiverse/core/explanation.py
|
|
2
|
+
"""
|
|
3
|
+
Unified container for explanation results.
|
|
4
|
+
|
|
5
|
+
The Explanation class provides a standardized format for all explainer outputs,
|
|
6
|
+
enabling consistent handling across different explanation methods.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import Dict, List, Optional, Any
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Explanation:
|
|
13
|
+
"""
|
|
14
|
+
Unified container for explanation results.
|
|
15
|
+
|
|
16
|
+
Attributes:
|
|
17
|
+
explainer_name: Name of the explainer that generated this explanation
|
|
18
|
+
target_class: The class/output being explained
|
|
19
|
+
explanation_data: Dictionary containing explanation details
|
|
20
|
+
(e.g., feature_attributions, heatmaps, rules)
|
|
21
|
+
feature_names: Optional list of feature names for index resolution
|
|
22
|
+
metadata: Optional additional metadata about the explanation
|
|
23
|
+
|
|
24
|
+
Example:
|
|
25
|
+
>>> explanation = Explanation(
|
|
26
|
+
... explainer_name="LIME",
|
|
27
|
+
... target_class="cat",
|
|
28
|
+
... explanation_data={"feature_attributions": {"fur": 0.8, "whiskers": 0.6}},
|
|
29
|
+
... feature_names=["fur", "whiskers", "tail", "ears"]
|
|
30
|
+
... )
|
|
31
|
+
>>> print(explanation.get_top_features(k=2))
|
|
32
|
+
[('fur', 0.8), ('whiskers', 0.6)]
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
explainer_name: str,
|
|
38
|
+
target_class: str,
|
|
39
|
+
explanation_data: Dict[str, Any],
|
|
40
|
+
feature_names: Optional[List[str]] = None,
|
|
41
|
+
metadata: Optional[Dict[str, Any]] = None
|
|
42
|
+
):
|
|
43
|
+
"""
|
|
44
|
+
Initialize an Explanation object.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
explainer_name: Name of the explainer (e.g., "LIME", "SHAP")
|
|
48
|
+
target_class: The target class or output being explained
|
|
49
|
+
explanation_data: Dictionary containing the explanation details.
|
|
50
|
+
Common keys include:
|
|
51
|
+
- "feature_attributions": Dict[str, float] mapping feature names to importance
|
|
52
|
+
- "attributions_raw": List[float] of raw attribution values
|
|
53
|
+
- "heatmap": np.ndarray for image explanations
|
|
54
|
+
- "rules": List of rule strings for rule-based explanations
|
|
55
|
+
feature_names: Optional list of feature names. If provided, enables
|
|
56
|
+
index-based lookup in evaluation metrics.
|
|
57
|
+
metadata: Optional additional metadata (e.g., computation time, parameters)
|
|
58
|
+
"""
|
|
59
|
+
self.explainer_name = explainer_name
|
|
60
|
+
self.target_class = target_class
|
|
61
|
+
self.explanation_data = explanation_data
|
|
62
|
+
self.feature_names = list(feature_names) if feature_names is not None else None
|
|
63
|
+
self.metadata = metadata or {}
|
|
64
|
+
|
|
65
|
+
def __repr__(self):
|
|
66
|
+
n_features = len(self.feature_names) if self.feature_names else "N/A"
|
|
67
|
+
return (
|
|
68
|
+
f"Explanation(explainer='{self.explainer_name}', "
|
|
69
|
+
f"target='{self.target_class}', "
|
|
70
|
+
f"keys={list(self.explanation_data.keys())}, "
|
|
71
|
+
f"n_features={n_features})"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def get_attributions(self) -> Optional[Dict[str, float]]:
|
|
75
|
+
"""
|
|
76
|
+
Get feature attributions if available.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
Dictionary mapping feature names to attribution values,
|
|
80
|
+
or None if not available.
|
|
81
|
+
"""
|
|
82
|
+
return self.explanation_data.get("feature_attributions")
|
|
83
|
+
|
|
84
|
+
def get_top_features(self, k: int = 5, absolute: bool = True) -> List[tuple]:
|
|
85
|
+
"""
|
|
86
|
+
Get the top-k most important features.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
k: Number of top features to return
|
|
90
|
+
absolute: If True, rank by absolute value of attribution
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
List of (feature_name, attribution_value) tuples sorted by importance
|
|
94
|
+
"""
|
|
95
|
+
attributions = self.get_attributions()
|
|
96
|
+
if not attributions:
|
|
97
|
+
return []
|
|
98
|
+
|
|
99
|
+
if absolute:
|
|
100
|
+
sorted_items = sorted(
|
|
101
|
+
attributions.items(),
|
|
102
|
+
key=lambda x: abs(x[1]),
|
|
103
|
+
reverse=True
|
|
104
|
+
)
|
|
105
|
+
else:
|
|
106
|
+
sorted_items = sorted(
|
|
107
|
+
attributions.items(),
|
|
108
|
+
key=lambda x: x[1],
|
|
109
|
+
reverse=True
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
return sorted_items[:k]
|
|
113
|
+
|
|
114
|
+
def get_feature_index(self, feature_name: str) -> Optional[int]:
|
|
115
|
+
"""
|
|
116
|
+
Get the index of a feature by name.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
feature_name: Name of the feature
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
Index of the feature, or None if not found or feature_names not set
|
|
123
|
+
"""
|
|
124
|
+
if self.feature_names is None:
|
|
125
|
+
return None
|
|
126
|
+
try:
|
|
127
|
+
return self.feature_names.index(feature_name)
|
|
128
|
+
except ValueError:
|
|
129
|
+
return None
|
|
130
|
+
|
|
131
|
+
def plot(self, plot_type: str = 'bar', **kwargs):
|
|
132
|
+
"""
|
|
133
|
+
Visualize the explanation.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
plot_type: Type of plot ('bar', 'waterfall', 'heatmap')
|
|
137
|
+
**kwargs: Additional arguments passed to the plotting function
|
|
138
|
+
|
|
139
|
+
Note:
|
|
140
|
+
This is a placeholder for future visualization integration.
|
|
141
|
+
"""
|
|
142
|
+
print(
|
|
143
|
+
f"[plot: {plot_type}] Plotting explanation for {self.target_class} "
|
|
144
|
+
f"from {self.explainer_name}."
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
148
|
+
"""
|
|
149
|
+
Convert explanation to a dictionary for serialization.
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
Dictionary representation of the explanation
|
|
153
|
+
"""
|
|
154
|
+
return {
|
|
155
|
+
"explainer_name": self.explainer_name,
|
|
156
|
+
"target_class": self.target_class,
|
|
157
|
+
"explanation_data": self.explanation_data,
|
|
158
|
+
"feature_names": self.feature_names,
|
|
159
|
+
"metadata": self.metadata
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
@classmethod
|
|
163
|
+
def from_dict(cls, data: Dict[str, Any]) -> "Explanation":
|
|
164
|
+
"""
|
|
165
|
+
Create an Explanation from a dictionary.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
data: Dictionary with explanation data
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
Explanation instance
|
|
172
|
+
"""
|
|
173
|
+
return cls(
|
|
174
|
+
explainer_name=data["explainer_name"],
|
|
175
|
+
target_class=data["target_class"],
|
|
176
|
+
explanation_data=data["explanation_data"],
|
|
177
|
+
feature_names=data.get("feature_names"),
|
|
178
|
+
metadata=data.get("metadata", {})
|
|
179
|
+
)
|
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
# src/explainiverse/engine/suite.py
|
|
2
|
+
"""
|
|
3
|
+
ExplanationSuite - Multi-explainer comparison and evaluation.
|
|
4
|
+
|
|
5
|
+
Provides utilities for running multiple explainers on the same instances
|
|
6
|
+
and comparing their outputs.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import Dict, List, Optional, Any, Tuple
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ExplanationSuite:
|
|
14
|
+
"""
|
|
15
|
+
Run and compare multiple explainers on the same instances.
|
|
16
|
+
|
|
17
|
+
This class provides a unified interface for:
|
|
18
|
+
- Running multiple explainers on a single instance
|
|
19
|
+
- Comparing attribution scores side-by-side
|
|
20
|
+
- Suggesting the best explainer based on model/task characteristics
|
|
21
|
+
- Evaluating explainers using ROAR (Remove And Retrain)
|
|
22
|
+
|
|
23
|
+
Example:
|
|
24
|
+
>>> from explainiverse import ExplanationSuite, SklearnAdapter
|
|
25
|
+
>>> suite = ExplanationSuite(
|
|
26
|
+
... model=adapter,
|
|
27
|
+
... explainer_configs=[
|
|
28
|
+
... ("lime", {"training_data": X_train, "feature_names": fnames, "class_names": cnames}),
|
|
29
|
+
... ("shap", {"background_data": X_train[:50], "feature_names": fnames, "class_names": cnames}),
|
|
30
|
+
... ]
|
|
31
|
+
... )
|
|
32
|
+
>>> results = suite.run(X_test[0])
|
|
33
|
+
>>> suite.compare()
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
model,
|
|
39
|
+
explainer_configs: List[Tuple[str, Dict[str, Any]]],
|
|
40
|
+
data_meta: Optional[Dict[str, Any]] = None
|
|
41
|
+
):
|
|
42
|
+
"""
|
|
43
|
+
Initialize the ExplanationSuite.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
model: A model adapter (e.g., SklearnAdapter, PyTorchAdapter)
|
|
47
|
+
explainer_configs: List of (explainer_name, kwargs) tuples.
|
|
48
|
+
The explainer_name should match a registered explainer in
|
|
49
|
+
the default_registry (e.g., "lime", "shap", "treeshap").
|
|
50
|
+
data_meta: Optional metadata about the task, scope, or preference.
|
|
51
|
+
Can include "task" ("classification" or "regression").
|
|
52
|
+
"""
|
|
53
|
+
self.model = model
|
|
54
|
+
self.configs = explainer_configs
|
|
55
|
+
self.data_meta = data_meta or {}
|
|
56
|
+
self.explanations: Dict[str, Any] = {}
|
|
57
|
+
self._registry = None
|
|
58
|
+
|
|
59
|
+
def _get_registry(self):
|
|
60
|
+
"""Lazy load the registry to avoid circular imports."""
|
|
61
|
+
if self._registry is None:
|
|
62
|
+
from explainiverse.core.registry import default_registry
|
|
63
|
+
self._registry = default_registry
|
|
64
|
+
return self._registry
|
|
65
|
+
|
|
66
|
+
def run(self, instance: np.ndarray) -> Dict[str, Any]:
|
|
67
|
+
"""
|
|
68
|
+
Run all configured explainers on a single instance.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
instance: Input instance to explain (1D numpy array)
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
Dictionary mapping explainer names to Explanation objects
|
|
75
|
+
"""
|
|
76
|
+
instance = np.asarray(instance)
|
|
77
|
+
registry = self._get_registry()
|
|
78
|
+
|
|
79
|
+
for name, params in self.configs:
|
|
80
|
+
try:
|
|
81
|
+
explainer = registry.create(name, model=self.model, **params)
|
|
82
|
+
explanation = explainer.explain(instance)
|
|
83
|
+
self.explanations[name] = explanation
|
|
84
|
+
except Exception as e:
|
|
85
|
+
print(f"[ExplanationSuite] Warning: Failed to run {name}: {e}")
|
|
86
|
+
continue
|
|
87
|
+
|
|
88
|
+
return self.explanations
|
|
89
|
+
|
|
90
|
+
def compare(self) -> None:
|
|
91
|
+
"""
|
|
92
|
+
Print attribution scores side-by-side for comparison.
|
|
93
|
+
"""
|
|
94
|
+
if not self.explanations:
|
|
95
|
+
print("No explanations to compare. Run suite.run(instance) first.")
|
|
96
|
+
return
|
|
97
|
+
|
|
98
|
+
# Collect all feature names across explanations
|
|
99
|
+
all_keys = set()
|
|
100
|
+
for explanation in self.explanations.values():
|
|
101
|
+
attrs = explanation.explanation_data.get("feature_attributions", {})
|
|
102
|
+
all_keys.update(attrs.keys())
|
|
103
|
+
|
|
104
|
+
print("\nSide-by-Side Comparison:")
|
|
105
|
+
print("-" * 60)
|
|
106
|
+
|
|
107
|
+
# Header
|
|
108
|
+
header = ["Feature"] + list(self.explanations.keys())
|
|
109
|
+
print(" | ".join(f"{h:>15}" for h in header))
|
|
110
|
+
print("-" * 60)
|
|
111
|
+
|
|
112
|
+
# Rows
|
|
113
|
+
for key in sorted(all_keys):
|
|
114
|
+
row = [f"{key:>15}"]
|
|
115
|
+
for name in self.explanations:
|
|
116
|
+
value = self.explanations[name].explanation_data.get(
|
|
117
|
+
"feature_attributions", {}
|
|
118
|
+
).get(key, None)
|
|
119
|
+
if value is not None:
|
|
120
|
+
row.append(f"{value:>15.4f}")
|
|
121
|
+
else:
|
|
122
|
+
row.append(f"{'—':>15}")
|
|
123
|
+
print(" | ".join(row))
|
|
124
|
+
|
|
125
|
+
def suggest_best(self) -> str:
|
|
126
|
+
"""
|
|
127
|
+
Suggest the best explainer based on model type and task characteristics.
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
Name of the suggested explainer
|
|
131
|
+
"""
|
|
132
|
+
task = self.data_meta.get("task", "unknown")
|
|
133
|
+
model = self.model.model if hasattr(self.model, 'model') else self.model
|
|
134
|
+
|
|
135
|
+
# 1. Regression: SHAP preferred due to consistent output
|
|
136
|
+
if task == "regression":
|
|
137
|
+
return "shap"
|
|
138
|
+
|
|
139
|
+
# 2. Model with predict_proba → SHAP handles probabilistic outputs well
|
|
140
|
+
if hasattr(model, "predict_proba"):
|
|
141
|
+
try:
|
|
142
|
+
# Check output dimensions
|
|
143
|
+
if hasattr(model, 'n_features_in_'):
|
|
144
|
+
test_input = np.zeros((1, model.n_features_in_))
|
|
145
|
+
output = self.model.predict(test_input)
|
|
146
|
+
if output.shape[1] > 2:
|
|
147
|
+
return "shap" # Multi-class, SHAP more stable
|
|
148
|
+
else:
|
|
149
|
+
return "lime" # Binary, both are okay
|
|
150
|
+
except Exception:
|
|
151
|
+
return "shap"
|
|
152
|
+
|
|
153
|
+
# 3. Tree-based models → prefer TreeSHAP
|
|
154
|
+
model_type_str = str(type(model)).lower()
|
|
155
|
+
if any(tree_type in model_type_str for tree_type in ['tree', 'forest', 'xgb', 'lgbm', 'catboost']):
|
|
156
|
+
return "treeshap"
|
|
157
|
+
|
|
158
|
+
# 4. Neural networks → prefer gradient methods
|
|
159
|
+
if 'torch' in model_type_str or 'keras' in model_type_str or 'tensorflow' in model_type_str:
|
|
160
|
+
return "integrated_gradients"
|
|
161
|
+
|
|
162
|
+
# 5. Default fallback
|
|
163
|
+
return "lime"
|
|
164
|
+
|
|
165
|
+
def evaluate_roar(
|
|
166
|
+
self,
|
|
167
|
+
X_train: np.ndarray,
|
|
168
|
+
y_train: np.ndarray,
|
|
169
|
+
X_test: np.ndarray,
|
|
170
|
+
y_test: np.ndarray,
|
|
171
|
+
top_k: int = 2,
|
|
172
|
+
model_class=None,
|
|
173
|
+
model_kwargs: Optional[Dict] = None
|
|
174
|
+
) -> Dict[str, float]:
|
|
175
|
+
"""
|
|
176
|
+
Evaluate each explainer using ROAR (Remove And Retrain).
|
|
177
|
+
|
|
178
|
+
ROAR measures explanation quality by retraining the model after
|
|
179
|
+
removing the top-k important features identified by each explainer.
|
|
180
|
+
A larger accuracy drop indicates more faithful explanations.
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
X_train, y_train: Training data
|
|
184
|
+
X_test, y_test: Test data
|
|
185
|
+
top_k: Number of features to mask
|
|
186
|
+
model_class: Model constructor with .fit() and .predict()
|
|
187
|
+
If None, uses the same type as self.model.model
|
|
188
|
+
model_kwargs: Optional keyword args for new model instance
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
Dict mapping explainer names to accuracy drops
|
|
192
|
+
"""
|
|
193
|
+
from explainiverse.evaluation.metrics import compute_roar
|
|
194
|
+
|
|
195
|
+
model_kwargs = model_kwargs or {}
|
|
196
|
+
|
|
197
|
+
# Default to type(self.model.model) if not provided
|
|
198
|
+
if model_class is None:
|
|
199
|
+
raw_model = self.model.model if hasattr(self.model, 'model') else self.model
|
|
200
|
+
model_class = type(raw_model)
|
|
201
|
+
|
|
202
|
+
roar_scores = {}
|
|
203
|
+
|
|
204
|
+
for name, explanation in self.explanations.items():
|
|
205
|
+
print(f"[ROAR] Evaluating explainer: {name}")
|
|
206
|
+
try:
|
|
207
|
+
roar = compute_roar(
|
|
208
|
+
model_class=model_class,
|
|
209
|
+
X_train=X_train,
|
|
210
|
+
y_train=y_train,
|
|
211
|
+
X_test=X_test,
|
|
212
|
+
y_test=y_test,
|
|
213
|
+
explanations=[explanation],
|
|
214
|
+
top_k=top_k,
|
|
215
|
+
model_kwargs=model_kwargs
|
|
216
|
+
)
|
|
217
|
+
roar_scores[name] = roar
|
|
218
|
+
except Exception as e:
|
|
219
|
+
print(f"[ROAR] Failed for {name}: {e}")
|
|
220
|
+
roar_scores[name] = 0.0
|
|
221
|
+
|
|
222
|
+
return roar_scores
|
|
223
|
+
|
|
224
|
+
def get_explanation(self, name: str):
|
|
225
|
+
"""
|
|
226
|
+
Get a specific explanation by explainer name.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
name: Name of the explainer
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
Explanation object or None if not found
|
|
233
|
+
"""
|
|
234
|
+
return self.explanations.get(name)
|
|
235
|
+
|
|
236
|
+
def list_explainers(self) -> List[str]:
|
|
237
|
+
"""
|
|
238
|
+
List all configured explainer names.
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
List of explainer names
|
|
242
|
+
"""
|
|
243
|
+
return [name for name, _ in self.configs]
|
|
244
|
+
|
|
245
|
+
def list_completed(self) -> List[str]:
|
|
246
|
+
"""
|
|
247
|
+
List explainers that have been run successfully.
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
List of explainer names with results
|
|
251
|
+
"""
|
|
252
|
+
return list(self.explanations.keys())
|