explainiverse 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- explainiverse/__init__.py +1 -1
- explainiverse/core/registry.py +35 -0
- explainiverse/evaluation/__init__.py +54 -2
- explainiverse/evaluation/_utils.py +325 -0
- explainiverse/evaluation/faithfulness.py +428 -0
- explainiverse/evaluation/stability.py +379 -0
- explainiverse/explainers/__init__.py +3 -0
- explainiverse/explainers/gradient/__init__.py +7 -1
- explainiverse/explainers/gradient/deeplift.py +745 -0
- {explainiverse-0.2.4.dist-info → explainiverse-0.3.0.dist-info}/METADATA +2 -1
- {explainiverse-0.2.4.dist-info → explainiverse-0.3.0.dist-info}/RECORD +13 -9
- {explainiverse-0.2.4.dist-info → explainiverse-0.3.0.dist-info}/LICENSE +0 -0
- {explainiverse-0.2.4.dist-info → explainiverse-0.3.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,428 @@
|
|
|
1
|
+
# src/explainiverse/evaluation/faithfulness.py
|
|
2
|
+
"""
|
|
3
|
+
Faithfulness evaluation metrics for explanations.
|
|
4
|
+
|
|
5
|
+
Implements:
|
|
6
|
+
- PGI (Prediction Gap on Important features)
|
|
7
|
+
- PGU (Prediction Gap on Unimportant features)
|
|
8
|
+
- Faithfulness Correlation
|
|
9
|
+
- Comprehensiveness and Sufficiency
|
|
10
|
+
"""
|
|
11
|
+
import numpy as np
|
|
12
|
+
import pandas as pd
|
|
13
|
+
from typing import Union, Callable, List, Dict, Optional
|
|
14
|
+
from explainiverse.core.explanation import Explanation
|
|
15
|
+
from explainiverse.evaluation._utils import (
|
|
16
|
+
get_sorted_feature_indices,
|
|
17
|
+
compute_baseline_values,
|
|
18
|
+
apply_feature_mask,
|
|
19
|
+
resolve_k,
|
|
20
|
+
get_prediction_value,
|
|
21
|
+
compute_prediction_change,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def compute_pgi(
|
|
26
|
+
model,
|
|
27
|
+
instance: np.ndarray,
|
|
28
|
+
explanation: Explanation,
|
|
29
|
+
k: Union[int, float] = 0.2,
|
|
30
|
+
baseline: Union[str, float, np.ndarray, Callable] = "mean",
|
|
31
|
+
background_data: np.ndarray = None,
|
|
32
|
+
) -> float:
|
|
33
|
+
"""
|
|
34
|
+
Compute Prediction Gap on Important features (PGI).
|
|
35
|
+
|
|
36
|
+
Measures prediction change when removing the top-k most important features.
|
|
37
|
+
Higher PGI indicates the explanation correctly identified important features.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
model: Model adapter with predict/predict_proba method
|
|
41
|
+
instance: Input instance (1D array)
|
|
42
|
+
explanation: Explanation object with feature_attributions
|
|
43
|
+
k: Number of top features to remove (int) or fraction (float 0-1)
|
|
44
|
+
baseline: Baseline for feature replacement ("mean", "median", scalar, array, callable)
|
|
45
|
+
background_data: Reference data for computing baseline (required for "mean"/"median")
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
PGI score (higher = explanation identified truly important features)
|
|
49
|
+
"""
|
|
50
|
+
instance = np.asarray(instance).flatten()
|
|
51
|
+
n_features = len(instance)
|
|
52
|
+
|
|
53
|
+
# Resolve k to integer
|
|
54
|
+
k_int = resolve_k(k, n_features)
|
|
55
|
+
|
|
56
|
+
# Get feature indices sorted by importance (most important first)
|
|
57
|
+
sorted_indices = get_sorted_feature_indices(explanation, descending=True)
|
|
58
|
+
top_k_indices = sorted_indices[:k_int]
|
|
59
|
+
|
|
60
|
+
# Compute baseline values
|
|
61
|
+
baseline_values = compute_baseline_values(
|
|
62
|
+
baseline, background_data, n_features
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
# Perturb instance by removing top-k important features
|
|
66
|
+
perturbed = apply_feature_mask(instance, top_k_indices, baseline_values)
|
|
67
|
+
|
|
68
|
+
# Compute prediction change
|
|
69
|
+
return compute_prediction_change(model, instance, perturbed, metric="absolute")
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def compute_pgu(
|
|
73
|
+
model,
|
|
74
|
+
instance: np.ndarray,
|
|
75
|
+
explanation: Explanation,
|
|
76
|
+
k: Union[int, float] = 0.2,
|
|
77
|
+
baseline: Union[str, float, np.ndarray, Callable] = "mean",
|
|
78
|
+
background_data: np.ndarray = None,
|
|
79
|
+
) -> float:
|
|
80
|
+
"""
|
|
81
|
+
Compute Prediction Gap on Unimportant features (PGU).
|
|
82
|
+
|
|
83
|
+
Measures prediction change when removing the bottom-k least important features.
|
|
84
|
+
Lower PGU indicates the explanation correctly identified unimportant features.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
model: Model adapter with predict/predict_proba method
|
|
88
|
+
instance: Input instance (1D array)
|
|
89
|
+
explanation: Explanation object with feature_attributions
|
|
90
|
+
k: Number of bottom features to remove (int) or fraction (float 0-1)
|
|
91
|
+
baseline: Baseline for feature replacement ("mean", "median", scalar, array, callable)
|
|
92
|
+
background_data: Reference data for computing baseline (required for "mean"/"median")
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
PGU score (lower = explanation correctly identified unimportant features)
|
|
96
|
+
"""
|
|
97
|
+
instance = np.asarray(instance).flatten()
|
|
98
|
+
n_features = len(instance)
|
|
99
|
+
|
|
100
|
+
# Resolve k to integer
|
|
101
|
+
k_int = resolve_k(k, n_features)
|
|
102
|
+
|
|
103
|
+
# Get feature indices sorted by importance (least important first for PGU)
|
|
104
|
+
sorted_indices = get_sorted_feature_indices(explanation, descending=False)
|
|
105
|
+
bottom_k_indices = sorted_indices[:k_int]
|
|
106
|
+
|
|
107
|
+
# Compute baseline values
|
|
108
|
+
baseline_values = compute_baseline_values(
|
|
109
|
+
baseline, background_data, n_features
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# Perturb instance by removing bottom-k unimportant features
|
|
113
|
+
perturbed = apply_feature_mask(instance, bottom_k_indices, baseline_values)
|
|
114
|
+
|
|
115
|
+
# Compute prediction change
|
|
116
|
+
return compute_prediction_change(model, instance, perturbed, metric="absolute")
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def compute_faithfulness_score(
|
|
120
|
+
model,
|
|
121
|
+
instance: np.ndarray,
|
|
122
|
+
explanation: Explanation,
|
|
123
|
+
k: Union[int, float] = 0.2,
|
|
124
|
+
baseline: Union[str, float, np.ndarray, Callable] = "mean",
|
|
125
|
+
background_data: np.ndarray = None,
|
|
126
|
+
epsilon: float = 1e-7,
|
|
127
|
+
) -> Dict[str, float]:
|
|
128
|
+
"""
|
|
129
|
+
Compute combined faithfulness metrics.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
model: Model adapter
|
|
133
|
+
instance: Input instance (1D array)
|
|
134
|
+
explanation: Explanation object
|
|
135
|
+
k: Number/fraction of features for PGI/PGU
|
|
136
|
+
baseline: Baseline for feature replacement
|
|
137
|
+
background_data: Reference data for baseline computation
|
|
138
|
+
epsilon: Small constant to avoid division by zero
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
Dictionary containing:
|
|
142
|
+
- pgi: Prediction Gap on Important features
|
|
143
|
+
- pgu: Prediction Gap on Unimportant features
|
|
144
|
+
- faithfulness_ratio: PGI / (PGU + epsilon) - higher is better
|
|
145
|
+
- faithfulness_diff: PGI - PGU - higher is better
|
|
146
|
+
"""
|
|
147
|
+
pgi = compute_pgi(model, instance, explanation, k, baseline, background_data)
|
|
148
|
+
pgu = compute_pgu(model, instance, explanation, k, baseline, background_data)
|
|
149
|
+
|
|
150
|
+
return {
|
|
151
|
+
"pgi": pgi,
|
|
152
|
+
"pgu": pgu,
|
|
153
|
+
"faithfulness_ratio": pgi / (pgu + epsilon),
|
|
154
|
+
"faithfulness_diff": pgi - pgu,
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def compute_comprehensiveness(
|
|
159
|
+
model,
|
|
160
|
+
instance: np.ndarray,
|
|
161
|
+
explanation: Explanation,
|
|
162
|
+
k_values: List[Union[int, float]] = None,
|
|
163
|
+
baseline: Union[str, float, np.ndarray, Callable] = "mean",
|
|
164
|
+
background_data: np.ndarray = None,
|
|
165
|
+
) -> Dict[str, float]:
|
|
166
|
+
"""
|
|
167
|
+
Compute comprehensiveness - how much prediction drops when removing important features.
|
|
168
|
+
|
|
169
|
+
This is essentially PGI computed at multiple k values and averaged.
|
|
170
|
+
Higher comprehensiveness = better explanation.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
model: Model adapter
|
|
174
|
+
instance: Input instance
|
|
175
|
+
explanation: Explanation object
|
|
176
|
+
k_values: List of k values to evaluate (default: [0.1, 0.2, 0.3])
|
|
177
|
+
baseline: Baseline for feature replacement
|
|
178
|
+
background_data: Reference data
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
Dictionary with per-k scores and mean comprehensiveness
|
|
182
|
+
"""
|
|
183
|
+
if k_values is None:
|
|
184
|
+
k_values = [0.1, 0.2, 0.3]
|
|
185
|
+
|
|
186
|
+
scores = {}
|
|
187
|
+
for k in k_values:
|
|
188
|
+
score = compute_pgi(model, instance, explanation, k, baseline, background_data)
|
|
189
|
+
scores[f"comp_k{k}"] = score
|
|
190
|
+
|
|
191
|
+
scores["comprehensiveness"] = np.mean(list(scores.values()))
|
|
192
|
+
return scores
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def compute_sufficiency(
|
|
196
|
+
model,
|
|
197
|
+
instance: np.ndarray,
|
|
198
|
+
explanation: Explanation,
|
|
199
|
+
k_values: List[Union[int, float]] = None,
|
|
200
|
+
baseline: Union[str, float, np.ndarray, Callable] = "mean",
|
|
201
|
+
background_data: np.ndarray = None,
|
|
202
|
+
) -> Dict[str, float]:
|
|
203
|
+
"""
|
|
204
|
+
Compute sufficiency - how much prediction is preserved when keeping only important features.
|
|
205
|
+
|
|
206
|
+
Lower sufficiency = the important features alone are sufficient for prediction.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
model: Model adapter
|
|
210
|
+
instance: Input instance
|
|
211
|
+
explanation: Explanation object
|
|
212
|
+
k_values: List of k values (fraction of features to KEEP)
|
|
213
|
+
baseline: Baseline for feature replacement
|
|
214
|
+
background_data: Reference data
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
Dictionary with per-k scores and mean sufficiency
|
|
218
|
+
"""
|
|
219
|
+
if k_values is None:
|
|
220
|
+
k_values = [0.1, 0.2, 0.3]
|
|
221
|
+
|
|
222
|
+
instance = np.asarray(instance).flatten()
|
|
223
|
+
n_features = len(instance)
|
|
224
|
+
|
|
225
|
+
# Get baseline values
|
|
226
|
+
baseline_values = compute_baseline_values(baseline, background_data, n_features)
|
|
227
|
+
|
|
228
|
+
# Get sorted indices (most important first)
|
|
229
|
+
sorted_indices = get_sorted_feature_indices(explanation, descending=True)
|
|
230
|
+
|
|
231
|
+
scores = {}
|
|
232
|
+
for k in k_values:
|
|
233
|
+
k_int = resolve_k(k, n_features)
|
|
234
|
+
|
|
235
|
+
# Keep only top-k features, replace rest with baseline
|
|
236
|
+
top_k_set = set(sorted_indices[:k_int])
|
|
237
|
+
indices_to_mask = [i for i in range(n_features) if i not in top_k_set]
|
|
238
|
+
|
|
239
|
+
perturbed = apply_feature_mask(instance, indices_to_mask, baseline_values)
|
|
240
|
+
change = compute_prediction_change(model, instance, perturbed, metric="absolute")
|
|
241
|
+
scores[f"suff_k{k}"] = change
|
|
242
|
+
|
|
243
|
+
scores["sufficiency"] = np.mean([v for k, v in scores.items() if k.startswith("suff_k")])
|
|
244
|
+
return scores
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def compute_faithfulness_correlation(
|
|
248
|
+
model,
|
|
249
|
+
instance: np.ndarray,
|
|
250
|
+
explanation: Explanation,
|
|
251
|
+
baseline: Union[str, float, np.ndarray, Callable] = "mean",
|
|
252
|
+
background_data: np.ndarray = None,
|
|
253
|
+
n_steps: int = None,
|
|
254
|
+
) -> float:
|
|
255
|
+
"""
|
|
256
|
+
Compute faithfulness correlation between attributions and prediction changes.
|
|
257
|
+
|
|
258
|
+
Measures correlation between feature importance ranking and actual impact
|
|
259
|
+
on predictions when features are removed one at a time.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
model: Model adapter
|
|
263
|
+
instance: Input instance
|
|
264
|
+
explanation: Explanation object
|
|
265
|
+
baseline: Baseline for feature replacement
|
|
266
|
+
background_data: Reference data
|
|
267
|
+
n_steps: Number of features to evaluate (default: all features)
|
|
268
|
+
|
|
269
|
+
Returns:
|
|
270
|
+
Pearson correlation coefficient (-1 to 1, higher is better)
|
|
271
|
+
"""
|
|
272
|
+
instance = np.asarray(instance).flatten()
|
|
273
|
+
n_features = len(instance)
|
|
274
|
+
|
|
275
|
+
if n_steps is None:
|
|
276
|
+
n_steps = n_features
|
|
277
|
+
n_steps = min(n_steps, n_features)
|
|
278
|
+
|
|
279
|
+
# Get attributions
|
|
280
|
+
attributions = explanation.explanation_data.get("feature_attributions", {})
|
|
281
|
+
sorted_indices = get_sorted_feature_indices(explanation, descending=True)[:n_steps]
|
|
282
|
+
|
|
283
|
+
# Get baseline
|
|
284
|
+
baseline_values = compute_baseline_values(baseline, background_data, n_features)
|
|
285
|
+
|
|
286
|
+
# Compute importance values and prediction changes for each feature
|
|
287
|
+
importance_values = []
|
|
288
|
+
prediction_changes = []
|
|
289
|
+
|
|
290
|
+
feature_names = getattr(explanation, 'feature_names', None)
|
|
291
|
+
|
|
292
|
+
for idx in sorted_indices:
|
|
293
|
+
# Get importance value for this feature
|
|
294
|
+
if feature_names and idx < len(feature_names):
|
|
295
|
+
fname = feature_names[idx]
|
|
296
|
+
else:
|
|
297
|
+
# Try common naming patterns
|
|
298
|
+
for pattern in [f"feature_{idx}", f"f{idx}", f"feat_{idx}"]:
|
|
299
|
+
if pattern in attributions:
|
|
300
|
+
fname = pattern
|
|
301
|
+
break
|
|
302
|
+
else:
|
|
303
|
+
fname = list(attributions.keys())[sorted_indices.index(idx)] if idx < len(attributions) else None
|
|
304
|
+
|
|
305
|
+
if fname and fname in attributions:
|
|
306
|
+
importance_values.append(abs(attributions[fname]))
|
|
307
|
+
else:
|
|
308
|
+
continue
|
|
309
|
+
|
|
310
|
+
# Compute prediction change when removing this single feature
|
|
311
|
+
perturbed = apply_feature_mask(instance, [idx], baseline_values)
|
|
312
|
+
change = compute_prediction_change(model, instance, perturbed, metric="absolute")
|
|
313
|
+
prediction_changes.append(change)
|
|
314
|
+
|
|
315
|
+
if len(importance_values) < 2:
|
|
316
|
+
return 0.0 # Not enough data points
|
|
317
|
+
|
|
318
|
+
# Compute Pearson correlation
|
|
319
|
+
return float(np.corrcoef(importance_values, prediction_changes)[0, 1])
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def compare_explainer_faithfulness(
|
|
323
|
+
model,
|
|
324
|
+
X: np.ndarray,
|
|
325
|
+
explanations: Dict[str, List[Explanation]],
|
|
326
|
+
k: Union[int, float] = 0.2,
|
|
327
|
+
baseline: Union[str, float, np.ndarray, Callable] = "mean",
|
|
328
|
+
max_samples: int = None,
|
|
329
|
+
) -> pd.DataFrame:
|
|
330
|
+
"""
|
|
331
|
+
Compare multiple explainers on faithfulness metrics across a dataset.
|
|
332
|
+
|
|
333
|
+
Args:
|
|
334
|
+
model: Model adapter
|
|
335
|
+
X: Input data (2D array, n_samples x n_features)
|
|
336
|
+
explanations: Dict mapping explainer names to lists of Explanation objects
|
|
337
|
+
k: Number/fraction of features for PGI/PGU
|
|
338
|
+
baseline: Baseline for feature replacement
|
|
339
|
+
max_samples: Limit number of samples to evaluate (None = all)
|
|
340
|
+
|
|
341
|
+
Returns:
|
|
342
|
+
DataFrame with columns: [explainer, mean_pgi, std_pgi, mean_pgu, std_pgu,
|
|
343
|
+
mean_ratio, mean_diff, n_samples]
|
|
344
|
+
"""
|
|
345
|
+
results = []
|
|
346
|
+
|
|
347
|
+
for explainer_name, expl_list in explanations.items():
|
|
348
|
+
n_samples = len(expl_list)
|
|
349
|
+
if max_samples:
|
|
350
|
+
n_samples = min(n_samples, max_samples)
|
|
351
|
+
|
|
352
|
+
pgi_scores = []
|
|
353
|
+
pgu_scores = []
|
|
354
|
+
|
|
355
|
+
for i in range(n_samples):
|
|
356
|
+
instance = X[i]
|
|
357
|
+
exp = expl_list[i]
|
|
358
|
+
|
|
359
|
+
try:
|
|
360
|
+
scores = compute_faithfulness_score(
|
|
361
|
+
model, instance, exp, k, baseline, X
|
|
362
|
+
)
|
|
363
|
+
pgi_scores.append(scores["pgi"])
|
|
364
|
+
pgu_scores.append(scores["pgu"])
|
|
365
|
+
except Exception as e:
|
|
366
|
+
# Skip instances that fail
|
|
367
|
+
continue
|
|
368
|
+
|
|
369
|
+
if pgi_scores:
|
|
370
|
+
results.append({
|
|
371
|
+
"explainer": explainer_name,
|
|
372
|
+
"mean_pgi": np.mean(pgi_scores),
|
|
373
|
+
"std_pgi": np.std(pgi_scores),
|
|
374
|
+
"mean_pgu": np.mean(pgu_scores),
|
|
375
|
+
"std_pgu": np.std(pgu_scores),
|
|
376
|
+
"mean_ratio": np.mean(pgi_scores) / (np.mean(pgu_scores) + 1e-7),
|
|
377
|
+
"mean_diff": np.mean(pgi_scores) - np.mean(pgu_scores),
|
|
378
|
+
"n_samples": len(pgi_scores),
|
|
379
|
+
})
|
|
380
|
+
|
|
381
|
+
return pd.DataFrame(results)
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def compute_batch_faithfulness(
|
|
385
|
+
model,
|
|
386
|
+
X: np.ndarray,
|
|
387
|
+
explanations: List[Explanation],
|
|
388
|
+
k: Union[int, float] = 0.2,
|
|
389
|
+
baseline: Union[str, float, np.ndarray, Callable] = "mean",
|
|
390
|
+
) -> Dict[str, float]:
|
|
391
|
+
"""
|
|
392
|
+
Compute average faithfulness metrics over a batch of instances.
|
|
393
|
+
|
|
394
|
+
Args:
|
|
395
|
+
model: Model adapter
|
|
396
|
+
X: Input data (2D array)
|
|
397
|
+
explanations: List of Explanation objects (one per instance)
|
|
398
|
+
k: Number/fraction of features for PGI/PGU
|
|
399
|
+
baseline: Baseline for feature replacement
|
|
400
|
+
|
|
401
|
+
Returns:
|
|
402
|
+
Dictionary with aggregated metrics
|
|
403
|
+
"""
|
|
404
|
+
pgi_scores = []
|
|
405
|
+
pgu_scores = []
|
|
406
|
+
|
|
407
|
+
for i, exp in enumerate(explanations):
|
|
408
|
+
try:
|
|
409
|
+
scores = compute_faithfulness_score(
|
|
410
|
+
model, X[i], exp, k, baseline, X
|
|
411
|
+
)
|
|
412
|
+
pgi_scores.append(scores["pgi"])
|
|
413
|
+
pgu_scores.append(scores["pgu"])
|
|
414
|
+
except Exception:
|
|
415
|
+
continue
|
|
416
|
+
|
|
417
|
+
if not pgi_scores:
|
|
418
|
+
return {"mean_pgi": 0.0, "mean_pgu": 0.0, "mean_ratio": 0.0, "n_samples": 0}
|
|
419
|
+
|
|
420
|
+
return {
|
|
421
|
+
"mean_pgi": np.mean(pgi_scores),
|
|
422
|
+
"std_pgi": np.std(pgi_scores),
|
|
423
|
+
"mean_pgu": np.mean(pgu_scores),
|
|
424
|
+
"std_pgu": np.std(pgu_scores),
|
|
425
|
+
"mean_ratio": np.mean(pgi_scores) / (np.mean(pgu_scores) + 1e-7),
|
|
426
|
+
"mean_diff": np.mean(pgi_scores) - np.mean(pgu_scores),
|
|
427
|
+
"n_samples": len(pgi_scores),
|
|
428
|
+
}
|