explainiverse 0.2.5__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- explainiverse/evaluation/__init__.py +54 -2
- explainiverse/evaluation/_utils.py +325 -0
- explainiverse/evaluation/faithfulness.py +428 -0
- explainiverse/evaluation/stability.py +379 -0
- {explainiverse-0.2.5.dist-info → explainiverse-0.3.0.dist-info}/METADATA +2 -1
- {explainiverse-0.2.5.dist-info → explainiverse-0.3.0.dist-info}/RECORD +8 -5
- {explainiverse-0.2.5.dist-info → explainiverse-0.3.0.dist-info}/LICENSE +0 -0
- {explainiverse-0.2.5.dist-info → explainiverse-0.3.0.dist-info}/WHEEL +0 -0
|
@@ -1,8 +1,60 @@
|
|
|
1
1
|
# src/explainiverse/evaluation/__init__.py
|
|
2
2
|
"""
|
|
3
3
|
Evaluation metrics for explanation quality.
|
|
4
|
+
|
|
5
|
+
Includes:
|
|
6
|
+
- Faithfulness metrics (PGI, PGU, Comprehensiveness, Sufficiency)
|
|
7
|
+
- Stability metrics (RIS, ROS, Lipschitz)
|
|
8
|
+
- Perturbation metrics (AOPC, ROAR)
|
|
4
9
|
"""
|
|
5
10
|
|
|
6
|
-
from explainiverse.evaluation.metrics import
|
|
11
|
+
from explainiverse.evaluation.metrics import (
|
|
12
|
+
compute_aopc,
|
|
13
|
+
compute_batch_aopc,
|
|
14
|
+
compute_roar,
|
|
15
|
+
compute_roar_curve,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
from explainiverse.evaluation.faithfulness import (
|
|
19
|
+
compute_pgi,
|
|
20
|
+
compute_pgu,
|
|
21
|
+
compute_faithfulness_score,
|
|
22
|
+
compute_comprehensiveness,
|
|
23
|
+
compute_sufficiency,
|
|
24
|
+
compute_faithfulness_correlation,
|
|
25
|
+
compare_explainer_faithfulness,
|
|
26
|
+
compute_batch_faithfulness,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
from explainiverse.evaluation.stability import (
|
|
30
|
+
compute_ris,
|
|
31
|
+
compute_ros,
|
|
32
|
+
compute_lipschitz_estimate,
|
|
33
|
+
compute_stability_metrics,
|
|
34
|
+
compute_batch_stability,
|
|
35
|
+
compare_explainer_stability,
|
|
36
|
+
)
|
|
7
37
|
|
|
8
|
-
__all__ = [
|
|
38
|
+
__all__ = [
|
|
39
|
+
# Perturbation metrics (existing)
|
|
40
|
+
"compute_aopc",
|
|
41
|
+
"compute_batch_aopc",
|
|
42
|
+
"compute_roar",
|
|
43
|
+
"compute_roar_curve",
|
|
44
|
+
# Faithfulness metrics (new)
|
|
45
|
+
"compute_pgi",
|
|
46
|
+
"compute_pgu",
|
|
47
|
+
"compute_faithfulness_score",
|
|
48
|
+
"compute_comprehensiveness",
|
|
49
|
+
"compute_sufficiency",
|
|
50
|
+
"compute_faithfulness_correlation",
|
|
51
|
+
"compare_explainer_faithfulness",
|
|
52
|
+
"compute_batch_faithfulness",
|
|
53
|
+
# Stability metrics (new)
|
|
54
|
+
"compute_ris",
|
|
55
|
+
"compute_ros",
|
|
56
|
+
"compute_lipschitz_estimate",
|
|
57
|
+
"compute_stability_metrics",
|
|
58
|
+
"compute_batch_stability",
|
|
59
|
+
"compare_explainer_stability",
|
|
60
|
+
]
|
|
@@ -0,0 +1,325 @@
|
|
|
1
|
+
# src/explainiverse/evaluation/_utils.py
|
|
2
|
+
"""
|
|
3
|
+
Shared utility functions for evaluation metrics.
|
|
4
|
+
"""
|
|
5
|
+
import numpy as np
|
|
6
|
+
import re
|
|
7
|
+
from typing import Union, Callable, List, Tuple
|
|
8
|
+
from explainiverse.core.explanation import Explanation
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _extract_base_feature_name(feature_str: str) -> str:
|
|
12
|
+
"""
|
|
13
|
+
Extract the base feature name from LIME-style feature strings.
|
|
14
|
+
|
|
15
|
+
LIME returns strings like "petal width (cm) <= 0.80" or "feature_2 > 3.5".
|
|
16
|
+
This extracts just the feature name part.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
feature_str: Feature string possibly with conditions
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
Base feature name
|
|
23
|
+
"""
|
|
24
|
+
# Remove comparison operators and values
|
|
25
|
+
# Pattern matches: name <= value, name < value, name >= value, name > value, name = value
|
|
26
|
+
patterns = [
|
|
27
|
+
r'^(.+?)\s*<=\s*[\d\.\-]+$',
|
|
28
|
+
r'^(.+?)\s*>=\s*[\d\.\-]+$',
|
|
29
|
+
r'^(.+?)\s*<\s*[\d\.\-]+$',
|
|
30
|
+
r'^(.+?)\s*>\s*[\d\.\-]+$',
|
|
31
|
+
r'^(.+?)\s*=\s*[\d\.\-]+$',
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
for pattern in patterns:
|
|
35
|
+
match = re.match(pattern, feature_str.strip())
|
|
36
|
+
if match:
|
|
37
|
+
return match.group(1).strip()
|
|
38
|
+
|
|
39
|
+
# No match found, return as-is
|
|
40
|
+
return feature_str.strip()
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _match_feature_to_index(
|
|
44
|
+
feature_key: str,
|
|
45
|
+
feature_names: List[str]
|
|
46
|
+
) -> int:
|
|
47
|
+
"""
|
|
48
|
+
Match a feature key (possibly with LIME conditions) to its index.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
feature_key: Feature name from explanation (may include conditions)
|
|
52
|
+
feature_names: List of original feature names
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
Index of the matching feature, or -1 if not found
|
|
56
|
+
"""
|
|
57
|
+
# Try exact match first
|
|
58
|
+
if feature_key in feature_names:
|
|
59
|
+
return feature_names.index(feature_key)
|
|
60
|
+
|
|
61
|
+
# Try extracting base name
|
|
62
|
+
base_name = _extract_base_feature_name(feature_key)
|
|
63
|
+
if base_name in feature_names:
|
|
64
|
+
return feature_names.index(base_name)
|
|
65
|
+
|
|
66
|
+
# Try partial matching (feature name is contained in key)
|
|
67
|
+
for i, fname in enumerate(feature_names):
|
|
68
|
+
if fname in feature_key:
|
|
69
|
+
return i
|
|
70
|
+
|
|
71
|
+
# Try index extraction from patterns like "feature_2" or "f2" or "feat_2"
|
|
72
|
+
patterns = [
|
|
73
|
+
r'feature[_\s]*(\d+)',
|
|
74
|
+
r'feat[_\s]*(\d+)',
|
|
75
|
+
r'^f(\d+)$',
|
|
76
|
+
r'^x(\d+)$',
|
|
77
|
+
]
|
|
78
|
+
for pattern in patterns:
|
|
79
|
+
match = re.search(pattern, feature_key, re.IGNORECASE)
|
|
80
|
+
if match:
|
|
81
|
+
idx = int(match.group(1))
|
|
82
|
+
if 0 <= idx < len(feature_names):
|
|
83
|
+
return idx
|
|
84
|
+
|
|
85
|
+
return -1
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def get_sorted_feature_indices(
|
|
89
|
+
explanation: Explanation,
|
|
90
|
+
descending: bool = True
|
|
91
|
+
) -> List[int]:
|
|
92
|
+
"""
|
|
93
|
+
Extract feature indices sorted by absolute attribution value.
|
|
94
|
+
|
|
95
|
+
Handles various feature naming conventions:
|
|
96
|
+
- Clean names: "sepal length", "feature_0"
|
|
97
|
+
- LIME-style: "sepal length <= 5.0", "feature_0 > 2.3"
|
|
98
|
+
- Indexed: "f0", "x1", "feat_2"
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
explanation: Explanation object with feature_attributions
|
|
102
|
+
descending: If True, sort from most to least important
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
List of feature indices sorted by importance
|
|
106
|
+
"""
|
|
107
|
+
attributions = explanation.explanation_data.get("feature_attributions", {})
|
|
108
|
+
|
|
109
|
+
if not attributions:
|
|
110
|
+
raise ValueError("No feature attributions found in explanation.")
|
|
111
|
+
|
|
112
|
+
# Sort features by absolute importance
|
|
113
|
+
sorted_features = sorted(
|
|
114
|
+
attributions.items(),
|
|
115
|
+
key=lambda x: abs(x[1]),
|
|
116
|
+
reverse=descending
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# Map feature names to indices
|
|
120
|
+
feature_indices = []
|
|
121
|
+
feature_names = getattr(explanation, 'feature_names', None)
|
|
122
|
+
|
|
123
|
+
for i, (fname, _) in enumerate(sorted_features):
|
|
124
|
+
if feature_names is not None:
|
|
125
|
+
idx = _match_feature_to_index(fname, feature_names)
|
|
126
|
+
if idx >= 0:
|
|
127
|
+
feature_indices.append(idx)
|
|
128
|
+
else:
|
|
129
|
+
# Fallback: use position in sorted list
|
|
130
|
+
feature_indices.append(i % len(feature_names))
|
|
131
|
+
else:
|
|
132
|
+
# No feature_names available - try to extract index from name
|
|
133
|
+
patterns = [
|
|
134
|
+
r'feature[_\s]*(\d+)',
|
|
135
|
+
r'feat[_\s]*(\d+)',
|
|
136
|
+
r'^f(\d+)',
|
|
137
|
+
r'^x(\d+)',
|
|
138
|
+
]
|
|
139
|
+
found = False
|
|
140
|
+
for pattern in patterns:
|
|
141
|
+
match = re.search(pattern, fname, re.IGNORECASE)
|
|
142
|
+
if match:
|
|
143
|
+
feature_indices.append(int(match.group(1)))
|
|
144
|
+
found = True
|
|
145
|
+
break
|
|
146
|
+
if not found:
|
|
147
|
+
feature_indices.append(i)
|
|
148
|
+
|
|
149
|
+
return feature_indices
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def compute_baseline_values(
|
|
153
|
+
baseline: Union[str, float, np.ndarray, Callable],
|
|
154
|
+
background_data: np.ndarray = None,
|
|
155
|
+
n_features: int = None
|
|
156
|
+
) -> np.ndarray:
|
|
157
|
+
"""
|
|
158
|
+
Compute per-feature baseline values for perturbation.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
baseline: Baseline specification - one of:
|
|
162
|
+
- "mean": Use mean of background_data
|
|
163
|
+
- "median": Use median of background_data
|
|
164
|
+
- float/int: Use this value for all features
|
|
165
|
+
- np.ndarray: Use these values directly (must match n_features)
|
|
166
|
+
- Callable: Function that takes background_data and returns baseline array
|
|
167
|
+
background_data: Reference data for computing statistics (required for "mean"/"median")
|
|
168
|
+
n_features: Number of features (required if baseline is scalar)
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
1D numpy array of baseline values, one per feature
|
|
172
|
+
"""
|
|
173
|
+
if isinstance(baseline, str):
|
|
174
|
+
if background_data is None:
|
|
175
|
+
raise ValueError(f"background_data required for baseline='{baseline}'")
|
|
176
|
+
if baseline == "mean":
|
|
177
|
+
return np.mean(background_data, axis=0)
|
|
178
|
+
elif baseline == "median":
|
|
179
|
+
return np.median(background_data, axis=0)
|
|
180
|
+
else:
|
|
181
|
+
raise ValueError(f"Unsupported string baseline: {baseline}")
|
|
182
|
+
|
|
183
|
+
elif callable(baseline):
|
|
184
|
+
if background_data is None:
|
|
185
|
+
raise ValueError("background_data required for callable baseline")
|
|
186
|
+
result = baseline(background_data)
|
|
187
|
+
return np.asarray(result)
|
|
188
|
+
|
|
189
|
+
elif isinstance(baseline, np.ndarray):
|
|
190
|
+
return baseline
|
|
191
|
+
|
|
192
|
+
elif isinstance(baseline, (float, int, np.number)):
|
|
193
|
+
if n_features is None:
|
|
194
|
+
raise ValueError("n_features required for scalar baseline")
|
|
195
|
+
return np.full(n_features, baseline)
|
|
196
|
+
|
|
197
|
+
else:
|
|
198
|
+
raise ValueError(f"Invalid baseline type: {type(baseline)}")
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def apply_feature_mask(
|
|
202
|
+
instance: np.ndarray,
|
|
203
|
+
feature_indices: List[int],
|
|
204
|
+
baseline_values: np.ndarray
|
|
205
|
+
) -> np.ndarray:
|
|
206
|
+
"""
|
|
207
|
+
Replace specified features with baseline values.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
instance: Original instance (1D array)
|
|
211
|
+
feature_indices: Indices of features to replace
|
|
212
|
+
baseline_values: Per-feature baseline values
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
Modified instance with specified features replaced
|
|
216
|
+
"""
|
|
217
|
+
modified = instance.copy()
|
|
218
|
+
for idx in feature_indices:
|
|
219
|
+
if idx < len(modified) and idx < len(baseline_values):
|
|
220
|
+
modified[idx] = baseline_values[idx]
|
|
221
|
+
return modified
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def resolve_k(k: Union[int, float], n_features: int) -> int:
|
|
225
|
+
"""
|
|
226
|
+
Resolve k to an integer number of features.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
k: Either an integer count or a float fraction (0-1)
|
|
230
|
+
n_features: Total number of features
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
Integer number of features
|
|
234
|
+
"""
|
|
235
|
+
if isinstance(k, float) and 0 < k <= 1:
|
|
236
|
+
return max(1, int(k * n_features))
|
|
237
|
+
elif isinstance(k, int) and k > 0:
|
|
238
|
+
return min(k, n_features)
|
|
239
|
+
else:
|
|
240
|
+
raise ValueError(f"k must be positive int or float in (0, 1], got {k}")
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def get_prediction_value(
|
|
244
|
+
model,
|
|
245
|
+
instance: np.ndarray,
|
|
246
|
+
output_type: str = "probability"
|
|
247
|
+
) -> float:
|
|
248
|
+
"""
|
|
249
|
+
Get a scalar prediction value from a model.
|
|
250
|
+
|
|
251
|
+
Works with both raw sklearn models and explainiverse adapters.
|
|
252
|
+
For adapters, .predict() typically returns probabilities.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
model: Model adapter with predict/predict_proba methods
|
|
256
|
+
instance: Single instance (1D array)
|
|
257
|
+
output_type: "probability" (max prob) or "class" (predicted class)
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
Scalar prediction value
|
|
261
|
+
"""
|
|
262
|
+
instance_2d = instance.reshape(1, -1)
|
|
263
|
+
|
|
264
|
+
if output_type == "probability":
|
|
265
|
+
# Try predict_proba first (raw sklearn model)
|
|
266
|
+
if hasattr(model, 'predict_proba'):
|
|
267
|
+
proba = model.predict_proba(instance_2d)
|
|
268
|
+
if isinstance(proba, np.ndarray):
|
|
269
|
+
if proba.ndim == 2:
|
|
270
|
+
return float(np.max(proba[0]))
|
|
271
|
+
return float(np.max(proba))
|
|
272
|
+
return float(np.max(proba[0]))
|
|
273
|
+
|
|
274
|
+
# Fall back to predict (adapter returns probs from predict)
|
|
275
|
+
pred = model.predict(instance_2d)
|
|
276
|
+
if isinstance(pred, np.ndarray):
|
|
277
|
+
if pred.ndim == 2:
|
|
278
|
+
return float(np.max(pred[0]))
|
|
279
|
+
elif pred.ndim == 1:
|
|
280
|
+
return float(np.max(pred))
|
|
281
|
+
return float(pred[0]) if hasattr(pred, '__getitem__') else float(pred)
|
|
282
|
+
|
|
283
|
+
elif output_type == "class":
|
|
284
|
+
# For class prediction, use argmax of probabilities
|
|
285
|
+
if hasattr(model, 'predict_proba'):
|
|
286
|
+
proba = model.predict_proba(instance_2d)
|
|
287
|
+
return float(np.argmax(proba[0]))
|
|
288
|
+
pred = model.predict(instance_2d)
|
|
289
|
+
if isinstance(pred, np.ndarray) and pred.ndim == 2:
|
|
290
|
+
return float(np.argmax(pred[0]))
|
|
291
|
+
return float(pred[0]) if hasattr(pred, '__getitem__') else float(pred)
|
|
292
|
+
|
|
293
|
+
else:
|
|
294
|
+
raise ValueError(f"Unknown output_type: {output_type}")
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def compute_prediction_change(
|
|
298
|
+
model,
|
|
299
|
+
original: np.ndarray,
|
|
300
|
+
perturbed: np.ndarray,
|
|
301
|
+
metric: str = "absolute"
|
|
302
|
+
) -> float:
|
|
303
|
+
"""
|
|
304
|
+
Compute the change in prediction between original and perturbed instances.
|
|
305
|
+
|
|
306
|
+
Args:
|
|
307
|
+
model: Model adapter
|
|
308
|
+
original: Original instance
|
|
309
|
+
perturbed: Perturbed instance
|
|
310
|
+
metric: "absolute" for |p1 - p2|, "relative" for |p1 - p2| / p1
|
|
311
|
+
|
|
312
|
+
Returns:
|
|
313
|
+
Prediction change value
|
|
314
|
+
"""
|
|
315
|
+
orig_pred = get_prediction_value(model, original)
|
|
316
|
+
pert_pred = get_prediction_value(model, perturbed)
|
|
317
|
+
|
|
318
|
+
if metric == "absolute":
|
|
319
|
+
return abs(orig_pred - pert_pred)
|
|
320
|
+
elif metric == "relative":
|
|
321
|
+
if abs(orig_pred) < 1e-10:
|
|
322
|
+
return abs(pert_pred)
|
|
323
|
+
return abs(orig_pred - pert_pred) / abs(orig_pred)
|
|
324
|
+
else:
|
|
325
|
+
raise ValueError(f"Unknown metric: {metric}")
|
|
@@ -0,0 +1,428 @@
|
|
|
1
|
+
# src/explainiverse/evaluation/faithfulness.py
|
|
2
|
+
"""
|
|
3
|
+
Faithfulness evaluation metrics for explanations.
|
|
4
|
+
|
|
5
|
+
Implements:
|
|
6
|
+
- PGI (Prediction Gap on Important features)
|
|
7
|
+
- PGU (Prediction Gap on Unimportant features)
|
|
8
|
+
- Faithfulness Correlation
|
|
9
|
+
- Comprehensiveness and Sufficiency
|
|
10
|
+
"""
|
|
11
|
+
import numpy as np
|
|
12
|
+
import pandas as pd
|
|
13
|
+
from typing import Union, Callable, List, Dict, Optional
|
|
14
|
+
from explainiverse.core.explanation import Explanation
|
|
15
|
+
from explainiverse.evaluation._utils import (
|
|
16
|
+
get_sorted_feature_indices,
|
|
17
|
+
compute_baseline_values,
|
|
18
|
+
apply_feature_mask,
|
|
19
|
+
resolve_k,
|
|
20
|
+
get_prediction_value,
|
|
21
|
+
compute_prediction_change,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def compute_pgi(
|
|
26
|
+
model,
|
|
27
|
+
instance: np.ndarray,
|
|
28
|
+
explanation: Explanation,
|
|
29
|
+
k: Union[int, float] = 0.2,
|
|
30
|
+
baseline: Union[str, float, np.ndarray, Callable] = "mean",
|
|
31
|
+
background_data: np.ndarray = None,
|
|
32
|
+
) -> float:
|
|
33
|
+
"""
|
|
34
|
+
Compute Prediction Gap on Important features (PGI).
|
|
35
|
+
|
|
36
|
+
Measures prediction change when removing the top-k most important features.
|
|
37
|
+
Higher PGI indicates the explanation correctly identified important features.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
model: Model adapter with predict/predict_proba method
|
|
41
|
+
instance: Input instance (1D array)
|
|
42
|
+
explanation: Explanation object with feature_attributions
|
|
43
|
+
k: Number of top features to remove (int) or fraction (float 0-1)
|
|
44
|
+
baseline: Baseline for feature replacement ("mean", "median", scalar, array, callable)
|
|
45
|
+
background_data: Reference data for computing baseline (required for "mean"/"median")
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
PGI score (higher = explanation identified truly important features)
|
|
49
|
+
"""
|
|
50
|
+
instance = np.asarray(instance).flatten()
|
|
51
|
+
n_features = len(instance)
|
|
52
|
+
|
|
53
|
+
# Resolve k to integer
|
|
54
|
+
k_int = resolve_k(k, n_features)
|
|
55
|
+
|
|
56
|
+
# Get feature indices sorted by importance (most important first)
|
|
57
|
+
sorted_indices = get_sorted_feature_indices(explanation, descending=True)
|
|
58
|
+
top_k_indices = sorted_indices[:k_int]
|
|
59
|
+
|
|
60
|
+
# Compute baseline values
|
|
61
|
+
baseline_values = compute_baseline_values(
|
|
62
|
+
baseline, background_data, n_features
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
# Perturb instance by removing top-k important features
|
|
66
|
+
perturbed = apply_feature_mask(instance, top_k_indices, baseline_values)
|
|
67
|
+
|
|
68
|
+
# Compute prediction change
|
|
69
|
+
return compute_prediction_change(model, instance, perturbed, metric="absolute")
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def compute_pgu(
|
|
73
|
+
model,
|
|
74
|
+
instance: np.ndarray,
|
|
75
|
+
explanation: Explanation,
|
|
76
|
+
k: Union[int, float] = 0.2,
|
|
77
|
+
baseline: Union[str, float, np.ndarray, Callable] = "mean",
|
|
78
|
+
background_data: np.ndarray = None,
|
|
79
|
+
) -> float:
|
|
80
|
+
"""
|
|
81
|
+
Compute Prediction Gap on Unimportant features (PGU).
|
|
82
|
+
|
|
83
|
+
Measures prediction change when removing the bottom-k least important features.
|
|
84
|
+
Lower PGU indicates the explanation correctly identified unimportant features.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
model: Model adapter with predict/predict_proba method
|
|
88
|
+
instance: Input instance (1D array)
|
|
89
|
+
explanation: Explanation object with feature_attributions
|
|
90
|
+
k: Number of bottom features to remove (int) or fraction (float 0-1)
|
|
91
|
+
baseline: Baseline for feature replacement ("mean", "median", scalar, array, callable)
|
|
92
|
+
background_data: Reference data for computing baseline (required for "mean"/"median")
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
PGU score (lower = explanation correctly identified unimportant features)
|
|
96
|
+
"""
|
|
97
|
+
instance = np.asarray(instance).flatten()
|
|
98
|
+
n_features = len(instance)
|
|
99
|
+
|
|
100
|
+
# Resolve k to integer
|
|
101
|
+
k_int = resolve_k(k, n_features)
|
|
102
|
+
|
|
103
|
+
# Get feature indices sorted by importance (least important first for PGU)
|
|
104
|
+
sorted_indices = get_sorted_feature_indices(explanation, descending=False)
|
|
105
|
+
bottom_k_indices = sorted_indices[:k_int]
|
|
106
|
+
|
|
107
|
+
# Compute baseline values
|
|
108
|
+
baseline_values = compute_baseline_values(
|
|
109
|
+
baseline, background_data, n_features
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# Perturb instance by removing bottom-k unimportant features
|
|
113
|
+
perturbed = apply_feature_mask(instance, bottom_k_indices, baseline_values)
|
|
114
|
+
|
|
115
|
+
# Compute prediction change
|
|
116
|
+
return compute_prediction_change(model, instance, perturbed, metric="absolute")
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def compute_faithfulness_score(
|
|
120
|
+
model,
|
|
121
|
+
instance: np.ndarray,
|
|
122
|
+
explanation: Explanation,
|
|
123
|
+
k: Union[int, float] = 0.2,
|
|
124
|
+
baseline: Union[str, float, np.ndarray, Callable] = "mean",
|
|
125
|
+
background_data: np.ndarray = None,
|
|
126
|
+
epsilon: float = 1e-7,
|
|
127
|
+
) -> Dict[str, float]:
|
|
128
|
+
"""
|
|
129
|
+
Compute combined faithfulness metrics.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
model: Model adapter
|
|
133
|
+
instance: Input instance (1D array)
|
|
134
|
+
explanation: Explanation object
|
|
135
|
+
k: Number/fraction of features for PGI/PGU
|
|
136
|
+
baseline: Baseline for feature replacement
|
|
137
|
+
background_data: Reference data for baseline computation
|
|
138
|
+
epsilon: Small constant to avoid division by zero
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
Dictionary containing:
|
|
142
|
+
- pgi: Prediction Gap on Important features
|
|
143
|
+
- pgu: Prediction Gap on Unimportant features
|
|
144
|
+
- faithfulness_ratio: PGI / (PGU + epsilon) - higher is better
|
|
145
|
+
- faithfulness_diff: PGI - PGU - higher is better
|
|
146
|
+
"""
|
|
147
|
+
pgi = compute_pgi(model, instance, explanation, k, baseline, background_data)
|
|
148
|
+
pgu = compute_pgu(model, instance, explanation, k, baseline, background_data)
|
|
149
|
+
|
|
150
|
+
return {
|
|
151
|
+
"pgi": pgi,
|
|
152
|
+
"pgu": pgu,
|
|
153
|
+
"faithfulness_ratio": pgi / (pgu + epsilon),
|
|
154
|
+
"faithfulness_diff": pgi - pgu,
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def compute_comprehensiveness(
|
|
159
|
+
model,
|
|
160
|
+
instance: np.ndarray,
|
|
161
|
+
explanation: Explanation,
|
|
162
|
+
k_values: List[Union[int, float]] = None,
|
|
163
|
+
baseline: Union[str, float, np.ndarray, Callable] = "mean",
|
|
164
|
+
background_data: np.ndarray = None,
|
|
165
|
+
) -> Dict[str, float]:
|
|
166
|
+
"""
|
|
167
|
+
Compute comprehensiveness - how much prediction drops when removing important features.
|
|
168
|
+
|
|
169
|
+
This is essentially PGI computed at multiple k values and averaged.
|
|
170
|
+
Higher comprehensiveness = better explanation.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
model: Model adapter
|
|
174
|
+
instance: Input instance
|
|
175
|
+
explanation: Explanation object
|
|
176
|
+
k_values: List of k values to evaluate (default: [0.1, 0.2, 0.3])
|
|
177
|
+
baseline: Baseline for feature replacement
|
|
178
|
+
background_data: Reference data
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
Dictionary with per-k scores and mean comprehensiveness
|
|
182
|
+
"""
|
|
183
|
+
if k_values is None:
|
|
184
|
+
k_values = [0.1, 0.2, 0.3]
|
|
185
|
+
|
|
186
|
+
scores = {}
|
|
187
|
+
for k in k_values:
|
|
188
|
+
score = compute_pgi(model, instance, explanation, k, baseline, background_data)
|
|
189
|
+
scores[f"comp_k{k}"] = score
|
|
190
|
+
|
|
191
|
+
scores["comprehensiveness"] = np.mean(list(scores.values()))
|
|
192
|
+
return scores
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def compute_sufficiency(
|
|
196
|
+
model,
|
|
197
|
+
instance: np.ndarray,
|
|
198
|
+
explanation: Explanation,
|
|
199
|
+
k_values: List[Union[int, float]] = None,
|
|
200
|
+
baseline: Union[str, float, np.ndarray, Callable] = "mean",
|
|
201
|
+
background_data: np.ndarray = None,
|
|
202
|
+
) -> Dict[str, float]:
|
|
203
|
+
"""
|
|
204
|
+
Compute sufficiency - how much prediction is preserved when keeping only important features.
|
|
205
|
+
|
|
206
|
+
Lower sufficiency = the important features alone are sufficient for prediction.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
model: Model adapter
|
|
210
|
+
instance: Input instance
|
|
211
|
+
explanation: Explanation object
|
|
212
|
+
k_values: List of k values (fraction of features to KEEP)
|
|
213
|
+
baseline: Baseline for feature replacement
|
|
214
|
+
background_data: Reference data
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
Dictionary with per-k scores and mean sufficiency
|
|
218
|
+
"""
|
|
219
|
+
if k_values is None:
|
|
220
|
+
k_values = [0.1, 0.2, 0.3]
|
|
221
|
+
|
|
222
|
+
instance = np.asarray(instance).flatten()
|
|
223
|
+
n_features = len(instance)
|
|
224
|
+
|
|
225
|
+
# Get baseline values
|
|
226
|
+
baseline_values = compute_baseline_values(baseline, background_data, n_features)
|
|
227
|
+
|
|
228
|
+
# Get sorted indices (most important first)
|
|
229
|
+
sorted_indices = get_sorted_feature_indices(explanation, descending=True)
|
|
230
|
+
|
|
231
|
+
scores = {}
|
|
232
|
+
for k in k_values:
|
|
233
|
+
k_int = resolve_k(k, n_features)
|
|
234
|
+
|
|
235
|
+
# Keep only top-k features, replace rest with baseline
|
|
236
|
+
top_k_set = set(sorted_indices[:k_int])
|
|
237
|
+
indices_to_mask = [i for i in range(n_features) if i not in top_k_set]
|
|
238
|
+
|
|
239
|
+
perturbed = apply_feature_mask(instance, indices_to_mask, baseline_values)
|
|
240
|
+
change = compute_prediction_change(model, instance, perturbed, metric="absolute")
|
|
241
|
+
scores[f"suff_k{k}"] = change
|
|
242
|
+
|
|
243
|
+
scores["sufficiency"] = np.mean([v for k, v in scores.items() if k.startswith("suff_k")])
|
|
244
|
+
return scores
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def compute_faithfulness_correlation(
|
|
248
|
+
model,
|
|
249
|
+
instance: np.ndarray,
|
|
250
|
+
explanation: Explanation,
|
|
251
|
+
baseline: Union[str, float, np.ndarray, Callable] = "mean",
|
|
252
|
+
background_data: np.ndarray = None,
|
|
253
|
+
n_steps: int = None,
|
|
254
|
+
) -> float:
|
|
255
|
+
"""
|
|
256
|
+
Compute faithfulness correlation between attributions and prediction changes.
|
|
257
|
+
|
|
258
|
+
Measures correlation between feature importance ranking and actual impact
|
|
259
|
+
on predictions when features are removed one at a time.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
model: Model adapter
|
|
263
|
+
instance: Input instance
|
|
264
|
+
explanation: Explanation object
|
|
265
|
+
baseline: Baseline for feature replacement
|
|
266
|
+
background_data: Reference data
|
|
267
|
+
n_steps: Number of features to evaluate (default: all features)
|
|
268
|
+
|
|
269
|
+
Returns:
|
|
270
|
+
Pearson correlation coefficient (-1 to 1, higher is better)
|
|
271
|
+
"""
|
|
272
|
+
instance = np.asarray(instance).flatten()
|
|
273
|
+
n_features = len(instance)
|
|
274
|
+
|
|
275
|
+
if n_steps is None:
|
|
276
|
+
n_steps = n_features
|
|
277
|
+
n_steps = min(n_steps, n_features)
|
|
278
|
+
|
|
279
|
+
# Get attributions
|
|
280
|
+
attributions = explanation.explanation_data.get("feature_attributions", {})
|
|
281
|
+
sorted_indices = get_sorted_feature_indices(explanation, descending=True)[:n_steps]
|
|
282
|
+
|
|
283
|
+
# Get baseline
|
|
284
|
+
baseline_values = compute_baseline_values(baseline, background_data, n_features)
|
|
285
|
+
|
|
286
|
+
# Compute importance values and prediction changes for each feature
|
|
287
|
+
importance_values = []
|
|
288
|
+
prediction_changes = []
|
|
289
|
+
|
|
290
|
+
feature_names = getattr(explanation, 'feature_names', None)
|
|
291
|
+
|
|
292
|
+
for idx in sorted_indices:
|
|
293
|
+
# Get importance value for this feature
|
|
294
|
+
if feature_names and idx < len(feature_names):
|
|
295
|
+
fname = feature_names[idx]
|
|
296
|
+
else:
|
|
297
|
+
# Try common naming patterns
|
|
298
|
+
for pattern in [f"feature_{idx}", f"f{idx}", f"feat_{idx}"]:
|
|
299
|
+
if pattern in attributions:
|
|
300
|
+
fname = pattern
|
|
301
|
+
break
|
|
302
|
+
else:
|
|
303
|
+
fname = list(attributions.keys())[sorted_indices.index(idx)] if idx < len(attributions) else None
|
|
304
|
+
|
|
305
|
+
if fname and fname in attributions:
|
|
306
|
+
importance_values.append(abs(attributions[fname]))
|
|
307
|
+
else:
|
|
308
|
+
continue
|
|
309
|
+
|
|
310
|
+
# Compute prediction change when removing this single feature
|
|
311
|
+
perturbed = apply_feature_mask(instance, [idx], baseline_values)
|
|
312
|
+
change = compute_prediction_change(model, instance, perturbed, metric="absolute")
|
|
313
|
+
prediction_changes.append(change)
|
|
314
|
+
|
|
315
|
+
if len(importance_values) < 2:
|
|
316
|
+
return 0.0 # Not enough data points
|
|
317
|
+
|
|
318
|
+
# Compute Pearson correlation
|
|
319
|
+
return float(np.corrcoef(importance_values, prediction_changes)[0, 1])
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def compare_explainer_faithfulness(
|
|
323
|
+
model,
|
|
324
|
+
X: np.ndarray,
|
|
325
|
+
explanations: Dict[str, List[Explanation]],
|
|
326
|
+
k: Union[int, float] = 0.2,
|
|
327
|
+
baseline: Union[str, float, np.ndarray, Callable] = "mean",
|
|
328
|
+
max_samples: int = None,
|
|
329
|
+
) -> pd.DataFrame:
|
|
330
|
+
"""
|
|
331
|
+
Compare multiple explainers on faithfulness metrics across a dataset.
|
|
332
|
+
|
|
333
|
+
Args:
|
|
334
|
+
model: Model adapter
|
|
335
|
+
X: Input data (2D array, n_samples x n_features)
|
|
336
|
+
explanations: Dict mapping explainer names to lists of Explanation objects
|
|
337
|
+
k: Number/fraction of features for PGI/PGU
|
|
338
|
+
baseline: Baseline for feature replacement
|
|
339
|
+
max_samples: Limit number of samples to evaluate (None = all)
|
|
340
|
+
|
|
341
|
+
Returns:
|
|
342
|
+
DataFrame with columns: [explainer, mean_pgi, std_pgi, mean_pgu, std_pgu,
|
|
343
|
+
mean_ratio, mean_diff, n_samples]
|
|
344
|
+
"""
|
|
345
|
+
results = []
|
|
346
|
+
|
|
347
|
+
for explainer_name, expl_list in explanations.items():
|
|
348
|
+
n_samples = len(expl_list)
|
|
349
|
+
if max_samples:
|
|
350
|
+
n_samples = min(n_samples, max_samples)
|
|
351
|
+
|
|
352
|
+
pgi_scores = []
|
|
353
|
+
pgu_scores = []
|
|
354
|
+
|
|
355
|
+
for i in range(n_samples):
|
|
356
|
+
instance = X[i]
|
|
357
|
+
exp = expl_list[i]
|
|
358
|
+
|
|
359
|
+
try:
|
|
360
|
+
scores = compute_faithfulness_score(
|
|
361
|
+
model, instance, exp, k, baseline, X
|
|
362
|
+
)
|
|
363
|
+
pgi_scores.append(scores["pgi"])
|
|
364
|
+
pgu_scores.append(scores["pgu"])
|
|
365
|
+
except Exception as e:
|
|
366
|
+
# Skip instances that fail
|
|
367
|
+
continue
|
|
368
|
+
|
|
369
|
+
if pgi_scores:
|
|
370
|
+
results.append({
|
|
371
|
+
"explainer": explainer_name,
|
|
372
|
+
"mean_pgi": np.mean(pgi_scores),
|
|
373
|
+
"std_pgi": np.std(pgi_scores),
|
|
374
|
+
"mean_pgu": np.mean(pgu_scores),
|
|
375
|
+
"std_pgu": np.std(pgu_scores),
|
|
376
|
+
"mean_ratio": np.mean(pgi_scores) / (np.mean(pgu_scores) + 1e-7),
|
|
377
|
+
"mean_diff": np.mean(pgi_scores) - np.mean(pgu_scores),
|
|
378
|
+
"n_samples": len(pgi_scores),
|
|
379
|
+
})
|
|
380
|
+
|
|
381
|
+
return pd.DataFrame(results)
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def compute_batch_faithfulness(
|
|
385
|
+
model,
|
|
386
|
+
X: np.ndarray,
|
|
387
|
+
explanations: List[Explanation],
|
|
388
|
+
k: Union[int, float] = 0.2,
|
|
389
|
+
baseline: Union[str, float, np.ndarray, Callable] = "mean",
|
|
390
|
+
) -> Dict[str, float]:
|
|
391
|
+
"""
|
|
392
|
+
Compute average faithfulness metrics over a batch of instances.
|
|
393
|
+
|
|
394
|
+
Args:
|
|
395
|
+
model: Model adapter
|
|
396
|
+
X: Input data (2D array)
|
|
397
|
+
explanations: List of Explanation objects (one per instance)
|
|
398
|
+
k: Number/fraction of features for PGI/PGU
|
|
399
|
+
baseline: Baseline for feature replacement
|
|
400
|
+
|
|
401
|
+
Returns:
|
|
402
|
+
Dictionary with aggregated metrics
|
|
403
|
+
"""
|
|
404
|
+
pgi_scores = []
|
|
405
|
+
pgu_scores = []
|
|
406
|
+
|
|
407
|
+
for i, exp in enumerate(explanations):
|
|
408
|
+
try:
|
|
409
|
+
scores = compute_faithfulness_score(
|
|
410
|
+
model, X[i], exp, k, baseline, X
|
|
411
|
+
)
|
|
412
|
+
pgi_scores.append(scores["pgi"])
|
|
413
|
+
pgu_scores.append(scores["pgu"])
|
|
414
|
+
except Exception:
|
|
415
|
+
continue
|
|
416
|
+
|
|
417
|
+
if not pgi_scores:
|
|
418
|
+
return {"mean_pgi": 0.0, "mean_pgu": 0.0, "mean_ratio": 0.0, "n_samples": 0}
|
|
419
|
+
|
|
420
|
+
return {
|
|
421
|
+
"mean_pgi": np.mean(pgi_scores),
|
|
422
|
+
"std_pgi": np.std(pgi_scores),
|
|
423
|
+
"mean_pgu": np.mean(pgu_scores),
|
|
424
|
+
"std_pgu": np.std(pgu_scores),
|
|
425
|
+
"mean_ratio": np.mean(pgi_scores) / (np.mean(pgu_scores) + 1e-7),
|
|
426
|
+
"mean_diff": np.mean(pgi_scores) - np.mean(pgu_scores),
|
|
427
|
+
"n_samples": len(pgi_scores),
|
|
428
|
+
}
|
|
@@ -0,0 +1,379 @@
|
|
|
1
|
+
# src/explainiverse/evaluation/stability.py
|
|
2
|
+
"""
|
|
3
|
+
Stability evaluation metrics for explanations.
|
|
4
|
+
|
|
5
|
+
Implements:
|
|
6
|
+
- RIS (Relative Input Stability) - sensitivity to input perturbations
|
|
7
|
+
- ROS (Relative Output Stability) - consistency with similar predictions
|
|
8
|
+
- Lipschitz Estimate - local smoothness of explanations
|
|
9
|
+
"""
|
|
10
|
+
import numpy as np
|
|
11
|
+
from typing import Union, Callable, List, Dict, Optional, Tuple
|
|
12
|
+
from explainiverse.core.explanation import Explanation
|
|
13
|
+
from explainiverse.core.explainer import BaseExplainer
|
|
14
|
+
from explainiverse.evaluation._utils import get_prediction_value
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _extract_attribution_vector(explanation: Explanation) -> np.ndarray:
|
|
18
|
+
"""
|
|
19
|
+
Extract attribution values as a numpy array from an Explanation.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
explanation: Explanation object with feature_attributions
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
1D numpy array of attribution values
|
|
26
|
+
"""
|
|
27
|
+
attributions = explanation.explanation_data.get("feature_attributions", {})
|
|
28
|
+
if not attributions:
|
|
29
|
+
raise ValueError("No feature attributions found in explanation.")
|
|
30
|
+
|
|
31
|
+
# Get values in consistent order
|
|
32
|
+
feature_names = getattr(explanation, 'feature_names', None)
|
|
33
|
+
if feature_names:
|
|
34
|
+
values = [attributions.get(fn, 0.0) for fn in feature_names]
|
|
35
|
+
else:
|
|
36
|
+
values = list(attributions.values())
|
|
37
|
+
|
|
38
|
+
return np.array(values, dtype=float)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _normalize_vector(v: np.ndarray, epsilon: float = 1e-10) -> np.ndarray:
|
|
42
|
+
"""Normalize a vector to unit length."""
|
|
43
|
+
norm = np.linalg.norm(v)
|
|
44
|
+
if norm < epsilon:
|
|
45
|
+
return v
|
|
46
|
+
return v / norm
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def compute_ris(
|
|
50
|
+
explainer: BaseExplainer,
|
|
51
|
+
instance: np.ndarray,
|
|
52
|
+
n_perturbations: int = 10,
|
|
53
|
+
noise_scale: float = 0.01,
|
|
54
|
+
seed: int = None,
|
|
55
|
+
) -> float:
|
|
56
|
+
"""
|
|
57
|
+
Compute Relative Input Stability (RIS).
|
|
58
|
+
|
|
59
|
+
Measures how stable explanations are to small perturbations in the input.
|
|
60
|
+
Lower RIS indicates more stable explanations.
|
|
61
|
+
|
|
62
|
+
RIS = mean(||E(x) - E(x')|| / ||x - x'||) for perturbed inputs x'
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
explainer: Explainer instance with .explain() method
|
|
66
|
+
instance: Original input instance (1D array)
|
|
67
|
+
n_perturbations: Number of perturbed samples to generate
|
|
68
|
+
noise_scale: Standard deviation of Gaussian noise (relative to feature range)
|
|
69
|
+
seed: Random seed for reproducibility
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
RIS score (lower = more stable)
|
|
73
|
+
"""
|
|
74
|
+
if seed is not None:
|
|
75
|
+
np.random.seed(seed)
|
|
76
|
+
|
|
77
|
+
instance = np.asarray(instance).flatten()
|
|
78
|
+
n_features = len(instance)
|
|
79
|
+
|
|
80
|
+
# Get original explanation
|
|
81
|
+
original_exp = explainer.explain(instance)
|
|
82
|
+
original_exp.feature_names = getattr(original_exp, 'feature_names', None) or \
|
|
83
|
+
[f"feature_{i}" for i in range(n_features)]
|
|
84
|
+
original_attr = _extract_attribution_vector(original_exp)
|
|
85
|
+
|
|
86
|
+
ratios = []
|
|
87
|
+
|
|
88
|
+
for _ in range(n_perturbations):
|
|
89
|
+
# Generate perturbed input
|
|
90
|
+
noise = np.random.normal(0, noise_scale, n_features)
|
|
91
|
+
perturbed = instance + noise * np.abs(instance + 1e-10) # Scale noise by feature magnitude
|
|
92
|
+
|
|
93
|
+
# Get perturbed explanation
|
|
94
|
+
try:
|
|
95
|
+
perturbed_exp = explainer.explain(perturbed)
|
|
96
|
+
perturbed_exp.feature_names = original_exp.feature_names
|
|
97
|
+
perturbed_attr = _extract_attribution_vector(perturbed_exp)
|
|
98
|
+
except Exception:
|
|
99
|
+
continue
|
|
100
|
+
|
|
101
|
+
# Compute ratio of changes
|
|
102
|
+
attr_diff = np.linalg.norm(original_attr - perturbed_attr)
|
|
103
|
+
input_diff = np.linalg.norm(instance - perturbed)
|
|
104
|
+
|
|
105
|
+
if input_diff > 1e-10:
|
|
106
|
+
ratios.append(attr_diff / input_diff)
|
|
107
|
+
|
|
108
|
+
if not ratios:
|
|
109
|
+
return float('inf')
|
|
110
|
+
|
|
111
|
+
return float(np.mean(ratios))
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def compute_ros(
|
|
115
|
+
explainer: BaseExplainer,
|
|
116
|
+
model,
|
|
117
|
+
instance: np.ndarray,
|
|
118
|
+
reference_instances: np.ndarray,
|
|
119
|
+
n_neighbors: int = 5,
|
|
120
|
+
prediction_threshold: float = 0.05,
|
|
121
|
+
) -> float:
|
|
122
|
+
"""
|
|
123
|
+
Compute Relative Output Stability (ROS).
|
|
124
|
+
|
|
125
|
+
Measures how similar explanations are for instances with similar predictions.
|
|
126
|
+
Higher ROS indicates more consistent explanations.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
explainer: Explainer instance with .explain() method
|
|
130
|
+
model: Model adapter with predict/predict_proba method
|
|
131
|
+
instance: Query instance
|
|
132
|
+
reference_instances: Pool of reference instances to find neighbors
|
|
133
|
+
n_neighbors: Number of neighbors to compare
|
|
134
|
+
prediction_threshold: Maximum prediction difference to consider "similar"
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
ROS score (higher = more consistent for similar predictions)
|
|
138
|
+
"""
|
|
139
|
+
instance = np.asarray(instance).flatten()
|
|
140
|
+
n_features = len(instance)
|
|
141
|
+
|
|
142
|
+
# Get prediction for query instance
|
|
143
|
+
query_pred = get_prediction_value(model, instance)
|
|
144
|
+
|
|
145
|
+
# Find instances with similar predictions
|
|
146
|
+
similar_instances = []
|
|
147
|
+
for ref in reference_instances:
|
|
148
|
+
ref = np.asarray(ref).flatten()
|
|
149
|
+
ref_pred = get_prediction_value(model, ref)
|
|
150
|
+
if abs(query_pred - ref_pred) <= prediction_threshold:
|
|
151
|
+
similar_instances.append(ref)
|
|
152
|
+
|
|
153
|
+
if len(similar_instances) < 2:
|
|
154
|
+
return 1.0 # Perfect stability if no similar instances
|
|
155
|
+
|
|
156
|
+
# Limit to n_neighbors
|
|
157
|
+
similar_instances = similar_instances[:n_neighbors]
|
|
158
|
+
|
|
159
|
+
# Get explanation for query
|
|
160
|
+
query_exp = explainer.explain(instance)
|
|
161
|
+
query_exp.feature_names = getattr(query_exp, 'feature_names', None) or \
|
|
162
|
+
[f"feature_{i}" for i in range(n_features)]
|
|
163
|
+
query_attr = _normalize_vector(_extract_attribution_vector(query_exp))
|
|
164
|
+
|
|
165
|
+
# Get explanations for similar instances and compute similarity
|
|
166
|
+
similarities = []
|
|
167
|
+
for ref in similar_instances:
|
|
168
|
+
try:
|
|
169
|
+
ref_exp = explainer.explain(ref)
|
|
170
|
+
ref_exp.feature_names = query_exp.feature_names
|
|
171
|
+
ref_attr = _normalize_vector(_extract_attribution_vector(ref_exp))
|
|
172
|
+
|
|
173
|
+
# Cosine similarity
|
|
174
|
+
similarity = np.dot(query_attr, ref_attr)
|
|
175
|
+
similarities.append(similarity)
|
|
176
|
+
except Exception:
|
|
177
|
+
continue
|
|
178
|
+
|
|
179
|
+
if not similarities:
|
|
180
|
+
return 1.0
|
|
181
|
+
|
|
182
|
+
return float(np.mean(similarities))
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def compute_lipschitz_estimate(
|
|
186
|
+
explainer: BaseExplainer,
|
|
187
|
+
instance: np.ndarray,
|
|
188
|
+
n_samples: int = 20,
|
|
189
|
+
radius: float = 0.1,
|
|
190
|
+
seed: int = None,
|
|
191
|
+
) -> float:
|
|
192
|
+
"""
|
|
193
|
+
Estimate local Lipschitz constant of the explanation function.
|
|
194
|
+
|
|
195
|
+
The Lipschitz constant bounds how fast explanations can change:
|
|
196
|
+
||E(x) - E(y)|| <= L * ||x - y||
|
|
197
|
+
|
|
198
|
+
Lower L indicates smoother, more stable explanations.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
explainer: Explainer instance
|
|
202
|
+
instance: Center point for local estimate
|
|
203
|
+
n_samples: Number of sample pairs to evaluate
|
|
204
|
+
radius: Radius of ball around instance to sample from
|
|
205
|
+
seed: Random seed
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
Estimated local Lipschitz constant (lower = smoother)
|
|
209
|
+
"""
|
|
210
|
+
if seed is not None:
|
|
211
|
+
np.random.seed(seed)
|
|
212
|
+
|
|
213
|
+
instance = np.asarray(instance).flatten()
|
|
214
|
+
n_features = len(instance)
|
|
215
|
+
|
|
216
|
+
max_ratio = 0.0
|
|
217
|
+
|
|
218
|
+
for _ in range(n_samples):
|
|
219
|
+
# Generate two random points in a ball around instance
|
|
220
|
+
direction1 = np.random.randn(n_features)
|
|
221
|
+
direction1 = direction1 / np.linalg.norm(direction1)
|
|
222
|
+
r1 = np.random.uniform(0, radius)
|
|
223
|
+
point1 = instance + r1 * direction1
|
|
224
|
+
|
|
225
|
+
direction2 = np.random.randn(n_features)
|
|
226
|
+
direction2 = direction2 / np.linalg.norm(direction2)
|
|
227
|
+
r2 = np.random.uniform(0, radius)
|
|
228
|
+
point2 = instance + r2 * direction2
|
|
229
|
+
|
|
230
|
+
try:
|
|
231
|
+
exp1 = explainer.explain(point1)
|
|
232
|
+
exp1.feature_names = [f"feature_{i}" for i in range(n_features)]
|
|
233
|
+
attr1 = _extract_attribution_vector(exp1)
|
|
234
|
+
|
|
235
|
+
exp2 = explainer.explain(point2)
|
|
236
|
+
exp2.feature_names = exp1.feature_names
|
|
237
|
+
attr2 = _extract_attribution_vector(exp2)
|
|
238
|
+
except Exception:
|
|
239
|
+
continue
|
|
240
|
+
|
|
241
|
+
attr_diff = np.linalg.norm(attr1 - attr2)
|
|
242
|
+
input_diff = np.linalg.norm(point1 - point2)
|
|
243
|
+
|
|
244
|
+
if input_diff > 1e-10:
|
|
245
|
+
ratio = attr_diff / input_diff
|
|
246
|
+
max_ratio = max(max_ratio, ratio)
|
|
247
|
+
|
|
248
|
+
return float(max_ratio)
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def compute_stability_metrics(
|
|
252
|
+
explainer: BaseExplainer,
|
|
253
|
+
model,
|
|
254
|
+
instance: np.ndarray,
|
|
255
|
+
background_data: np.ndarray,
|
|
256
|
+
n_perturbations: int = 10,
|
|
257
|
+
noise_scale: float = 0.01,
|
|
258
|
+
n_neighbors: int = 5,
|
|
259
|
+
seed: int = None,
|
|
260
|
+
) -> Dict[str, float]:
|
|
261
|
+
"""
|
|
262
|
+
Compute comprehensive stability metrics for a single instance.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
explainer: Explainer instance
|
|
266
|
+
model: Model adapter
|
|
267
|
+
instance: Query instance
|
|
268
|
+
background_data: Reference data for ROS computation
|
|
269
|
+
n_perturbations: Number of perturbations for RIS
|
|
270
|
+
noise_scale: Noise scale for RIS
|
|
271
|
+
n_neighbors: Number of neighbors for ROS
|
|
272
|
+
seed: Random seed
|
|
273
|
+
|
|
274
|
+
Returns:
|
|
275
|
+
Dictionary with RIS, ROS, and Lipschitz estimate
|
|
276
|
+
"""
|
|
277
|
+
return {
|
|
278
|
+
"ris": compute_ris(explainer, instance, n_perturbations, noise_scale, seed),
|
|
279
|
+
"ros": compute_ros(explainer, model, instance, background_data, n_neighbors),
|
|
280
|
+
"lipschitz": compute_lipschitz_estimate(explainer, instance, seed=seed),
|
|
281
|
+
}
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def compute_batch_stability(
|
|
285
|
+
explainer: BaseExplainer,
|
|
286
|
+
model,
|
|
287
|
+
X: np.ndarray,
|
|
288
|
+
n_perturbations: int = 10,
|
|
289
|
+
noise_scale: float = 0.01,
|
|
290
|
+
max_samples: int = None,
|
|
291
|
+
seed: int = None,
|
|
292
|
+
) -> Dict[str, float]:
|
|
293
|
+
"""
|
|
294
|
+
Compute average stability metrics over a batch of instances.
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
explainer: Explainer instance
|
|
298
|
+
model: Model adapter
|
|
299
|
+
X: Input data (2D array)
|
|
300
|
+
n_perturbations: Number of perturbations per instance
|
|
301
|
+
noise_scale: Noise scale for perturbations
|
|
302
|
+
max_samples: Maximum number of samples to evaluate
|
|
303
|
+
seed: Random seed
|
|
304
|
+
|
|
305
|
+
Returns:
|
|
306
|
+
Dictionary with mean and std of stability metrics
|
|
307
|
+
"""
|
|
308
|
+
n_samples = len(X)
|
|
309
|
+
if max_samples:
|
|
310
|
+
n_samples = min(n_samples, max_samples)
|
|
311
|
+
|
|
312
|
+
ris_scores = []
|
|
313
|
+
ros_scores = []
|
|
314
|
+
|
|
315
|
+
for i in range(n_samples):
|
|
316
|
+
instance = X[i]
|
|
317
|
+
|
|
318
|
+
try:
|
|
319
|
+
ris = compute_ris(explainer, instance, n_perturbations, noise_scale, seed)
|
|
320
|
+
if not np.isinf(ris):
|
|
321
|
+
ris_scores.append(ris)
|
|
322
|
+
|
|
323
|
+
ros = compute_ros(explainer, model, instance, X, n_neighbors=5)
|
|
324
|
+
ros_scores.append(ros)
|
|
325
|
+
except Exception:
|
|
326
|
+
continue
|
|
327
|
+
|
|
328
|
+
results = {"n_samples": len(ris_scores)}
|
|
329
|
+
|
|
330
|
+
if ris_scores:
|
|
331
|
+
results["mean_ris"] = np.mean(ris_scores)
|
|
332
|
+
results["std_ris"] = np.std(ris_scores)
|
|
333
|
+
else:
|
|
334
|
+
results["mean_ris"] = float('inf')
|
|
335
|
+
results["std_ris"] = 0.0
|
|
336
|
+
|
|
337
|
+
if ros_scores:
|
|
338
|
+
results["mean_ros"] = np.mean(ros_scores)
|
|
339
|
+
results["std_ros"] = np.std(ros_scores)
|
|
340
|
+
else:
|
|
341
|
+
results["mean_ros"] = 0.0
|
|
342
|
+
results["std_ros"] = 0.0
|
|
343
|
+
|
|
344
|
+
return results
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
def compare_explainer_stability(
|
|
348
|
+
explainers: Dict[str, BaseExplainer],
|
|
349
|
+
model,
|
|
350
|
+
X: np.ndarray,
|
|
351
|
+
n_perturbations: int = 5,
|
|
352
|
+
noise_scale: float = 0.01,
|
|
353
|
+
max_samples: int = 20,
|
|
354
|
+
seed: int = None,
|
|
355
|
+
) -> Dict[str, Dict[str, float]]:
|
|
356
|
+
"""
|
|
357
|
+
Compare stability metrics across multiple explainers.
|
|
358
|
+
|
|
359
|
+
Args:
|
|
360
|
+
explainers: Dict mapping explainer names to explainer instances
|
|
361
|
+
model: Model adapter
|
|
362
|
+
X: Input data
|
|
363
|
+
n_perturbations: Number of perturbations per instance
|
|
364
|
+
noise_scale: Noise scale
|
|
365
|
+
max_samples: Max samples to evaluate per explainer
|
|
366
|
+
seed: Random seed
|
|
367
|
+
|
|
368
|
+
Returns:
|
|
369
|
+
Dict mapping explainer names to their stability metrics
|
|
370
|
+
"""
|
|
371
|
+
results = {}
|
|
372
|
+
|
|
373
|
+
for name, explainer in explainers.items():
|
|
374
|
+
metrics = compute_batch_stability(
|
|
375
|
+
explainer, model, X, n_perturbations, noise_scale, max_samples, seed
|
|
376
|
+
)
|
|
377
|
+
results[name] = metrics
|
|
378
|
+
|
|
379
|
+
return results
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: explainiverse
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0
|
|
4
4
|
Summary: Unified, extensible explainability framework supporting LIME, SHAP, Anchors, Counterfactuals, PDP, ALE, SAGE, and more
|
|
5
5
|
Home-page: https://github.com/jemsbhai/explainiverse
|
|
6
6
|
License: MIT
|
|
@@ -20,6 +20,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
|
20
20
|
Provides-Extra: torch
|
|
21
21
|
Requires-Dist: lime (>=0.2.0.1,<0.3.0.0)
|
|
22
22
|
Requires-Dist: numpy (>=1.24,<2.0)
|
|
23
|
+
Requires-Dist: pandas (>=1.5,<3.0)
|
|
23
24
|
Requires-Dist: scikit-learn (>=1.1,<1.6)
|
|
24
25
|
Requires-Dist: scipy (>=1.10,<2.0)
|
|
25
26
|
Requires-Dist: shap (>=0.48.0,<0.49.0)
|
|
@@ -9,8 +9,11 @@ explainiverse/core/explanation.py,sha256=6zxFh_TH8tFHc-r_H5-WHQ05Sp1Kp2TxLz3gyFe
|
|
|
9
9
|
explainiverse/core/registry.py,sha256=f1GAo2tg6Sjyz-uOPyLukYYSUgMmpb95pI3B6O-5jjo,22992
|
|
10
10
|
explainiverse/engine/__init__.py,sha256=1sZO8nH1mmwK2e-KUavBQm7zYDWUe27nyWoFy9tgsiA,197
|
|
11
11
|
explainiverse/engine/suite.py,sha256=sq8SK_6Pf0qRckTmVJ7Mdosu9bhkjAGPGN8ymLGFP9E,4914
|
|
12
|
-
explainiverse/evaluation/__init__.py,sha256=
|
|
12
|
+
explainiverse/evaluation/__init__.py,sha256=ePE97KwSjg_IChZ03DeQax8GruTjx-BVrMSi_nzoyoA,1501
|
|
13
|
+
explainiverse/evaluation/_utils.py,sha256=ej7YOPZ90gVHuuIMj45EXHq9Jx3QG7lhaj5sk26hRpg,10519
|
|
14
|
+
explainiverse/evaluation/faithfulness.py,sha256=_40afOW6vJ3dQguHlJySlgWqiJF_xIvN-uVA3nPKRvI,14841
|
|
13
15
|
explainiverse/evaluation/metrics.py,sha256=tSBXtyA_-0zOGCGjlPZU6LdGKRH_QpWfgKa78sdlovs,7453
|
|
16
|
+
explainiverse/evaluation/stability.py,sha256=q2d3rpxpp0X1s6ADST1iZA4tzksLJpR0mYBnA_U5FIs,12090
|
|
14
17
|
explainiverse/explainers/__init__.py,sha256=d7DTbUXzdVdN0l5GQnoJ4zzutI0TXNvx0UzwNXoWY9w,2207
|
|
15
18
|
explainiverse/explainers/attribution/__init__.py,sha256=YeVs9bS_IWDtqGbp6T37V6Zp5ZDWzLdAXHxxyFGpiQM,431
|
|
16
19
|
explainiverse/explainers/attribution/lime_wrapper.py,sha256=OnXIV7t6yd-vt38sIi7XmHFbgzlZfCEbRlFyGGd5XiE,3245
|
|
@@ -29,7 +32,7 @@ explainiverse/explainers/gradient/gradcam.py,sha256=ywW_8PhALwegkpSUDQMFvvVFkA5N
|
|
|
29
32
|
explainiverse/explainers/gradient/integrated_gradients.py,sha256=feBgY3Vw2rDti7fxRZtLkxse75m2dbP_R05ARqo2BRM,13367
|
|
30
33
|
explainiverse/explainers/rule_based/__init__.py,sha256=gKzlFCAzwurAMLJcuYgal4XhDj1thteBGcaHWmN7iWk,243
|
|
31
34
|
explainiverse/explainers/rule_based/anchors_wrapper.py,sha256=ML7W6aam-eMGZHy5ilol8qupZvNBJpYAFatEEPnuMyo,13254
|
|
32
|
-
explainiverse-0.
|
|
33
|
-
explainiverse-0.
|
|
34
|
-
explainiverse-0.
|
|
35
|
-
explainiverse-0.
|
|
35
|
+
explainiverse-0.3.0.dist-info/LICENSE,sha256=28rbHe8rJgmUlRdxJACfq1Sj-MtCEhyHxkJedQd1ZYA,1070
|
|
36
|
+
explainiverse-0.3.0.dist-info/METADATA,sha256=F53kTds8YbDDtEes-9dC6lDlxAYXekNSMtIGvwB1eY4,11518
|
|
37
|
+
explainiverse-0.3.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
38
|
+
explainiverse-0.3.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|