explainiverse 0.2.5__tar.gz → 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {explainiverse-0.2.5 → explainiverse-0.3.0}/PKG-INFO +2 -1
  2. {explainiverse-0.2.5 → explainiverse-0.3.0}/pyproject.toml +2 -1
  3. explainiverse-0.3.0/src/explainiverse/evaluation/__init__.py +60 -0
  4. explainiverse-0.3.0/src/explainiverse/evaluation/_utils.py +325 -0
  5. explainiverse-0.3.0/src/explainiverse/evaluation/faithfulness.py +428 -0
  6. explainiverse-0.3.0/src/explainiverse/evaluation/stability.py +379 -0
  7. explainiverse-0.2.5/src/explainiverse/evaluation/__init__.py +0 -8
  8. {explainiverse-0.2.5 → explainiverse-0.3.0}/LICENSE +0 -0
  9. {explainiverse-0.2.5 → explainiverse-0.3.0}/README.md +0 -0
  10. {explainiverse-0.2.5 → explainiverse-0.3.0}/src/explainiverse/__init__.py +0 -0
  11. {explainiverse-0.2.5 → explainiverse-0.3.0}/src/explainiverse/adapters/__init__.py +0 -0
  12. {explainiverse-0.2.5 → explainiverse-0.3.0}/src/explainiverse/adapters/base_adapter.py +0 -0
  13. {explainiverse-0.2.5 → explainiverse-0.3.0}/src/explainiverse/adapters/pytorch_adapter.py +0 -0
  14. {explainiverse-0.2.5 → explainiverse-0.3.0}/src/explainiverse/adapters/sklearn_adapter.py +0 -0
  15. {explainiverse-0.2.5 → explainiverse-0.3.0}/src/explainiverse/core/__init__.py +0 -0
  16. {explainiverse-0.2.5 → explainiverse-0.3.0}/src/explainiverse/core/explainer.py +0 -0
  17. {explainiverse-0.2.5 → explainiverse-0.3.0}/src/explainiverse/core/explanation.py +0 -0
  18. {explainiverse-0.2.5 → explainiverse-0.3.0}/src/explainiverse/core/registry.py +0 -0
  19. {explainiverse-0.2.5 → explainiverse-0.3.0}/src/explainiverse/engine/__init__.py +0 -0
  20. {explainiverse-0.2.5 → explainiverse-0.3.0}/src/explainiverse/engine/suite.py +0 -0
  21. {explainiverse-0.2.5 → explainiverse-0.3.0}/src/explainiverse/evaluation/metrics.py +0 -0
  22. {explainiverse-0.2.5 → explainiverse-0.3.0}/src/explainiverse/explainers/__init__.py +0 -0
  23. {explainiverse-0.2.5 → explainiverse-0.3.0}/src/explainiverse/explainers/attribution/__init__.py +0 -0
  24. {explainiverse-0.2.5 → explainiverse-0.3.0}/src/explainiverse/explainers/attribution/lime_wrapper.py +0 -0
  25. {explainiverse-0.2.5 → explainiverse-0.3.0}/src/explainiverse/explainers/attribution/shap_wrapper.py +0 -0
  26. {explainiverse-0.2.5 → explainiverse-0.3.0}/src/explainiverse/explainers/attribution/treeshap_wrapper.py +0 -0
  27. {explainiverse-0.2.5 → explainiverse-0.3.0}/src/explainiverse/explainers/counterfactual/__init__.py +0 -0
  28. {explainiverse-0.2.5 → explainiverse-0.3.0}/src/explainiverse/explainers/counterfactual/dice_wrapper.py +0 -0
  29. {explainiverse-0.2.5 → explainiverse-0.3.0}/src/explainiverse/explainers/global_explainers/__init__.py +0 -0
  30. {explainiverse-0.2.5 → explainiverse-0.3.0}/src/explainiverse/explainers/global_explainers/ale.py +0 -0
  31. {explainiverse-0.2.5 → explainiverse-0.3.0}/src/explainiverse/explainers/global_explainers/partial_dependence.py +0 -0
  32. {explainiverse-0.2.5 → explainiverse-0.3.0}/src/explainiverse/explainers/global_explainers/permutation_importance.py +0 -0
  33. {explainiverse-0.2.5 → explainiverse-0.3.0}/src/explainiverse/explainers/global_explainers/sage.py +0 -0
  34. {explainiverse-0.2.5 → explainiverse-0.3.0}/src/explainiverse/explainers/gradient/__init__.py +0 -0
  35. {explainiverse-0.2.5 → explainiverse-0.3.0}/src/explainiverse/explainers/gradient/deeplift.py +0 -0
  36. {explainiverse-0.2.5 → explainiverse-0.3.0}/src/explainiverse/explainers/gradient/gradcam.py +0 -0
  37. {explainiverse-0.2.5 → explainiverse-0.3.0}/src/explainiverse/explainers/gradient/integrated_gradients.py +0 -0
  38. {explainiverse-0.2.5 → explainiverse-0.3.0}/src/explainiverse/explainers/rule_based/__init__.py +0 -0
  39. {explainiverse-0.2.5 → explainiverse-0.3.0}/src/explainiverse/explainers/rule_based/anchors_wrapper.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: explainiverse
3
- Version: 0.2.5
3
+ Version: 0.3.0
4
4
  Summary: Unified, extensible explainability framework supporting LIME, SHAP, Anchors, Counterfactuals, PDP, ALE, SAGE, and more
5
5
  Home-page: https://github.com/jemsbhai/explainiverse
6
6
  License: MIT
@@ -20,6 +20,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
20
  Provides-Extra: torch
21
21
  Requires-Dist: lime (>=0.2.0.1,<0.3.0.0)
22
22
  Requires-Dist: numpy (>=1.24,<2.0)
23
+ Requires-Dist: pandas (>=1.5,<3.0)
23
24
  Requires-Dist: scikit-learn (>=1.1,<1.6)
24
25
  Requires-Dist: scipy (>=1.10,<2.0)
25
26
  Requires-Dist: shap (>=0.48.0,<0.49.0)
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "explainiverse"
3
- version = "0.2.5"
3
+ version = "0.3.0"
4
4
  description = "Unified, extensible explainability framework supporting LIME, SHAP, Anchors, Counterfactuals, PDP, ALE, SAGE, and more"
5
5
  authors = ["Muntaser Syed <jemsbhai@gmail.com>"]
6
6
  license = "MIT"
@@ -27,6 +27,7 @@ numpy = ">=1.24,<2.0"
27
27
  lime = "^0.2.0.1"
