explainiverse 0.2.5__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,379 @@
1
+ # src/explainiverse/evaluation/stability.py
2
+ """
3
+ Stability evaluation metrics for explanations.
4
+
5
+ Implements:
6
+ - RIS (Relative Input Stability) - sensitivity to input perturbations
7
+ - ROS (Relative Output Stability) - consistency with similar predictions
8
+ - Lipschitz Estimate - local smoothness of explanations
9
+ """
10
+ import numpy as np
11
+ from typing import Union, Callable, List, Dict, Optional, Tuple
12
+ from explainiverse.core.explanation import Explanation
13
+ from explainiverse.core.explainer import BaseExplainer
14
+ from explainiverse.evaluation._utils import get_prediction_value
15
+
16
+
17
+ def _extract_attribution_vector(explanation: Explanation) -> np.ndarray:
18
+ """
19
+ Extract attribution values as a numpy array from an Explanation.
20
+
21
+ Args:
22
+ explanation: Explanation object with feature_attributions
23
+
24
+ Returns:
25
+ 1D numpy array of attribution values
26
+ """
27
+ attributions = explanation.explanation_data.get("feature_attributions", {})
28
+ if not attributions:
29
+ raise ValueError("No feature attributions found in explanation.")
30
+
31
+ # Get values in consistent order
32
+ feature_names = getattr(explanation, 'feature_names', None)
33
+ if feature_names:
34
+ values = [attributions.get(fn, 0.0) for fn in feature_names]
35
+ else:
36
+ values = list(attributions.values())
37
+
38
+ return np.array(values, dtype=float)
39
+
40
+
41
+ def _normalize_vector(v: np.ndarray, epsilon: float = 1e-10) -> np.ndarray:
42
+ """Normalize a vector to unit length."""
43
+ norm = np.linalg.norm(v)
44
+ if norm < epsilon:
45
+ return v
46
+ return v / norm
47
+
48
+
49
+ def compute_ris(
50
+ explainer: BaseExplainer,
51
+ instance: np.ndarray,
52
+ n_perturbations: int = 10,
53
+ noise_scale: float = 0.01,
54
+ seed: int = None,
55
+ ) -> float:
56
+ """
57
+ Compute Relative Input Stability (RIS).
58
+
59
+ Measures how stable explanations are to small perturbations in the input.
60
+ Lower RIS indicates more stable explanations.
61
+
62
+ RIS = mean(||E(x) - E(x')|| / ||x - x'||) for perturbed inputs x'
63
+
64
+ Args:
65
+ explainer: Explainer instance with .explain() method
66
+ instance: Original input instance (1D array)
67
+ n_perturbations: Number of perturbed samples to generate
68
+ noise_scale: Standard deviation of Gaussian noise (relative to feature range)
69
+ seed: Random seed for reproducibility
70
+
71
+ Returns:
72
+ RIS score (lower = more stable)
73
+ """
74
+ if seed is not None:
75
+ np.random.seed(seed)
76
+
77
+ instance = np.asarray(instance).flatten()
78
+ n_features = len(instance)
79
+
80
+ # Get original explanation
81
+ original_exp = explainer.explain(instance)
82
+ original_exp.feature_names = getattr(original_exp, 'feature_names', None) or \
83
+ [f"feature_{i}" for i in range(n_features)]
84
+ original_attr = _extract_attribution_vector(original_exp)
85
+
86
+ ratios = []
87
+
88
+ for _ in range(n_perturbations):
89
+ # Generate perturbed input
90
+ noise = np.random.normal(0, noise_scale, n_features)
91
+ perturbed = instance + noise * np.abs(instance + 1e-10) # Scale noise by feature magnitude
92
+
93
+ # Get perturbed explanation
94
+ try:
95
+ perturbed_exp = explainer.explain(perturbed)
96
+ perturbed_exp.feature_names = original_exp.feature_names
97
+ perturbed_attr = _extract_attribution_vector(perturbed_exp)
98
+ except Exception:
99
+ continue
100
+
101
+ # Compute ratio of changes
102
+ attr_diff = np.linalg.norm(original_attr - perturbed_attr)
103
+ input_diff = np.linalg.norm(instance - perturbed)
104
+
105
+ if input_diff > 1e-10:
106
+ ratios.append(attr_diff / input_diff)
107
+
108
+ if not ratios:
109
+ return float('inf')
110
+
111
+ return float(np.mean(ratios))
112
+
113
+
114
+ def compute_ros(
115
+ explainer: BaseExplainer,
116
+ model,
117
+ instance: np.ndarray,
118
+ reference_instances: np.ndarray,
119
+ n_neighbors: int = 5,
120
+ prediction_threshold: float = 0.05,
121
+ ) -> float:
122
+ """
123
+ Compute Relative Output Stability (ROS).
124
+
125
+ Measures how similar explanations are for instances with similar predictions.
126
+ Higher ROS indicates more consistent explanations.
127
+
128
+ Args:
129
+ explainer: Explainer instance with .explain() method
130
+ model: Model adapter with predict/predict_proba method
131
+ instance: Query instance
132
+ reference_instances: Pool of reference instances to find neighbors
133
+ n_neighbors: Number of neighbors to compare
134
+ prediction_threshold: Maximum prediction difference to consider "similar"
135
+
136
+ Returns:
137
+ ROS score (higher = more consistent for similar predictions)
138
+ """
139
+ instance = np.asarray(instance).flatten()
140
+ n_features = len(instance)
141
+
142
+ # Get prediction for query instance
143
+ query_pred = get_prediction_value(model, instance)
144
+
145
+ # Find instances with similar predictions
146
+ similar_instances = []
147
+ for ref in reference_instances:
148
+ ref = np.asarray(ref).flatten()
149
+ ref_pred = get_prediction_value(model, ref)
150
+ if abs(query_pred - ref_pred) <= prediction_threshold:
151
+ similar_instances.append(ref)
152
+
153
+ if len(similar_instances) < 2:
154
+ return 1.0 # Perfect stability if no similar instances
155
+
156
+ # Limit to n_neighbors
157
+ similar_instances = similar_instances[:n_neighbors]
158
+
159
+ # Get explanation for query
160
+ query_exp = explainer.explain(instance)
161
+ query_exp.feature_names = getattr(query_exp, 'feature_names', None) or \
162
+ [f"feature_{i}" for i in range(n_features)]
163
+ query_attr = _normalize_vector(_extract_attribution_vector(query_exp))
164
+
165
+ # Get explanations for similar instances and compute similarity
166
+ similarities = []
167
+ for ref in similar_instances:
168
+ try:
169
+ ref_exp = explainer.explain(ref)
170
+ ref_exp.feature_names = query_exp.feature_names
171
+ ref_attr = _normalize_vector(_extract_attribution_vector(ref_exp))
172
+
173
+ # Cosine similarity
174
+ similarity = np.dot(query_attr, ref_attr)
175
+ similarities.append(similarity)
176
+ except Exception:
177
+ continue
178
+
179
+ if not similarities:
180
+ return 1.0
181
+
182
+ return float(np.mean(similarities))
183
+
184
+
185
+ def compute_lipschitz_estimate(
186
+ explainer: BaseExplainer,
187
+ instance: np.ndarray,
188
+ n_samples: int = 20,
189
+ radius: float = 0.1,
190
+ seed: int = None,
191
+ ) -> float:
192
+ """
193
+ Estimate local Lipschitz constant of the explanation function.
194
+
195
+ The Lipschitz constant bounds how fast explanations can change:
196
+ ||E(x) - E(y)|| <= L * ||x - y||
197
+
198
+ Lower L indicates smoother, more stable explanations.
199
+
200
+ Args:
201
+ explainer: Explainer instance
202
+ instance: Center point for local estimate
203
+ n_samples: Number of sample pairs to evaluate
204
+ radius: Radius of ball around instance to sample from
205
+ seed: Random seed
206
+
207
+ Returns:
208
+ Estimated local Lipschitz constant (lower = smoother)
209
+ """
210
+ if seed is not None:
211
+ np.random.seed(seed)
212
+
213
+ instance = np.asarray(instance).flatten()
214
+ n_features = len(instance)
215
+
216
+ max_ratio = 0.0
217
+
218
+ for _ in range(n_samples):
219
+ # Generate two random points in a ball around instance
220
+ direction1 = np.random.randn(n_features)
221
+ direction1 = direction1 / np.linalg.norm(direction1)
222
+ r1 = np.random.uniform(0, radius)
223
+ point1 = instance + r1 * direction1
224
+
225
+ direction2 = np.random.randn(n_features)
226
+ direction2 = direction2 / np.linalg.norm(direction2)
227
+ r2 = np.random.uniform(0, radius)
228
+ point2 = instance + r2 * direction2
229
+
230
+ try:
231
+ exp1 = explainer.explain(point1)
232
+ exp1.feature_names = [f"feature_{i}" for i in range(n_features)]
233
+ attr1 = _extract_attribution_vector(exp1)
234
+
235
+ exp2 = explainer.explain(point2)
236
+ exp2.feature_names = exp1.feature_names
237
+ attr2 = _extract_attribution_vector(exp2)
238
+ except Exception:
239
+ continue
240
+
241
+ attr_diff = np.linalg.norm(attr1 - attr2)
242
+ input_diff = np.linalg.norm(point1 - point2)
243
+
244
+ if input_diff > 1e-10:
245
+ ratio = attr_diff / input_diff
246
+ max_ratio = max(max_ratio, ratio)
247
+
248
+ return float(max_ratio)
249
+
250
+
251
+ def compute_stability_metrics(
252
+ explainer: BaseExplainer,
253
+ model,
254
+ instance: np.ndarray,
255
+ background_data: np.ndarray,
256
+ n_perturbations: int = 10,
257
+ noise_scale: float = 0.01,
258
+ n_neighbors: int = 5,
259
+ seed: int = None,
260
+ ) -> Dict[str, float]:
261
+ """
262
+ Compute comprehensive stability metrics for a single instance.
263
+
264
+ Args:
265
+ explainer: Explainer instance
266
+ model: Model adapter
267
+ instance: Query instance
268
+ background_data: Reference data for ROS computation
269
+ n_perturbations: Number of perturbations for RIS
270
+ noise_scale: Noise scale for RIS
271
+ n_neighbors: Number of neighbors for ROS
272
+ seed: Random seed
273
+
274
+ Returns:
275
+ Dictionary with RIS, ROS, and Lipschitz estimate
276
+ """
277
+ return {
278
+ "ris": compute_ris(explainer, instance, n_perturbations, noise_scale, seed),
279
+ "ros": compute_ros(explainer, model, instance, background_data, n_neighbors),
280
+ "lipschitz": compute_lipschitz_estimate(explainer, instance, seed=seed),
281
+ }
282
+
283
+
284
+ def compute_batch_stability(
285
+ explainer: BaseExplainer,
286
+ model,
287
+ X: np.ndarray,
288
+ n_perturbations: int = 10,
289
+ noise_scale: float = 0.01,
290
+ max_samples: int = None,
291
+ seed: int = None,
292
+ ) -> Dict[str, float]:
293
+ """
294
+ Compute average stability metrics over a batch of instances.
295
+
296
+ Args:
297
+ explainer: Explainer instance
298
+ model: Model adapter
299
+ X: Input data (2D array)
300
+ n_perturbations: Number of perturbations per instance
301
+ noise_scale: Noise scale for perturbations
302
+ max_samples: Maximum number of samples to evaluate
303
+ seed: Random seed
304
+
305
+ Returns:
306
+ Dictionary with mean and std of stability metrics
307
+ """
308
+ n_samples = len(X)
309
+ if max_samples:
310
+ n_samples = min(n_samples, max_samples)
311
+
312
+ ris_scores = []
313
+ ros_scores = []
314
+
315
+ for i in range(n_samples):
316
+ instance = X[i]
317
+
318
+ try:
319
+ ris = compute_ris(explainer, instance, n_perturbations, noise_scale, seed)
320
+ if not np.isinf(ris):
321
+ ris_scores.append(ris)
322
+
323
+ ros = compute_ros(explainer, model, instance, X, n_neighbors=5)
324
+ ros_scores.append(ros)
325
+ except Exception:
326
+ continue
327
+
328
+ results = {"n_samples": len(ris_scores)}
329
+
330
+ if ris_scores:
331
+ results["mean_ris"] = np.mean(ris_scores)
332
+ results["std_ris"] = np.std(ris_scores)
333
+ else:
334
+ results["mean_ris"] = float('inf')
335
+ results["std_ris"] = 0.0
336
+
337
+ if ros_scores:
338
+ results["mean_ros"] = np.mean(ros_scores)
339
+ results["std_ros"] = np.std(ros_scores)
340
+ else:
341
+ results["mean_ros"] = 0.0
342
+ results["std_ros"] = 0.0
343
+
344
+ return results
345
+
346
+
347
+ def compare_explainer_stability(
348
+ explainers: Dict[str, BaseExplainer],
349
+ model,
350
+ X: np.ndarray,
351
+ n_perturbations: int = 5,
352
+ noise_scale: float = 0.01,
353
+ max_samples: int = 20,
354
+ seed: int = None,
355
+ ) -> Dict[str, Dict[str, float]]:
356
+ """
357
+ Compare stability metrics across multiple explainers.
358
+
359
+ Args:
360
+ explainers: Dict mapping explainer names to explainer instances
361
+ model: Model adapter
362
+ X: Input data
363
+ n_perturbations: Number of perturbations per instance
364
+ noise_scale: Noise scale
365
+ max_samples: Max samples to evaluate per explainer
366
+ seed: Random seed
367
+
368
+ Returns:
369
+ Dict mapping explainer names to their stability metrics
370
+ """
371
+ results = {}
372
+
373
+ for name, explainer in explainers.items():
374
+ metrics = compute_batch_stability(
375
+ explainer, model, X, n_perturbations, noise_scale, max_samples, seed
376
+ )
377
+ results[name] = metrics
378
+
379
+ return results
@@ -9,12 +9,17 @@ Local Explainers (instance-level):
9
9
  - Anchors: High-precision rule-based explanations
