explainiverse 0.2.5__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- explainiverse/__init__.py +1 -1
- explainiverse/core/registry.py +22 -0
- explainiverse/evaluation/__init__.py +54 -2
- explainiverse/evaluation/_utils.py +325 -0
- explainiverse/evaluation/faithfulness.py +428 -0
- explainiverse/evaluation/stability.py +379 -0
- explainiverse/explainers/__init__.py +8 -0
- explainiverse/explainers/example_based/__init__.py +18 -0
- explainiverse/explainers/example_based/protodash.py +826 -0
- {explainiverse-0.2.5.dist-info → explainiverse-0.4.0.dist-info}/METADATA +2 -1
- {explainiverse-0.2.5.dist-info → explainiverse-0.4.0.dist-info}/RECORD +13 -8
- {explainiverse-0.2.5.dist-info → explainiverse-0.4.0.dist-info}/LICENSE +0 -0
- {explainiverse-0.2.5.dist-info → explainiverse-0.4.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,379 @@
|
|
|
1
|
+
# src/explainiverse/evaluation/stability.py
|
|
2
|
+
"""
|
|
3
|
+
Stability evaluation metrics for explanations.
|
|
4
|
+
|
|
5
|
+
Implements:
|
|
6
|
+
- RIS (Relative Input Stability) - sensitivity to input perturbations
|
|
7
|
+
- ROS (Relative Output Stability) - consistency with similar predictions
|
|
8
|
+
- Lipschitz Estimate - local smoothness of explanations
|
|
9
|
+
"""
|
|
10
|
+
import numpy as np
|
|
11
|
+
from typing import Union, Callable, List, Dict, Optional, Tuple
|
|
12
|
+
from explainiverse.core.explanation import Explanation
|
|
13
|
+
from explainiverse.core.explainer import BaseExplainer
|
|
14
|
+
from explainiverse.evaluation._utils import get_prediction_value
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _extract_attribution_vector(explanation: Explanation) -> np.ndarray:
|
|
18
|
+
"""
|
|
19
|
+
Extract attribution values as a numpy array from an Explanation.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
explanation: Explanation object with feature_attributions
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
1D numpy array of attribution values
|
|
26
|
+
"""
|
|
27
|
+
attributions = explanation.explanation_data.get("feature_attributions", {})
|
|
28
|
+
if not attributions:
|
|
29
|
+
raise ValueError("No feature attributions found in explanation.")
|
|
30
|
+
|
|
31
|
+
# Get values in consistent order
|
|
32
|
+
feature_names = getattr(explanation, 'feature_names', None)
|
|
33
|
+
if feature_names:
|
|
34
|
+
values = [attributions.get(fn, 0.0) for fn in feature_names]
|
|
35
|
+
else:
|
|
36
|
+
values = list(attributions.values())
|
|
37
|
+
|
|
38
|
+
return np.array(values, dtype=float)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _normalize_vector(v: np.ndarray, epsilon: float = 1e-10) -> np.ndarray:
|
|
42
|
+
"""Normalize a vector to unit length."""
|
|
43
|
+
norm = np.linalg.norm(v)
|
|
44
|
+
if norm < epsilon:
|
|
45
|
+
return v
|
|
46
|
+
return v / norm
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def compute_ris(
|
|
50
|
+
explainer: BaseExplainer,
|
|
51
|
+
instance: np.ndarray,
|
|
52
|
+
n_perturbations: int = 10,
|
|
53
|
+
noise_scale: float = 0.01,
|
|
54
|
+
seed: int = None,
|
|
55
|
+
) -> float:
|
|
56
|
+
"""
|
|
57
|
+
Compute Relative Input Stability (RIS).
|
|
58
|
+
|
|
59
|
+
Measures how stable explanations are to small perturbations in the input.
|
|
60
|
+
Lower RIS indicates more stable explanations.
|
|
61
|
+
|
|
62
|
+
RIS = mean(||E(x) - E(x')|| / ||x - x'||) for perturbed inputs x'
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
explainer: Explainer instance with .explain() method
|
|
66
|
+
instance: Original input instance (1D array)
|
|
67
|
+
n_perturbations: Number of perturbed samples to generate
|
|
68
|
+
noise_scale: Standard deviation of Gaussian noise (relative to feature range)
|
|
69
|
+
seed: Random seed for reproducibility
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
RIS score (lower = more stable)
|
|
73
|
+
"""
|
|
74
|
+
if seed is not None:
|
|
75
|
+
np.random.seed(seed)
|
|
76
|
+
|
|
77
|
+
instance = np.asarray(instance).flatten()
|
|
78
|
+
n_features = len(instance)
|
|
79
|
+
|
|
80
|
+
# Get original explanation
|
|
81
|
+
original_exp = explainer.explain(instance)
|
|
82
|
+
original_exp.feature_names = getattr(original_exp, 'feature_names', None) or \
|
|
83
|
+
[f"feature_{i}" for i in range(n_features)]
|
|
84
|
+
original_attr = _extract_attribution_vector(original_exp)
|
|
85
|
+
|
|
86
|
+
ratios = []
|
|
87
|
+
|
|
88
|
+
for _ in range(n_perturbations):
|
|
89
|
+
# Generate perturbed input
|
|
90
|
+
noise = np.random.normal(0, noise_scale, n_features)
|
|
91
|
+
perturbed = instance + noise * np.abs(instance + 1e-10) # Scale noise by feature magnitude
|
|
92
|
+
|
|
93
|
+
# Get perturbed explanation
|
|
94
|
+
try:
|
|
95
|
+
perturbed_exp = explainer.explain(perturbed)
|
|
96
|
+
perturbed_exp.feature_names = original_exp.feature_names
|
|
97
|
+
perturbed_attr = _extract_attribution_vector(perturbed_exp)
|
|
98
|
+
except Exception:
|
|
99
|
+
continue
|
|
100
|
+
|
|
101
|
+
# Compute ratio of changes
|
|
102
|
+
attr_diff = np.linalg.norm(original_attr - perturbed_attr)
|
|
103
|
+
input_diff = np.linalg.norm(instance - perturbed)
|
|
104
|
+
|
|
105
|
+
if input_diff > 1e-10:
|
|
106
|
+
ratios.append(attr_diff / input_diff)
|
|
107
|
+
|
|
108
|
+
if not ratios:
|
|
109
|
+
return float('inf')
|
|
110
|
+
|
|
111
|
+
return float(np.mean(ratios))
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def compute_ros(
|
|
115
|
+
explainer: BaseExplainer,
|
|
116
|
+
model,
|
|
117
|
+
instance: np.ndarray,
|
|
118
|
+
reference_instances: np.ndarray,
|
|
119
|
+
n_neighbors: int = 5,
|
|
120
|
+
prediction_threshold: float = 0.05,
|
|
121
|
+
) -> float:
|
|
122
|
+
"""
|
|
123
|
+
Compute Relative Output Stability (ROS).
|
|
124
|
+
|
|
125
|
+
Measures how similar explanations are for instances with similar predictions.
|
|
126
|
+
Higher ROS indicates more consistent explanations.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
explainer: Explainer instance with .explain() method
|
|
130
|
+
model: Model adapter with predict/predict_proba method
|
|
131
|
+
instance: Query instance
|
|
132
|
+
reference_instances: Pool of reference instances to find neighbors
|
|
133
|
+
n_neighbors: Number of neighbors to compare
|
|
134
|
+
prediction_threshold: Maximum prediction difference to consider "similar"
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
ROS score (higher = more consistent for similar predictions)
|
|
138
|
+
"""
|
|
139
|
+
instance = np.asarray(instance).flatten()
|
|
140
|
+
n_features = len(instance)
|
|
141
|
+
|
|
142
|
+
# Get prediction for query instance
|
|
143
|
+
query_pred = get_prediction_value(model, instance)
|
|
144
|
+
|
|
145
|
+
# Find instances with similar predictions
|
|
146
|
+
similar_instances = []
|
|
147
|
+
for ref in reference_instances:
|
|
148
|
+
ref = np.asarray(ref).flatten()
|
|
149
|
+
ref_pred = get_prediction_value(model, ref)
|
|
150
|
+
if abs(query_pred - ref_pred) <= prediction_threshold:
|
|
151
|
+
similar_instances.append(ref)
|
|
152
|
+
|
|
153
|
+
if len(similar_instances) < 2:
|
|
154
|
+
return 1.0 # Perfect stability if no similar instances
|
|
155
|
+
|
|
156
|
+
# Limit to n_neighbors
|
|
157
|
+
similar_instances = similar_instances[:n_neighbors]
|
|
158
|
+
|
|
159
|
+
# Get explanation for query
|
|
160
|
+
query_exp = explainer.explain(instance)
|
|
161
|
+
query_exp.feature_names = getattr(query_exp, 'feature_names', None) or \
|
|
162
|
+
[f"feature_{i}" for i in range(n_features)]
|
|
163
|
+
query_attr = _normalize_vector(_extract_attribution_vector(query_exp))
|
|
164
|
+
|
|
165
|
+
# Get explanations for similar instances and compute similarity
|
|
166
|
+
similarities = []
|
|
167
|
+
for ref in similar_instances:
|
|
168
|
+
try:
|
|
169
|
+
ref_exp = explainer.explain(ref)
|
|
170
|
+
ref_exp.feature_names = query_exp.feature_names
|
|
171
|
+
ref_attr = _normalize_vector(_extract_attribution_vector(ref_exp))
|
|
172
|
+
|
|
173
|
+
# Cosine similarity
|
|
174
|
+
similarity = np.dot(query_attr, ref_attr)
|
|
175
|
+
similarities.append(similarity)
|
|
176
|
+
except Exception:
|
|
177
|
+
continue
|
|
178
|
+
|
|
179
|
+
if not similarities:
|
|
180
|
+
return 1.0
|
|
181
|
+
|
|
182
|
+
return float(np.mean(similarities))
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def compute_lipschitz_estimate(
|
|
186
|
+
explainer: BaseExplainer,
|
|
187
|
+
instance: np.ndarray,
|
|
188
|
+
n_samples: int = 20,
|
|
189
|
+
radius: float = 0.1,
|
|
190
|
+
seed: int = None,
|
|
191
|
+
) -> float:
|
|
192
|
+
"""
|
|
193
|
+
Estimate local Lipschitz constant of the explanation function.
|
|
194
|
+
|
|
195
|
+
The Lipschitz constant bounds how fast explanations can change:
|
|
196
|
+
||E(x) - E(y)|| <= L * ||x - y||
|
|
197
|
+
|
|
198
|
+
Lower L indicates smoother, more stable explanations.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
explainer: Explainer instance
|
|
202
|
+
instance: Center point for local estimate
|
|
203
|
+
n_samples: Number of sample pairs to evaluate
|
|
204
|
+
radius: Radius of ball around instance to sample from
|
|
205
|
+
seed: Random seed
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
Estimated local Lipschitz constant (lower = smoother)
|
|
209
|
+
"""
|
|
210
|
+
if seed is not None:
|
|
211
|
+
np.random.seed(seed)
|
|
212
|
+
|
|
213
|
+
instance = np.asarray(instance).flatten()
|
|
214
|
+
n_features = len(instance)
|
|
215
|
+
|
|
216
|
+
max_ratio = 0.0
|
|
217
|
+
|
|
218
|
+
for _ in range(n_samples):
|
|
219
|
+
# Generate two random points in a ball around instance
|
|
220
|
+
direction1 = np.random.randn(n_features)
|
|
221
|
+
direction1 = direction1 / np.linalg.norm(direction1)
|
|
222
|
+
r1 = np.random.uniform(0, radius)
|
|
223
|
+
point1 = instance + r1 * direction1
|
|
224
|
+
|
|
225
|
+
direction2 = np.random.randn(n_features)
|
|
226
|
+
direction2 = direction2 / np.linalg.norm(direction2)
|
|
227
|
+
r2 = np.random.uniform(0, radius)
|
|
228
|
+
point2 = instance + r2 * direction2
|
|
229
|
+
|
|
230
|
+
try:
|
|
231
|
+
exp1 = explainer.explain(point1)
|
|
232
|
+
exp1.feature_names = [f"feature_{i}" for i in range(n_features)]
|
|
233
|
+
attr1 = _extract_attribution_vector(exp1)
|
|
234
|
+
|
|
235
|
+
exp2 = explainer.explain(point2)
|
|
236
|
+
exp2.feature_names = exp1.feature_names
|
|
237
|
+
attr2 = _extract_attribution_vector(exp2)
|
|
238
|
+
except Exception:
|
|
239
|
+
continue
|
|
240
|
+
|
|
241
|
+
attr_diff = np.linalg.norm(attr1 - attr2)
|
|
242
|
+
input_diff = np.linalg.norm(point1 - point2)
|
|
243
|
+
|
|
244
|
+
if input_diff > 1e-10:
|
|
245
|
+
ratio = attr_diff / input_diff
|
|
246
|
+
max_ratio = max(max_ratio, ratio)
|
|
247
|
+
|
|
248
|
+
return float(max_ratio)
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def compute_stability_metrics(
|
|
252
|
+
explainer: BaseExplainer,
|
|
253
|
+
model,
|
|
254
|
+
instance: np.ndarray,
|
|
255
|
+
background_data: np.ndarray,
|
|
256
|
+
n_perturbations: int = 10,
|
|
257
|
+
noise_scale: float = 0.01,
|
|
258
|
+
n_neighbors: int = 5,
|
|
259
|
+
seed: int = None,
|
|
260
|
+
) -> Dict[str, float]:
|
|
261
|
+
"""
|
|
262
|
+
Compute comprehensive stability metrics for a single instance.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
explainer: Explainer instance
|
|
266
|
+
model: Model adapter
|
|
267
|
+
instance: Query instance
|
|
268
|
+
background_data: Reference data for ROS computation
|
|
269
|
+
n_perturbations: Number of perturbations for RIS
|
|
270
|
+
noise_scale: Noise scale for RIS
|
|
271
|
+
n_neighbors: Number of neighbors for ROS
|
|
272
|
+
seed: Random seed
|
|
273
|
+
|
|
274
|
+
Returns:
|
|
275
|
+
Dictionary with RIS, ROS, and Lipschitz estimate
|
|
276
|
+
"""
|
|
277
|
+
return {
|
|
278
|
+
"ris": compute_ris(explainer, instance, n_perturbations, noise_scale, seed),
|
|
279
|
+
"ros": compute_ros(explainer, model, instance, background_data, n_neighbors),
|
|
280
|
+
"lipschitz": compute_lipschitz_estimate(explainer, instance, seed=seed),
|
|
281
|
+
}
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def compute_batch_stability(
|
|
285
|
+
explainer: BaseExplainer,
|
|
286
|
+
model,
|
|
287
|
+
X: np.ndarray,
|
|
288
|
+
n_perturbations: int = 10,
|
|
289
|
+
noise_scale: float = 0.01,
|
|
290
|
+
max_samples: int = None,
|
|
291
|
+
seed: int = None,
|
|
292
|
+
) -> Dict[str, float]:
|
|
293
|
+
"""
|
|
294
|
+
Compute average stability metrics over a batch of instances.
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
explainer: Explainer instance
|
|
298
|
+
model: Model adapter
|
|
299
|
+
X: Input data (2D array)
|
|
300
|
+
n_perturbations: Number of perturbations per instance
|
|
301
|
+
noise_scale: Noise scale for perturbations
|
|
302
|
+
max_samples: Maximum number of samples to evaluate
|
|
303
|
+
seed: Random seed
|
|
304
|
+
|
|
305
|
+
Returns:
|
|
306
|
+
Dictionary with mean and std of stability metrics
|
|
307
|
+
"""
|
|
308
|
+
n_samples = len(X)
|
|
309
|
+
if max_samples:
|
|
310
|
+
n_samples = min(n_samples, max_samples)
|
|
311
|
+
|
|
312
|
+
ris_scores = []
|
|
313
|
+
ros_scores = []
|
|
314
|
+
|
|
315
|
+
for i in range(n_samples):
|
|
316
|
+
instance = X[i]
|
|
317
|
+
|
|
318
|
+
try:
|
|
319
|
+
ris = compute_ris(explainer, instance, n_perturbations, noise_scale, seed)
|
|
320
|
+
if not np.isinf(ris):
|
|
321
|
+
ris_scores.append(ris)
|
|
322
|
+
|
|
323
|
+
ros = compute_ros(explainer, model, instance, X, n_neighbors=5)
|
|
324
|
+
ros_scores.append(ros)
|
|
325
|
+
except Exception:
|
|
326
|
+
continue
|
|
327
|
+
|
|
328
|
+
results = {"n_samples": len(ris_scores)}
|
|
329
|
+
|
|
330
|
+
if ris_scores:
|
|
331
|
+
results["mean_ris"] = np.mean(ris_scores)
|
|
332
|
+
results["std_ris"] = np.std(ris_scores)
|
|
333
|
+
else:
|
|
334
|
+
results["mean_ris"] = float('inf')
|
|
335
|
+
results["std_ris"] = 0.0
|
|
336
|
+
|
|
337
|
+
if ros_scores:
|
|
338
|
+
results["mean_ros"] = np.mean(ros_scores)
|
|
339
|
+
results["std_ros"] = np.std(ros_scores)
|
|
340
|
+
else:
|
|
341
|
+
results["mean_ros"] = 0.0
|
|
342
|
+
results["std_ros"] = 0.0
|
|
343
|
+
|
|
344
|
+
return results
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
def compare_explainer_stability(
|
|
348
|
+
explainers: Dict[str, BaseExplainer],
|
|
349
|
+
model,
|
|
350
|
+
X: np.ndarray,
|
|
351
|
+
n_perturbations: int = 5,
|
|
352
|
+
noise_scale: float = 0.01,
|
|
353
|
+
max_samples: int = 20,
|
|
354
|
+
seed: int = None,
|
|
355
|
+
) -> Dict[str, Dict[str, float]]:
|
|
356
|
+
"""
|
|
357
|
+
Compare stability metrics across multiple explainers.
|
|
358
|
+
|
|
359
|
+
Args:
|
|
360
|
+
explainers: Dict mapping explainer names to explainer instances
|
|
361
|
+
model: Model adapter
|
|
362
|
+
X: Input data
|
|
363
|
+
n_perturbations: Number of perturbations per instance
|
|
364
|
+
noise_scale: Noise scale
|
|
365
|
+
max_samples: Max samples to evaluate per explainer
|
|
366
|
+
seed: Random seed
|
|
367
|
+
|
|
368
|
+
Returns:
|
|
369
|
+
Dict mapping explainer names to their stability metrics
|
|
370
|
+
"""
|
|
371
|
+
results = {}
|
|
372
|
+
|
|
373
|
+
for name, explainer in explainers.items():
|
|
374
|
+
metrics = compute_batch_stability(
|
|
375
|
+
explainer, model, X, n_perturbations, noise_scale, max_samples, seed
|
|
376
|
+
)
|
|
377
|
+
results[name] = metrics
|
|
378
|
+
|
|
379
|
+
return results
|
|
@@ -9,12 +9,17 @@ Local Explainers (instance-level):
|
|
|
9
9
|
- Anchors: High-precision rule-based explanations
|
|
10
10
|
- Counterfactual: Diverse counterfactual explanations
|
|
11
11
|
- Integrated Gradients: Gradient-based attributions for neural networks
|
|
12
|
+
- DeepLIFT: Reference-based attributions for neural networks
|
|
13
|
+
- DeepSHAP: DeepLIFT combined with SHAP for neural networks
|
|
12
14
|
|
|
13
15
|
Global Explainers (model-level):
|
|
14
16
|
- Permutation Importance: Feature importance via permutation
|
|
15
17
|
- Partial Dependence: Marginal feature effects (PDP)
|
|
16
18
|
- ALE: Accumulated Local Effects (unbiased for correlated features)
|
|
17
19
|
- SAGE: Shapley Additive Global importancE
|
|
20
|
+
|
|
21
|
+
Example-Based Explainers:
|
|
22
|
+
- ProtoDash: Prototype selection with importance weights
|
|
18
23
|
"""
|
|
19
24
|
|
|
20
25
|
from explainiverse.explainers.attribution.lime_wrapper import LimeExplainer
|
|
@@ -29,6 +34,7 @@ from explainiverse.explainers.global_explainers.sage import SAGEExplainer
|
|
|
29
34
|
from explainiverse.explainers.gradient.integrated_gradients import IntegratedGradientsExplainer
|
|
30
35
|
from explainiverse.explainers.gradient.gradcam import GradCAMExplainer
|
|
31
36
|
from explainiverse.explainers.gradient.deeplift import DeepLIFTExplainer, DeepLIFTShapExplainer
|
|
37
|
+
from explainiverse.explainers.example_based.protodash import ProtoDashExplainer
|
|
32
38
|
|
|
33
39
|
__all__ = [
|
|
34
40
|
# Local explainers
|
|
@@ -46,4 +52,6 @@ __all__ = [
|
|
|
46
52
|
"PartialDependenceExplainer",
|
|
47
53
|
"ALEExplainer",
|
|
48
54
|
"SAGEExplainer",
|
|
55
|
+
# Example-based explainers
|
|
56
|
+
"ProtoDashExplainer",
|
|
49
57
|
]
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# src/explainiverse/explainers/example_based/__init__.py
|
|
2
|
+
"""
|
|
3
|
+
Example-based explanation methods.
|
|
4
|
+
|
|
5
|
+
These methods explain models by identifying representative examples
|
|
6
|
+
from the training data, rather than computing feature attributions.
|
|
7
|
+
|
|
8
|
+
Methods:
|
|
9
|
+
- ProtoDash: Select prototypical examples with importance weights
|
|
10
|
+
- (Future) Influence Functions: Identify training examples that most affect predictions
|
|
11
|
+
- (Future) MMD-Critic: Find prototypes and criticisms
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from explainiverse.explainers.example_based.protodash import ProtoDashExplainer
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"ProtoDashExplainer",
|
|
18
|
+
]
|