28
28
  scikit-learn = ">=1.1,<1.6"
29
29
  shap = "^0.48.0"
30
+ pandas = ">=1.5,<3.0"
30
31
  scipy = ">=1.10,<2.0"
31
32
  xgboost = ">=1.7,<3.0"
32
33
  torch = { version = ">=2.0", optional = true }
@@ -0,0 +1,60 @@
1
+ # src/explainiverse/evaluation/__init__.py
2
+ """
3
+ Evaluation metrics for explanation quality.
4
+
5
+ Includes:
6
+ - Faithfulness metrics (PGI, PGU, Comprehensiveness, Sufficiency)
7
+ - Stability metrics (RIS, ROS, Lipschitz)
8
+ - Perturbation metrics (AOPC, ROAR)
9
+ """
10
+
11
+ from explainiverse.evaluation.metrics import (
12
+ compute_aopc,
13
+ compute_batch_aopc,
14
+ compute_roar,
15
+ compute_roar_curve,
16
+ )
17
+
18
+ from explainiverse.evaluation.faithfulness import (
19
+ compute_pgi,
20
+ compute_pgu,
21
+ compute_faithfulness_score,
22
+ compute_comprehensiveness,
23
+ compute_sufficiency,
24
+ compute_faithfulness_correlation,
25
+ compare_explainer_faithfulness,
26
+ compute_batch_faithfulness,
27
+ )
28
+
29
+ from explainiverse.evaluation.stability import (
30
+ compute_ris,
31
+ compute_ros,
32
+ compute_lipschitz_estimate,
33
+ compute_stability_metrics,
34
+ compute_batch_stability,
35
+ compare_explainer_stability,
36
+ )
37
+
38
+ __all__ = [
39
+ # Perturbation metrics (existing)
40
+ "compute_aopc",
41
+ "compute_batch_aopc",
42
+ "compute_roar",
43
+ "compute_roar_curve",
44
+ # Faithfulness metrics (new)
45
+ "compute_pgi",
46
+ "compute_pgu",
47
+ "compute_faithfulness_score",
48
+ "compute_comprehensiveness",
49
+ "compute_sufficiency",
50
+ "compute_faithfulness_correlation",
51
+ "compare_explainer_faithfulness",
52
+ "compute_batch_faithfulness",
53
+ # Stability metrics (new)
54
+ "compute_ris",
55
+ "compute_ros",
56
+ "compute_lipschitz_estimate",
57
+ "compute_stability_metrics",
58
+ "compute_batch_stability",
59
+ "compare_explainer_stability",
60
+ ]
@@ -0,0 +1,325 @@
1
+ # src/explainiverse/evaluation/_utils.py
2
+ """
3
+ Shared utility functions for evaluation metrics.
4
+ """
5
+ import numpy as np
6
+ import re
7
+ from typing import Union, Callable, List, Tuple
8
+ from explainiverse.core.explanation import Explanation
9
+
10
+
11
+ def _extract_base_feature_name(feature_str: str) -> str:
12
+ """
13
+ Extract the base feature name from LIME-style feature strings.
14
+
15
+ LIME returns strings like "petal width (cm) <= 0.80" or "feature_2 > 3.5".
16
+ This extracts just the feature name part.
17
+
18
+ Args:
19
+ feature_str: Feature string possibly with conditions
20
+
21
+ Returns:
22
+ Base feature name
23
+ """
24
+ # Remove comparison operators and values
25
+ # Pattern matches: name <= value, name < value, name >= value, name > value, name = value
26
+ patterns = [
27
+ r'^(.+?)\s*<=\s*[\d\.\-]+$',
28
+ r'^(.+?)\s*>=\s*[\d\.\-]+$',
29
+ r'^(.+?)\s*<\s*[\d\.\-]+$',
30
+ r'^(.+?)\s*>\s*[\d\.\-]+$',
31
+ r'^(.+?)\s*=\s*[\d\.\-]+$',
32
+ ]
33
+
34
+ for pattern in patterns:
35
+ match = re.match(pattern, feature_str.strip())
36
+ if match:
37
+ return match.group(1).strip()
38
+
39
+ # No match found, return as-is
40
+ return feature_str.strip()
41
+
42
+
43
+ def _match_feature_to_index(
44
+ feature_key: str,
45
+ feature_names: List[str]
46
+ ) -> int:
47
+ """
48
+ Match a feature key (possibly with LIME conditions) to its index.
49
+
50
+ Args:
51
+ feature_key: Feature name from explanation (may include conditions)
52
+ feature_names: List of original feature names
53
+
54
+ Returns:
55
+ Index of the matching feature, or -1 if not found
56
+ """
57
+ # Try exact match first
58
+ if feature_key in feature_names:
59
+ return feature_names.index(feature_key)
60
+
61
+ # Try extracting base name
62
+ base_name = _extract_base_feature_name(feature_key)
63
+ if base_name in feature_names:
64
+ return feature_names.index(base_name)
65
+
66
+ # Try partial matching (feature name is contained in key)
67
+ for i, fname in enumerate(feature_names):
68
+ if fname in feature_key:
69
+ return i
70
+
71
+ # Try index extraction from patterns like "feature_2" or "f2" or "feat_2"
72
+ patterns = [
73
+ r'feature[_\s]*(\d+)',
74
+ r'feat[_\s]*(\d+)',
75
+ r'^f(\d+)$',
76
+ r'^x(\d+)$',
77
+ ]
78
+ for pattern in patterns:
79
+ match = re.search(pattern, feature_key, re.IGNORECASE)
80
+ if match:
81
+ idx = int(match.group(1))
82
+ if 0 <= idx < len(feature_names):
83
+ return idx
84
+
85
+ return -1
86
+
87
+
88
+ def get_sorted_feature_indices(
89
+ explanation: Explanation,
90
+ descending: bool = True
91
+ ) -> List[int]:
92
+ """
93
+ Extract feature indices sorted by absolute attribution value.
94
+
95
+ Handles various feature naming conventions:
96
+ - Clean names: "sepal length", "feature_0"
97
+ - LIME-style: "sepal length <= 5.0", "feature_0 > 2.3"
98
+ - Indexed: "f0", "x1", "feat_2"
99
+
100
+ Args:
101
+ explanation: Explanation object with feature_attributions
102
+ descending: If True, sort from most to least important
103
+
104
+ Returns:
105
+ List of feature indices sorted by importance
106
+ """
107
+ attributions = explanation.explanation_data.get("feature_attributions", {})
108
+
109
+ if not attributions:
110
+ raise ValueError("No feature attributions found in explanation.")
111
+
112
+ # Sort features by absolute importance
113
+ sorted_features = sorted(
114
+ attributions.items(),
115
+ key=lambda x: abs(x[1]),
116
+ reverse=descending
117
+ )
118
+
119
+ # Map feature names to indices
120
+ feature_indices = []
121
+ feature_names = getattr(explanation, 'feature_names', None)
122
+
123
+ for i, (fname, _) in enumerate(sorted_features):
124
+ if feature_names is not None:
125
+ idx = _match_feature_to_index(fname, feature_names)
126
+ if idx >= 0:
127
+ feature_indices.append(idx)
128
+ else:
129
+ # Fallback: use position in sorted list
130
+ feature_indices.append(i % len(feature_names))
131
+ else:
132
+ # No feature_names available - try to extract index from name
133
+ patterns = [
134
+ r'feature[_\s]*(\d+)',
135
+ r'feat[_\s]*(\d+)',
136
+ r'^f(\d+)',
137
+ r'^x(\d+)',
138
+ ]
139
+ found = False
140
+ for pattern in patterns:
141
+ match = re.search(pattern, fname, re.IGNORECASE)
142
+ if match:
143
+ feature_indices.append(int(match.group(1)))
144
+ found = True
145
+ break
146
+ if not found:
147
+ feature_indices.append(i)
148
+
149
+ return feature_indices
150
+
151
+
152
+ def compute_baseline_values(
153
+ baseline: Union[str, float, np.ndarray, Callable],
154
+ background_data: np.ndarray = None,
155
+ n_features: int = None
156
+ ) -> np.ndarray:
157
+ """
158
+ Compute per-feature baseline values for perturbation.
159
+
160
+ Args:
161
+ baseline: Baseline specification - one of:
162
+ - "mean": Use mean of background_data
163
+ - "median": Use median of background_data
164
+ - float/int: Use this value for all features
165
+ - np.ndarray: Use these values directly (must match n_features)
166
+ - Callable: Function that takes background_data and returns baseline array
167
+ background_data: Reference data for computing statistics (required for "mean"/"median")
168
+ n_features: Number of features (required if baseline is scalar)
169
+
170
+ Returns:
171
+ 1D numpy array of baseline values, one per feature
172
+ """
173
+ if isinstance(baseline, str):
174
+ if background_data is None:
175
+ raise ValueError(f"background_data required for baseline='{baseline}'")
176
+ if baseline == "mean":
177
+ return np.mean(background_data, axis=0)
178
+ elif baseline == "median":
179
+ return np.median(background_data, axis=0)
180
+ else:
181
+ raise ValueError(f"Unsupported string baseline: {baseline}")
182
+
183
+ elif callable(baseline):
184
+ if background_data is None:
185
+ raise ValueError("background_data required for callable baseline")
186
+ result = baseline(background_data)
187
+ return np.asarray(result)
188
+
189
+ elif isinstance(baseline, np.ndarray):
190
+ return baseline
191
+
192
+ elif isinstance(baseline, (float, int, np.number)):
193
+ if n_features is None:
194
+ raise ValueError("n_features required for scalar baseline")
195
+ return np.full(n_features, baseline)
196
+
197
+ else:
198
+ raise ValueError(f"Invalid baseline type: {type(baseline)}")
199
+
200
+
201
+ def apply_feature_mask(
202
+ instance: np.ndarray,
203
+ feature_indices: List[int],
204
+ baseline_values: np.ndarray
205
+ ) -> np.ndarray:
206
+ """
207
+ Replace specified features with baseline values.
208
+
209
+ Args:
210
+ instance: Original instance (1D array)
211
+ feature_indices: Indices of features to replace
212
+ baseline_values: Per-feature baseline values
213
+
214
+ Returns:
215
+ Modified instance with specified features replaced
216
+ """
217
+ modified = instance.copy()
218
+ for idx in feature_indices:
219
+ if idx < len(modified) and idx < len(baseline_values):
220
+ modified[idx] = baseline_values[idx]
221
+ return modified
222
+
223
+
224
+ def resolve_k(k: Union[int, float], n_features: int) -> int:
225
+ """
226
+ Resolve k to an integer number of features.
227
+
228
+ Args:
229
+ k: Either an integer count or a float fraction (0-1)
230
+ n_features: Total number of features
231
+
232
+ Returns:
233
+ Integer number of features
234
+ """
235
+ if isinstance(k, float) and 0 < k <= 1:
236
+ return max(1, int(k * n_features))
237
+ elif isinstance(k, int) and k > 0:
238
+ return min(k, n_features)
239
+ else:
240
+ raise ValueError(f"k must be positive int or float in (0, 1], got {k}")
241
+
242
+
243
+ def get_prediction_value(
244
+ model,
245
+ instance: np.ndarray,
246
+ output_type: str = "probability"
247
+ ) -> float:
248
+ """
249
+ Get a scalar prediction value from a model.
250
+
251
+ Works with both raw sklearn models and explainiverse adapters.
252
+ For adapters, .predict() typically returns probabilities.
253
+
254
+ Args:
255
+ model: Model adapter with predict/predict_proba methods
256
+ instance: Single instance (1D array)
257
+ output_type: "probability" (max prob) or "class" (predicted class)
258
+
259
+ Returns:
260
+ Scalar prediction value
261
+ """
262
+ instance_2d = instance.reshape(1, -1)
263
+
264
+ if output_type == "probability":
265
+ # Try predict_proba first (raw sklearn model)
266
+ if hasattr(model, 'predict_proba'):
267
+ proba = model.predict_proba(instance_2d)
268
+ if isinstance(proba, np.ndarray):
269
+ if proba.ndim == 2:
270
+ return float(np.max(proba[0]))
271
+ return float(np.max(proba))
272
+ return float(np.max(proba[0]))
273
+
274
+ # Fall back to predict (adapter returns probs from predict)
275
+ pred = model.predict(instance_2d)
276
+ if isinstance(pred, np.ndarray):
277
+ if pred.ndim == 2:
278
+ return float(np.max(pred[0]))
279
+ elif pred.ndim == 1:
280
+ return float(np.max(pred))
281
+ return float(pred[0]) if hasattr(pred, '__getitem__') else float(pred)
282
+
283
+ elif output_type == "class":
284
+ # For class prediction, use argmax of probabilities
285
+ if hasattr(model, 'predict_proba'):
286
+ proba = model.predict_proba(instance_2d)
287
+ return float(np.argmax(proba[0]))
288
+ pred = model.predict(instance_2d)
289
+ if isinstance(pred, np.ndarray) and pred.ndim == 2:
290
+ return float(np.argmax(pred[0]))
291
+ return float(pred[0]) if hasattr(pred, '__getitem__') else float(pred)
292
+
293
+ else:
294
+ raise ValueError(f"Unknown output_type: {output_type}")
295
+
296
+
297
+ def compute_prediction_change(
298
+ model,
299
+ original: np.ndarray,
300
+ perturbed: np.ndarray,
301
+ metric: str = "absolute"
302
+ ) -> float:
303
+ """
304
+ Compute the change in prediction between original and perturbed instances.
305
+
306
+ Args:
307
+ model: Model adapter
308
+ original: Original instance
309
+ perturbed: Perturbed instance
310
+ metric: "absolute" for |p1 - p2|, "relative" for |p1 - p2| / p1
311
+
312
+ Returns:
313
+ Prediction change value
314
+ """
315
+ orig_pred = get_prediction_value(model, original)
316
+ pert_pred = get_prediction_value(model, perturbed)
317
+
318
+ if metric == "absolute":
319
+ return abs(orig_pred - pert_pred)
320
+ elif metric == "relative":
321
+ if abs(orig_pred) < 1e-10:
322
+ return abs(pert_pred)
323
+ return abs(orig_pred - pert_pred) / abs(orig_pred)
324
+ else:
325
+ raise ValueError(f"Unknown metric: {metric}")
@@ -0,0 +1,428 @@
1
+ # src/explainiverse/evaluation/faithfulness.py
2
+ """
3
+ Faithfulness evaluation metrics for explanations.
4
+
5
+ Implements:
6
+ - PGI (Prediction Gap on Important features)
7
+ - PGU (Prediction Gap on Unimportant features)
8
+ - Faithfulness Correlation
9
+ - Comprehensiveness and Sufficiency
10
+ """
11
+ import numpy as np
12
+ import pandas as pd
13
+ from typing import Union, Callable, List, Dict, Optional
14
+ from explainiverse.core.explanation import Explanation
15
+ from explainiverse.evaluation._utils import (
16
+ get_sorted_feature_indices,
17
+ compute_baseline_values,
18
+ apply_feature_mask,
19
+ resolve_k,
20
+ get_prediction_value,
21
+ compute_prediction_change,
22
+ )
23
+
24
+
25
+ def compute_pgi(
26
+ model,
27
+ instance: np.ndarray,
28
+ explanation: Explanation,
29
+ k: Union[int, float] = 0.2,
30
+ baseline: Union[str, float, np.ndarray, Callable] = "mean",
31
+ background_data: np.ndarray = None,
32
+ ) -> float:
33
+ """
34
+ Compute Prediction Gap on Important features (PGI).
35
+
36
+ Measures prediction change when removing the top-k most important features.
37
+ Higher PGI indicates the explanation correctly identified important features.
38
+
39
+ Args:
40
+ model: Model adapter with predict/predict_proba method
41
+ instance: Input instance (1D array)
42
+ explanation: Explanation object with feature_attributions
43
+ k: Number of top features to remove (int) or fraction (float 0-1)
44
+ baseline: Baseline for feature replacement ("mean", "median", scalar, array, callable)
45
+ background_data: Reference data for computing baseline (required for "mean"/"median")
46
+
47
+ Returns:
48
+ PGI score (higher = explanation identified truly important features)
49
+ """
50
+ instance = np.asarray(instance).flatten()
51
+ n_features = len(instance)
52
+
53
+ # Resolve k to integer
54
+ k_int = resolve_k(k, n_features)
55
+
56
+ # Get feature indices sorted by importance (most important first)
57
+ sorted_indices = get_sorted_feature_indices(explanation, descending=True)
58
+ top_k_indices = sorted_indices[:k_int]
59
+
60
+ # Compute baseline values
61
+ baseline_values = compute_baseline_values(
62
+ baseline, background_data, n_features
63
+ )
64
+
65
+ # Perturb instance by removing top-k important features
66
+ perturbed = apply_feature_mask(instance, top_k_indices, baseline_values)
67
+
68
+ # Compute prediction change
69
+ return compute_prediction_change(model, instance, perturbed, metric="absolute")
70
+
71
+
72
+ def compute_pgu(
73
+ model,
74
+ instance: np.ndarray,
75
+ explanation: Explanation,
76
+ k: Union[int, float] = 0.2,
77
+ baseline: Union[str, float, np.ndarray, Callable] = "mean",
78
+ background_data: np.ndarray = None,
79
+ ) -> float:
80
+ """
81
+ Compute Prediction Gap on Unimportant features (PGU).
82
+
83
+ Measures prediction change when removing the bottom-k least important features.
84
+ Lower PGU indicates the explanation correctly identified unimportant features.
85
+
86
+ Args:
87
+ model: Model adapter with predict/predict_proba method
88
+ instance: Input instance (1D array)
89
+ explanation: Explanation object with feature_attributions
90
+ k: Number of bottom features to remove (int) or fraction (float 0-1)
91
+ baseline: Baseline for feature replacement ("mean", "median", scalar, array, callable)
92
+ background_data: Reference data for computing baseline (required for "mean"/"median")
93
+
94
+ Returns:
95
+ PGU score (lower = explanation correctly identified unimportant features)
96
+ """
97
+ instance = np.asarray(instance).flatten()
98
+ n_features = len(instance)
99
+
100
+ # Resolve k to integer
101
+ k_int = resolve_k(k, n_features)
102
+
103
+ # Get feature indices sorted by importance (least important first for PGU)
104
+ sorted_indices = get_sorted_feature_indices(explanation, descending=False)
105
+ bottom_k_indices = sorted_indices[:k_int]
106
+
107
+ # Compute baseline values
108
+ baseline_values = compute_baseline_values(
109
+ baseline, background_data, n_features
110
+ )
111
+
112
+ # Perturb instance by removing bottom-k unimportant features
113
+ perturbed = apply_feature_mask(instance, bottom_k_indices, baseline_values)
114
+
115
+ # Compute prediction change
116
+ return compute_prediction_change(model, instance, perturbed, metric="absolute")
117
+
118
+
119
+ def compute_faithfulness_score(
120
+ model,
121
+ instance: np.ndarray,
122
+ explanation: Explanation,
123
+ k: Union[int, float] = 0.2,
124
+ baseline: Union[str, float, np.ndarray, Callable] = "mean",
125
+ background_data: np.ndarray = None,
126
+ epsilon: float = 1e-7,
127
+ ) -> Dict[str, float]:
128
+ """
129
+ Compute combined faithfulness metrics.
130
+
131
+ Args:
132
+ model: Model adapter
133
+ instance: Input instance (1D array)
134
+ explanation: Explanation object
135
+ k: Number/fraction of features for PGI/PGU
136
+ baseline: Baseline for feature replacement
137
+ background_data: Reference data for baseline computation
138
+ epsilon: Small constant to avoid division by zero
139
+
140
+ Returns:
141
+ Dictionary containing:
142
+ - pgi: Prediction Gap on Important features
143
+ - pgu: Prediction Gap on Unimportant features
144
+ - faithfulness_ratio: PGI / (PGU + epsilon) - higher is better
145
+ - faithfulness_diff: PGI - PGU - higher is better
146
+ """
147
+ pgi = compute_pgi(model, instance, explanation, k, baseline, background_data)
148
+ pgu = compute_pgu(model, instance, explanation, k, baseline, background_data)
149
+
150
+ return {
151
+ "pgi": pgi,
152
+ "pgu": pgu,
153
+ "faithfulness_ratio": pgi / (pgu + epsilon),
154
+ "faithfulness_diff": pgi - pgu,
155
+ }
156
+
157
+
158
+ def compute_comprehensiveness(
159
+ model,
160
+ instance: np.ndarray,
161
+ explanation: Explanation,
162
+ k_values: List[Union[int, float]] = None,
163
+ baseline: Union[str, float, np.ndarray, Callable] = "mean",
164
+ background_data: np.ndarray = None,
165
+ ) -> Dict[str, float]:
166
+ """
167
+ Compute comprehensiveness - how much prediction drops when removing important features.
168
+
169
+ This is essentially PGI computed at multiple k values and averaged.
170
+ Higher comprehensiveness = better explanation.
171
+
172
+ Args:
173
+ model: Model adapter
174
+ instance: Input instance
175
+ explanation: Explanation object
176
+ k_values: List of k values to evaluate (default: [0.1, 0.2, 0.3])
177
+ baseline: Baseline for feature replacement
178
+ background_data: Reference data
179
+
180
+ Returns:
181
+ Dictionary with per-k scores and mean comprehensiveness
182
+ """
183
+ if k_values is None:
184
+ k_values = [0.1, 0.2, 0.3]
185
+
186
+ scores = {}
187
+ for k in k_values:
188
+ score = compute_pgi(model, instance, explanation, k, baseline, background_data)
189
+ scores[f"comp_k{k}"] = score
190
+
191
+ scores["comprehensiveness"] = np.mean(list(scores.values()))
192
+ return scores
193
+
194
+
195
+ def compute_sufficiency(
196
+ model,
197
+ instance: np.ndarray,
198
+ explanation: Explanation,
199
+ k_values: List[Union[int, float]] = None,
200
+ baseline: Union[str, float, np.ndarray, Callable] = "mean",
201
+ background_data: np.ndarray = None,
202
+ ) -> Dict[str, float]:
203
+ """
204
+ Compute sufficiency - how much prediction is preserved when keeping only important features.
205
+
206
+ Lower sufficiency = the important features alone are sufficient for prediction.
207
+
208
+ Args:
209
+ model: Model adapter
210
+ instance: Input instance
211
+ explanation: Explanation object
212
+ k_values: List of k values (fraction of features to KEEP)
213
+ baseline: Baseline for feature replacement
214
+ background_data: Reference data
215
+
216
+ Returns:
217
+ Dictionary with per-k scores and mean sufficiency
218
+ """
219
+ if k_values is None:
220
+ k_values = [0.1, 0.2, 0.3]
221
+
222
+ instance = np.asarray(instance).flatten()
223
+ n_features = len(instance)
224
+
225
+ # Get baseline values
226
+ baseline_values = compute_baseline_values(baseline, background_data, n_features)
227
+
228
+ # Get sorted indices (most important first)
229
+ sorted_indices = get_sorted_feature_indices(explanation, descending=True)
230
+
231
+ scores = {}
232
+ for k in k_values:
233
+ k_int = resolve_k(k, n_features)
234
+
235
+ # Keep only top-k features, replace rest with baseline
236
+ top_k_set = set(sorted_indices[:k_int])
237
+ indices_to_mask = [i for i in range(n_features) if i not in top_k_set]
238
+
239
+ perturbed = apply_feature_mask(instance, indices_to_mask, baseline_values)
240
+ change = compute_prediction_change(model, instance, perturbed, metric="absolute")
241
+ scores[f"suff_k{k}"] = change
242
+
243
+ scores["sufficiency"] = np.mean([v for k, v in scores.items() if k.startswith("suff_k")])
244
+ return scores
245
+
246
+
247
+ def compute_faithfulness_correlation(
248
+ model,
249
+ instance: np.ndarray,
250
+ explanation: Explanation,
251
+ baseline: Union[str, float, np.ndarray, Callable] = "mean",
252
+ background_data: np.ndarray = None,
253
+ n_steps: int = None,
254
+ ) -> float:
255
+ """
256
+ Compute faithfulness correlation between attributions and prediction changes.
257
+
258
+ Measures correlation between feature importance ranking and actual impact
259
+ on predictions when features are removed one at a time.
260
+
261
+ Args:
262
+ model: Model adapter
263
+ instance: Input instance
264
+ explanation: Explanation object
265
+ baseline: Baseline for feature replacement
266
+ background_data: Reference data
267
+ n_steps: Number of features to evaluate (default: all features)
268
+
269
+ Returns:
270
+ Pearson correlation coefficient (-1 to 1, higher is better)
271
+ """
272
+ instance = np.asarray(instance).flatten()
273
+ n_features = len(instance)
274
+
275
+ if n_steps is None:
276
+ n_steps = n_features
277
+ n_steps = min(n_steps, n_features)
278
+
279
+ # Get attributions
280
+ attributions = explanation.explanation_data.get("feature_attributions", {})
281
+ sorted_indices = get_sorted_feature_indices(explanation, descending=True)[:n_steps]
282
+
283
+ # Get baseline
284
+ baseline_values = compute_baseline_values(baseline, background_data, n_features)
285
+
286
+ # Compute importance values and prediction changes for each feature
287
+ importance_values = []
288
+ prediction_changes = []
289
+
290
+ feature_names = getattr(explanation, 'feature_names', None)
291
+
292
+ for idx in sorted_indices:
293
+ # Get importance value for this feature
294
+ if feature_names and idx < len(feature_names):
295
+ fname = feature_names[idx]
296
+ else:
297
+ # Try common naming patterns
298
+ for pattern in [f"feature_{idx}", f"f{idx}", f"feat_{idx}"]:
299
+ if pattern in attributions:
300
+ fname = pattern
301
+ break
302
+ else:
303
+ fname = list(attributions.keys())[sorted_indices.index(idx)] if idx < len(attributions) else None
304
+
305
+ if fname and fname in attributions:
306
+ importance_values.append(abs(attributions[fname]))
307
+ else:
308
+ continue
309
+
310
+ # Compute prediction change when removing this single feature
311
+ perturbed = apply_feature_mask(instance, [idx], baseline_values)
312
+ change = compute_prediction_change(model, instance, perturbed, metric="absolute")
313
+ prediction_changes.append(change)
314
+
315
+ if len(importance_values) < 2:
316
+ return 0.0 # Not enough data points
317
+
318
+ # Compute Pearson correlation
319
+ return float(np.corrcoef(importance_values, prediction_changes)[0, 1])
320
+
321
+
322
+ def compare_explainer_faithfulness(
323
+ model,
324
+ X: np.ndarray,
325
+ explanations: Dict[str, List[Explanation]],
326
+ k: Union[int, float] = 0.2,
327
+ baseline: Union[str, float, np.ndarray, Callable] = "mean",
328
+ max_samples: int = None,
329
+ ) -> pd.DataFrame:
330
+ """
331
+ Compare multiple explainers on faithfulness metrics across a dataset.
332
+
333
+ Args:
334
+ model: Model adapter
335
+ X: Input data (2D array, n_samples x n_features)
336
+ explanations: Dict mapping explainer names to lists of Explanation objects
337
+ k: Number/fraction of features for PGI/PGU
338
+ baseline: Baseline for feature replacement
339
+ max_samples: Limit number of samples to evaluate (None = all)
340
+
341
+ Returns:
342
+ DataFrame with columns: [explainer, mean_pgi, std_pgi, mean_pgu, std_pgu,
343
+ mean_ratio, mean_diff, n_samples]
344
+ """
345
+ results = []
346
+
347
+ for explainer_name, expl_list in explanations.items():
348
+ n_samples = len(expl_list)
349
+ if max_samples:
350
+ n_samples = min(n_samples, max_samples)
351
+
352
+ pgi_scores = []
353
+ pgu_scores = []
354
+
355
+ for i in range(n_samples):
356
+ instance = X[i]
357
+ exp = expl_list[i]
358
+
359
+ try:
360
+ scores = compute_faithfulness_score(
361
+ model, instance, exp, k, baseline, X
362
+ )
363
+ pgi_scores.append(scores["pgi"])
364
+ pgu_scores.append(scores["pgu"])
365
+ except Exception as e:
366
+ # Skip instances that fail
367
+ continue
368
+
369
+ if pgi_scores:
370
+ results.append({
371
+ "explainer": explainer_name,
372
+ "mean_pgi": np.mean(pgi_scores),
373
+ "std_pgi": np.std(pgi_scores),
374
+ "mean_pgu": np.mean(pgu_scores),
375
+ "std_pgu": np.std(pgu_scores),
376
+ "mean_ratio": np.mean(pgi_scores) / (np.mean(pgu_scores) + 1e-7),
377
+ "mean_diff": np.mean(pgi_scores) - np.mean(pgu_scores),
378
+ "n_samples": len(pgi_scores),
379
+ })
380
+
381
+ return pd.DataFrame(results)
382
+
383
+
384
+ def compute_batch_faithfulness(
385
+ model,
386
+ X: np.ndarray,
387
+ explanations: List[Explanation],
388
+ k: Union[int, float] = 0.2,
389
+ baseline: Union[str, float, np.ndarray, Callable] = "mean",
390
+ ) -> Dict[str, float]:
391
+ """
392
+ Compute average faithfulness metrics over a batch of instances.
393
+
394
+ Args:
395
+ model: Model adapter
396
+ X: Input data (2D array)
397
+ explanations: List of Explanation objects (one per instance)
398
+ k: Number/fraction of features for PGI/PGU
399
+ baseline: Baseline for feature replacement
400
+
401
+ Returns:
402
+ Dictionary with aggregated metrics
403
+ """
404
+ pgi_scores = []
405
+ pgu_scores = []
406
+
407
+ for i, exp in enumerate(explanations):
408
+ try:
409
+ scores = compute_faithfulness_score(
410
+ model, X[i], exp, k, baseline, X
411
+ )
412
+ pgi_scores.append(scores["pgi"])
413
+ pgu_scores.append(scores["pgu"])
414
+ except Exception:
415
+ continue
416
+
417
+ if not pgi_scores:
418
+ return {"mean_pgi": 0.0, "mean_pgu": 0.0, "mean_ratio": 0.0, "n_samples": 0}
419
+
420
+ return {
421
+ "mean_pgi": np.mean(pgi_scores),
422
+ "std_pgi": np.std(pgi_scores),
423
+ "mean_pgu": np.mean(pgu_scores),
424
+ "std_pgu": np.std(pgu_scores),
425
+ "mean_ratio": np.mean(pgi_scores) / (np.mean(pgu_scores) + 1e-7),
426
+ "mean_diff": np.mean(pgi_scores) - np.mean(pgu_scores),
427
+ "n_samples": len(pgi_scores),
428
+ }
@@ -0,0 +1,379 @@
1
+ # src/explainiverse/evaluation/stability.py
2
+ """
3
+ Stability evaluation metrics for explanations.
4
+
5
+ Implements:
6
+ - RIS (Relative Input Stability) - sensitivity to input perturbations
7
+ - ROS (Relative Output Stability) - consistency with similar predictions
8
+ - Lipschitz Estimate - local smoothness of explanations
9
+ """
10
+ import numpy as np
11
+ from typing import Union, Callable, List, Dict, Optional, Tuple
12
+ from explainiverse.core.explanation import Explanation
13
+ from explainiverse.core.explainer import BaseExplainer
14
+ from explainiverse.evaluation._utils import get_prediction_value
15
+
16
+
17
+ def _extract_attribution_vector(explanation: Explanation) -> np.ndarray:
18
+ """
19
+ Extract attribution values as a numpy array from an Explanation.
20
+
21
+ Args:
22
+ explanation: Explanation object with feature_attributions
23
+
24
+ Returns:
25
+ 1D numpy array of attribution values
26
+ """
27
+ attributions = explanation.explanation_data.get("feature_attributions", {})
28
+ if not attributions:
29
+ raise ValueError("No feature attributions found in explanation.")
30
+
31
+ # Get values in consistent order
32
+ feature_names = getattr(explanation, 'feature_names', None)
33
+ if feature_names:
34
+ values = [attributions.get(fn, 0.0) for fn in feature_names]
35
+ else:
36
+ values = list(attributions.values())
37
+
38
+ return np.array(values, dtype=float)
39
+
40
+
41
+ def _normalize_vector(v: np.ndarray, epsilon: float = 1e-10) -> np.ndarray:
42
+ """Normalize a vector to unit length."""
43
+ norm = np.linalg.norm(v)
44
+ if norm < epsilon:
45
+ return v
46
+ return v / norm
47
+
48
+
49
+ def compute_ris(
50
+ explainer: BaseExplainer,
51
+ instance: np.ndarray,
52
+ n_perturbations: int = 10,
53
+ noise_scale: float = 0.01,
54
+ seed: int = None,
55
+ ) -> float:
56
+ """
57
+ Compute Relative Input Stability (RIS).
58
+
59
+ Measures how stable explanations are to small perturbations in the input.
60
+ Lower RIS indicates more stable explanations.
61
+
62
+ RIS = mean(||E(x) - E(x')|| / ||x - x'||) for perturbed inputs x'
63
+
64
+ Args:
65
+ explainer: Explainer instance with .explain() method
66
+ instance: Original input instance (1D array)
67
+ n_perturbations: Number of perturbed samples to generate
68
+ noise_scale: Standard deviation of Gaussian noise (relative to feature range)
69
+ seed: Random seed for reproducibility
70
+
71
+ Returns:
72
+ RIS score (lower = more stable)
73
+ """
74
+ if seed is not None:
75
+ np.random.seed(seed)
76
+
77
+ instance = np.asarray(instance).flatten()
78
+ n_features = len(instance)
79
+
80
+ # Get original explanation
81
+ original_exp = explainer.explain(instance)
82
+ original_exp.feature_names = getattr(original_exp, 'feature_names', None) or \
83
+ [f"feature_{i}" for i in range(n_features)]
84
+ original_attr = _extract_attribution_vector(original_exp)
85
+
86
+ ratios = []
87
+
88
+ for _ in range(n_perturbations):
89
+ # Generate perturbed input
90
+ noise = np.random.normal(0, noise_scale, n_features)
91
+ perturbed = instance + noise * np.abs(instance + 1e-10) # Scale noise by feature magnitude
92
+
93
+ # Get perturbed explanation
94
+ try:
95
+ perturbed_exp = explainer.explain(perturbed)
96
+ perturbed_exp.feature_names = original_exp.feature_names
97
+ perturbed_attr = _extract_attribution_vector(perturbed_exp)
98
+ except Exception:
99
+ continue
100
+
101
+ # Compute ratio of changes
102
+ attr_diff = np.linalg.norm(original_attr - perturbed_attr)
103
+ input_diff = np.linalg.norm(instance - perturbed)
104
+
105
+ if input_diff > 1e-10:
106
+ ratios.append(attr_diff / input_diff)
107
+
108
+ if not ratios:
109
+ return float('inf')
110
+
111
+ return float(np.mean(ratios))
112
+
113
+
114
+ def compute_ros(
115
+ explainer: BaseExplainer,
116
+ model,
117
+ instance: np.ndarray,
118
+ reference_instances: np.ndarray,
119
+ n_neighbors: int = 5,
120
+ prediction_threshold: float = 0.05,
121
+ ) -> float:
122
+ """
123
+ Compute Relative Output Stability (ROS).
124
+
125
+ Measures how similar explanations are for instances with similar predictions.
126
+ Higher ROS indicates more consistent explanations.
127
+
128
+ Args:
129
+ explainer: Explainer instance with .explain() method
130
+ model: Model adapter with predict/predict_proba method
131
+ instance: Query instance
132
+ reference_instances: Pool of reference instances to find neighbors
133
+ n_neighbors: Number of neighbors to compare
134
+ prediction_threshold: Maximum prediction difference to consider "similar"
135
+
136
+ Returns:
137
+ ROS score (higher = more consistent for similar predictions)
138
+ """
139
+ instance = np.asarray(instance).flatten()
140
+ n_features = len(instance)
141
+
142
+ # Get prediction for query instance
143
+ query_pred = get_prediction_value(model, instance)
144
+
145
+ # Find instances with similar predictions
146
+ similar_instances = []
147
+ for ref in reference_instances:
148
+ ref = np.asarray(ref).flatten()
149
+ ref_pred = get_prediction_value(model, ref)
150
+ if abs(query_pred - ref_pred) <= prediction_threshold:
151
+ similar_instances.append(ref)
152
+
153
+ if len(similar_instances) < 2:
154
+ return 1.0 # Perfect stability if no similar instances
155
+
156
+ # Limit to n_neighbors
157
+ similar_instances = similar_instances[:n_neighbors]
158
+
159
+ # Get explanation for query
160
+ query_exp = explainer.explain(instance)
161
+ query_exp.feature_names = getattr(query_exp, 'feature_names', None) or \
162
+ [f"feature_{i}" for i in range(n_features)]
163
+ query_attr = _normalize_vector(_extract_attribution_vector(query_exp))
164
+
165
+ # Get explanations for similar instances and compute similarity
166
+ similarities = []
167
+ for ref in similar_instances:
168
+ try:
169
+ ref_exp = explainer.explain(ref)
170
+ ref_exp.feature_names = query_exp.feature_names
171
+ ref_attr = _normalize_vector(_extract_attribution_vector(ref_exp))
172
+
173
+ # Cosine similarity
174
+ similarity = np.dot(query_attr, ref_attr)
175
+ similarities.append(similarity)
176
+ except Exception:
177
+ continue
178
+
179
+ if not similarities:
180
+ return 1.0
181
+
182
+ return float(np.mean(similarities))
183
+
184
+
185
+ def compute_lipschitz_estimate(
186
+ explainer: BaseExplainer,
187
+ instance: np.ndarray,
188
+ n_samples: int = 20,
189
+ radius: float = 0.1,
190
+ seed: int = None,
191
+ ) -> float:
192
+ """
193
+ Estimate local Lipschitz constant of the explanation function.
194
+
195
+ The Lipschitz constant bounds how fast explanations can change:
196
+ ||E(x) - E(y)|| <= L * ||x - y||
197
+
198
+ Lower L indicates smoother, more stable explanations.
199
+
200
+ Args:
201
+ explainer: Explainer instance
202
+ instance: Center point for local estimate
203
+ n_samples: Number of sample pairs to evaluate
204
+ radius: Radius of ball around instance to sample from
205
+ seed: Random seed
206
+
207
+ Returns:
208
+ Estimated local Lipschitz constant (lower = smoother)
209
+ """
210
+ if seed is not None:
211
+ np.random.seed(seed)
212
+
213
+ instance = np.asarray(instance).flatten()
214
+ n_features = len(instance)
215
+
216
+ max_ratio = 0.0
217
+
218
+ for _ in range(n_samples):
219
+ # Generate two random points in a ball around instance
220
+ direction1 = np.random.randn(n_features)
221
+ direction1 = direction1 / np.linalg.norm(direction1)
222
+ r1 = np.random.uniform(0, radius)
223
+ point1 = instance + r1 * direction1
224
+
225
+ direction2 = np.random.randn(n_features)
226
+ direction2 = direction2 / np.linalg.norm(direction2)
227
+ r2 = np.random.uniform(0, radius)
228
+ point2 = instance + r2 * direction2
229
+
230
+ try:
231
+ exp1 = explainer.explain(point1)
232
+ exp1.feature_names = [f"feature_{i}" for i in range(n_features)]
233
+ attr1 = _extract_attribution_vector(exp1)
234
+
235
+ exp2 = explainer.explain(point2)
236
+ exp2.feature_names = exp1.feature_names
237
+ attr2 = _extract_attribution_vector(exp2)
238
+ except Exception:
239
+ continue
240
+
241
+ attr_diff = np.linalg.norm(attr1 - attr2)
242
+ input_diff = np.linalg.norm(point1 - point2)
243
+
244
+ if input_diff > 1e-10:
245
+ ratio = attr_diff / input_diff
246
+ max_ratio = max(max_ratio, ratio)
247
+
248
+ return float(max_ratio)
249
+
250
+
251
+ def compute_stability_metrics(
252
+ explainer: BaseExplainer,
253
+ model,
254
+ instance: np.ndarray,
255
+ background_data: np.ndarray,
256
+ n_perturbations: int = 10,
257
+ noise_scale: float = 0.01,
258
+ n_neighbors: int = 5,
259
+ seed: int = None,
260
+ ) -> Dict[str, float]:
261
+ """
262
+ Compute comprehensive stability metrics for a single instance.
263
+
264
+ Args:
265
+ explainer: Explainer instance
266
+ model: Model adapter
267
+ instance: Query instance
268
+ background_data: Reference data for ROS computation
269
+ n_perturbations: Number of perturbations for RIS
270
+ noise_scale: Noise scale for RIS
271
+ n_neighbors: Number of neighbors for ROS
272
+ seed: Random seed
273
+
274
+ Returns:
275
+ Dictionary with RIS, ROS, and Lipschitz estimate
276
+ """
277
+ return {
278
+ "ris": compute_ris(explainer, instance, n_perturbations, noise_scale, seed),
279
+ "ros": compute_ros(explainer, model, instance, background_data, n_neighbors),
280
+ "lipschitz": compute_lipschitz_estimate(explainer, instance, seed=seed),
281
+ }
282
+
283
+
284
+ def compute_batch_stability(
285
+ explainer: BaseExplainer,
286
+ model,
287
+ X: np.ndarray,
288
+ n_perturbations: int = 10,
289
+ noise_scale: float = 0.01,
290
+ max_samples: int = None,
291
+ seed: int = None,
292
+ ) -> Dict[str, float]:
293
+ """
294
+ Compute average stability metrics over a batch of instances.
295
+
296
+ Args:
297
+ explainer: Explainer instance
298
+ model: Model adapter
299
+ X: Input data (2D array)
300
+ n_perturbations: Number of perturbations per instance
301
+ noise_scale: Noise scale for perturbations
302
+ max_samples: Maximum number of samples to evaluate
303
+ seed: Random seed
304
+
305
+ Returns:
306
+ Dictionary with mean and std of stability metrics
307
+ """
308
+ n_samples = len(X)
309
+ if max_samples:
310
+ n_samples = min(n_samples, max_samples)
311
+
312
+ ris_scores = []
313
+ ros_scores = []
314
+
315
+ for i in range(n_samples):
316
+ instance = X[i]
317
+
318
+ try:
319
+ ris = compute_ris(explainer, instance, n_perturbations, noise_scale, seed)
320
+ if not np.isinf(ris):
321
+ ris_scores.append(ris)
322
+
323
+ ros = compute_ros(explainer, model, instance, X, n_neighbors=5)
324
+ ros_scores.append(ros)
325
+ except Exception:
326
+ continue
327
+
328
+ results = {"n_samples": len(ris_scores)}
329
+
330
+ if ris_scores:
331
+ results["mean_ris"] = np.mean(ris_scores)
332
+ results["std_ris"] = np.std(ris_scores)
333
+ else:
334
+ results["mean_ris"] = float('inf')
335
+ results["std_ris"] = 0.0
336
+
337
+ if ros_scores:
338
+ results["mean_ros"] = np.mean(ros_scores)
339
+ results["std_ros"] = np.std(ros_scores)
340
+ else:
341
+ results["mean_ros"] = 0.0
342
+ results["std_ros"] = 0.0
343
+
344
+ return results
345
+
346
+
347
+ def compare_explainer_stability(
348
+ explainers: Dict[str, BaseExplainer],
349
+ model,
350
+ X: np.ndarray,
351
+ n_perturbations: int = 5,
352
+ noise_scale: float = 0.01,
353
+ max_samples: int = 20,
354
+ seed: int = None,
355
+ ) -> Dict[str, Dict[str, float]]:
356
+ """
357
+ Compare stability metrics across multiple explainers.
358
+
359
+ Args:
360
+ explainers: Dict mapping explainer names to explainer instances
361
+ model: Model adapter
362
+ X: Input data
363
+ n_perturbations: Number of perturbations per instance
364
+ noise_scale: Noise scale
365
+ max_samples: Max samples to evaluate per explainer
366
+ seed: Random seed
367
+
368
+ Returns:
369
+ Dict mapping explainer names to their stability metrics
370
+ """
371
+ results = {}
372
+
373
+ for name, explainer in explainers.items():
374
+ metrics = compute_batch_stability(
375
+ explainer, model, X, n_perturbations, noise_scale, max_samples, seed
376
+ )
377
+ results[name] = metrics
378
+
379
+ return results
@@ -1,8 +0,0 @@
1
- # src/explainiverse/evaluation/__init__.py
2
- """
3
- Evaluation metrics for explanation quality.
4
- """
5
-
6
- from explainiverse.evaluation.metrics import compute_aopc, compute_roar
7
-
8
- __all__ = ["compute_aopc", "compute_roar"]
File without changes
File without changes