10
10
  - Counterfactual: Diverse counterfactual explanations
11
11
  - Integrated Gradients: Gradient-based attributions for neural networks
12
+ - DeepLIFT: Reference-based attributions for neural networks
13
+ - DeepSHAP: DeepLIFT combined with SHAP for neural networks
12
14
 
13
15
  Global Explainers (model-level):
14
16
  - Permutation Importance: Feature importance via permutation
15
17
  - Partial Dependence: Marginal feature effects (PDP)
16
18
  - ALE: Accumulated Local Effects (unbiased for correlated features)
17
19
  - SAGE: Shapley Additive Global importancE
20
+
21
+ Example-Based Explainers:
22
+ - ProtoDash: Prototype selection with importance weights
18
23
  """
19
24
 
20
25
  from explainiverse.explainers.attribution.lime_wrapper import LimeExplainer
@@ -29,6 +34,7 @@ from explainiverse.explainers.global_explainers.sage import SAGEExplainer
29
34
  from explainiverse.explainers.gradient.integrated_gradients import IntegratedGradientsExplainer
30
35
  from explainiverse.explainers.gradient.gradcam import GradCAMExplainer
31
36
  from explainiverse.explainers.gradient.deeplift import DeepLIFTExplainer, DeepLIFTShapExplainer
37
+ from explainiverse.explainers.example_based.protodash import ProtoDashExplainer
32
38
 
33
39
  __all__ = [
34
40
  # Local explainers
@@ -46,4 +52,6 @@ __all__ = [
46
52
  "PartialDependenceExplainer",
47
53
  "ALEExplainer",
48
54
  "SAGEExplainer",
55
+ # Example-based explainers
56
+ "ProtoDashExplainer",
49
57
  ]
@@ -0,0 +1,18 @@
1
+ # src/explainiverse/explainers/example_based/__init__.py
2
+ """
3
+ Example-based explanation methods.
4
+
5
+ These methods explain models by identifying representative examples
6
+ from the training data, rather than computing feature attributions.
7
+
8
+ Methods:
9
+ - ProtoDash: Select prototypical examples with importance weights
10
+ - (Future) Influence Functions: Identify training examples that most affect predictions
11
+ - (Future) MMD-Critic: Find prototypes and criticisms
12
+ """
13
+
14
+ from explainiverse.explainers.example_based.protodash import ProtoDashExplainer
15
+
16
+ __all__ = [
17
+ "ProtoDashExplainer",
18
+ ]