explainiverse 0.6.0__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- explainiverse/__init__.py +1 -1
- explainiverse/adapters/pytorch_adapter.py +88 -25
- explainiverse/core/explanation.py +165 -10
- explainiverse/core/registry.py +18 -0
- explainiverse/engine/suite.py +187 -78
- explainiverse/evaluation/metrics.py +189 -108
- explainiverse/explainers/attribution/lime_wrapper.py +90 -7
- explainiverse/explainers/attribution/shap_wrapper.py +104 -8
- explainiverse/explainers/gradient/__init__.py +12 -0
- explainiverse/explainers/gradient/integrated_gradients.py +189 -76
- explainiverse/explainers/gradient/tcav.py +865 -0
- {explainiverse-0.6.0.dist-info → explainiverse-0.7.1.dist-info}/METADATA +60 -9
- {explainiverse-0.6.0.dist-info → explainiverse-0.7.1.dist-info}/RECORD +15 -14
- {explainiverse-0.6.0.dist-info → explainiverse-0.7.1.dist-info}/LICENSE +0 -0
- {explainiverse-0.6.0.dist-info → explainiverse-0.7.1.dist-info}/WHEEL +0 -0
|
@@ -1,9 +1,68 @@
|
|
|
1
|
+
# src/explainiverse/evaluation/metrics.py
|
|
2
|
+
"""
|
|
3
|
+
Legacy evaluation metrics: AOPC and ROAR.
|
|
4
|
+
|
|
5
|
+
For comprehensive evaluation, prefer the metrics in faithfulness.py
|
|
6
|
+
and stability.py which have better edge case handling.
|
|
7
|
+
"""
|
|
8
|
+
|
|
1
9
|
import numpy as np
|
|
10
|
+
import re
|
|
11
|
+
from typing import List, Dict, Optional, Union, Callable
|
|
2
12
|
from explainiverse.core.explanation import Explanation
|
|
3
13
|
from sklearn.metrics import accuracy_score
|
|
4
14
|
import copy
|
|
5
15
|
|
|
6
16
|
|
|
17
|
+
def _extract_feature_index(
|
|
18
|
+
feature_name: str,
|
|
19
|
+
feature_names: Optional[List[str]] = None,
|
|
20
|
+
fallback_index: int = 0
|
|
21
|
+
) -> int:
|
|
22
|
+
"""
|
|
23
|
+
Extract feature index from a feature name string.
|
|
24
|
+
|
|
25
|
+
Handles various naming conventions including LIME-style conditions
|
|
26
|
+
like "feature_0 <= 5.0".
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
feature_name: Feature name (possibly with conditions)
|
|
30
|
+
feature_names: Optional list of canonical feature names
|
|
31
|
+
fallback_index: Index to return if extraction fails
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
Feature index
|
|
35
|
+
"""
|
|
36
|
+
# Try exact match first
|
|
37
|
+
if feature_names is not None:
|
|
38
|
+
if feature_name in feature_names:
|
|
39
|
+
return feature_names.index(feature_name)
|
|
40
|
+
|
|
41
|
+
# Extract base name (remove LIME-style conditions)
|
|
42
|
+
base_name = re.sub(r'\s*[<>=!]+\s*[\d.\-]+$', '', feature_name).strip()
|
|
43
|
+
if base_name in feature_names:
|
|
44
|
+
return feature_names.index(base_name)
|
|
45
|
+
|
|
46
|
+
# Try partial match (feature name contained in key)
|
|
47
|
+
for i, fname in enumerate(feature_names):
|
|
48
|
+
if fname in feature_name:
|
|
49
|
+
return i
|
|
50
|
+
|
|
51
|
+
# Try extracting index from patterns like "feature_2", "f2", "x2"
|
|
52
|
+
patterns = [
|
|
53
|
+
r'feature[_\s]*(\d+)',
|
|
54
|
+
r'feat[_\s]*(\d+)',
|
|
55
|
+
r'^f(\d+)$',
|
|
56
|
+
r'^x(\d+)$',
|
|
57
|
+
]
|
|
58
|
+
for pattern in patterns:
|
|
59
|
+
match = re.search(pattern, feature_name, re.IGNORECASE)
|
|
60
|
+
if match:
|
|
61
|
+
return int(match.group(1))
|
|
62
|
+
|
|
63
|
+
return fallback_index
|
|
64
|
+
|
|
65
|
+
|
|
7
66
|
def compute_aopc(
|
|
8
67
|
model,
|
|
9
68
|
instance: np.ndarray,
|
|
@@ -12,39 +71,50 @@ def compute_aopc(
|
|
|
12
71
|
baseline_value: float = 0.0
|
|
13
72
|
) -> float:
|
|
14
73
|
"""
|
|
15
|
-
|
|
74
|
+
Compute Area Over the Perturbation Curve (AOPC).
|
|
75
|
+
|
|
76
|
+
AOPC measures explanation faithfulness by iteratively removing
|
|
77
|
+
the most important features and measuring prediction change.
|
|
16
78
|
|
|
17
79
|
Args:
|
|
18
|
-
model:
|
|
19
|
-
instance:
|
|
20
|
-
explanation: Explanation object
|
|
21
|
-
num_steps:
|
|
22
|
-
baseline_value:
|
|
80
|
+
model: Model adapter with .predict() method
|
|
81
|
+
instance: Input sample (1D array)
|
|
82
|
+
explanation: Explanation object with feature_attributions
|
|
83
|
+
num_steps: Number of top features to remove
|
|
84
|
+
baseline_value: Value to replace removed features with
|
|
23
85
|
|
|
24
86
|
Returns:
|
|
25
|
-
AOPC score (higher
|
|
87
|
+
AOPC score (higher = more faithful explanation)
|
|
26
88
|
"""
|
|
89
|
+
instance = np.asarray(instance).flatten()
|
|
90
|
+
n_features = len(instance)
|
|
91
|
+
|
|
27
92
|
base_pred = model.predict(instance.reshape(1, -1))[0]
|
|
93
|
+
if hasattr(base_pred, '__len__') and len(base_pred) > 1:
|
|
94
|
+
base_pred = float(np.max(base_pred))
|
|
95
|
+
else:
|
|
96
|
+
base_pred = float(base_pred)
|
|
97
|
+
|
|
28
98
|
attributions = explanation.explanation_data.get("feature_attributions", {})
|
|
29
|
-
|
|
30
99
|
if not attributions:
|
|
31
100
|
raise ValueError("No feature attributions found in explanation.")
|
|
32
101
|
|
|
33
|
-
# Sort features by
|
|
102
|
+
# Sort features by absolute importance (most important first)
|
|
34
103
|
sorted_features = sorted(
|
|
35
104
|
attributions.items(),
|
|
36
105
|
key=lambda x: abs(x[1]),
|
|
37
106
|
reverse=True
|
|
38
107
|
)
|
|
39
108
|
|
|
40
|
-
#
|
|
109
|
+
# Get feature_names from explanation (may be None)
|
|
110
|
+
feature_names = getattr(explanation, 'feature_names', None)
|
|
111
|
+
|
|
112
|
+
# Map feature names to indices
|
|
41
113
|
feature_indices = []
|
|
42
114
|
for i, (fname, _) in enumerate(sorted_features):
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
idx = i # fallback: assume order
|
|
47
|
-
feature_indices.append(idx)
|
|
115
|
+
idx = _extract_feature_index(fname, feature_names, fallback_index=i)
|
|
116
|
+
if 0 <= idx < n_features:
|
|
117
|
+
feature_indices.append(idx)
|
|
48
118
|
|
|
49
119
|
deltas = []
|
|
50
120
|
modified = instance.copy()
|
|
@@ -52,42 +122,52 @@ def compute_aopc(
|
|
|
52
122
|
for i in range(min(num_steps, len(feature_indices))):
|
|
53
123
|
idx = feature_indices[i]
|
|
54
124
|
modified[idx] = baseline_value
|
|
125
|
+
|
|
55
126
|
new_pred = model.predict(modified.reshape(1, -1))[0]
|
|
127
|
+
if hasattr(new_pred, '__len__') and len(new_pred) > 1:
|
|
128
|
+
new_pred = float(np.max(new_pred))
|
|
129
|
+
else:
|
|
130
|
+
new_pred = float(new_pred)
|
|
131
|
+
|
|
56
132
|
delta = abs(base_pred - new_pred)
|
|
57
133
|
deltas.append(delta)
|
|
58
134
|
|
|
59
|
-
return np.mean(deltas)
|
|
135
|
+
return float(np.mean(deltas)) if deltas else 0.0
|
|
60
136
|
|
|
61
137
|
|
|
62
138
|
def compute_batch_aopc(
|
|
63
139
|
model,
|
|
64
140
|
X: np.ndarray,
|
|
65
|
-
explanations:
|
|
141
|
+
explanations: Dict[str, List[Explanation]],
|
|
66
142
|
num_steps: int = 10,
|
|
67
143
|
baseline_value: float = 0.0
|
|
68
|
-
) ->
|
|
144
|
+
) -> Dict[str, float]:
|
|
69
145
|
"""
|
|
70
|
-
Compute average AOPC
|
|
146
|
+
Compute average AOPC across multiple explainers and instances.
|
|
71
147
|
|
|
72
148
|
Args:
|
|
73
|
-
model:
|
|
74
|
-
X: 2D input array
|
|
75
|
-
explanations:
|
|
76
|
-
num_steps:
|
|
77
|
-
baseline_value:
|
|
149
|
+
model: Model adapter
|
|
150
|
+
X: 2D input array (n_samples, n_features)
|
|
151
|
+
explanations: Dict mapping explainer names to lists of Explanation objects
|
|
152
|
+
num_steps: Number of top features to remove
|
|
153
|
+
baseline_value: Value to replace features with
|
|
78
154
|
|
|
79
155
|
Returns:
|
|
80
|
-
Dict
|
|
156
|
+
Dict mapping explainer names to mean AOPC scores
|
|
81
157
|
"""
|
|
82
158
|
results = {}
|
|
83
159
|
|
|
84
160
|
for explainer_name, expl_list in explanations.items():
|
|
85
161
|
scores = []
|
|
86
162
|
for i, exp in enumerate(expl_list):
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
163
|
+
if i >= len(X):
|
|
164
|
+
break
|
|
165
|
+
try:
|
|
166
|
+
score = compute_aopc(model, X[i], exp, num_steps, baseline_value)
|
|
167
|
+
scores.append(score)
|
|
168
|
+
except Exception:
|
|
169
|
+
continue
|
|
170
|
+
results[explainer_name] = float(np.mean(scores)) if scores else 0.0
|
|
91
171
|
|
|
92
172
|
return results
|
|
93
173
|
|
|
@@ -98,136 +178,137 @@ def compute_roar(
|
|
|
98
178
|
y_train: np.ndarray,
|
|
99
179
|
X_test: np.ndarray,
|
|
100
180
|
y_test: np.ndarray,
|
|
101
|
-
explanations:
|
|
181
|
+
explanations: List[Explanation],
|
|
102
182
|
top_k: int = 3,
|
|
103
|
-
baseline_value: float = 0.0,
|
|
104
|
-
model_kwargs:
|
|
183
|
+
baseline_value: Union[str, float, np.ndarray, Callable] = 0.0,
|
|
184
|
+
model_kwargs: Optional[Dict] = None
|
|
105
185
|
) -> float:
|
|
106
186
|
"""
|
|
107
|
-
Compute ROAR (Remove And Retrain)
|
|
187
|
+
Compute ROAR (Remove And Retrain) score.
|
|
188
|
+
|
|
189
|
+
ROAR retrains the model after removing top-k important features
|
|
190
|
+
and measures the accuracy drop.
|
|
108
191
|
|
|
109
192
|
Args:
|
|
110
|
-
model_class:
|
|
111
|
-
X_train:
|
|
112
|
-
y_train:
|
|
113
|
-
X_test:
|
|
114
|
-
y_test:
|
|
115
|
-
explanations:
|
|
116
|
-
top_k:
|
|
117
|
-
baseline_value:
|
|
118
|
-
|
|
193
|
+
model_class: Uninstantiated model class (e.g., LogisticRegression)
|
|
194
|
+
X_train: Training features
|
|
195
|
+
y_train: Training labels
|
|
196
|
+
X_test: Test features
|
|
197
|
+
y_test: Test labels
|
|
198
|
+
explanations: List of Explanation objects (one per training instance)
|
|
199
|
+
top_k: Number of top features to remove
|
|
200
|
+
baseline_value: Replacement value for removed features:
|
|
201
|
+
- float/int: constant value
|
|
202
|
+
- "mean": per-feature mean from X_train
|
|
203
|
+
- "median": per-feature median from X_train
|
|
204
|
+
- np.ndarray: per-feature values
|
|
205
|
+
- callable: function(X_train) -> per-feature values
|
|
206
|
+
model_kwargs: Optional kwargs for model_class
|
|
119
207
|
|
|
120
208
|
Returns:
|
|
121
209
|
Accuracy drop (baseline_acc - retrained_acc)
|
|
122
210
|
"""
|
|
123
211
|
model_kwargs = model_kwargs or {}
|
|
212
|
+
n_features = X_train.shape[1]
|
|
124
213
|
|
|
125
|
-
#
|
|
214
|
+
# Train baseline model
|
|
126
215
|
baseline_model = model_class(**model_kwargs)
|
|
127
216
|
baseline_model.fit(X_train, y_train)
|
|
128
|
-
|
|
129
|
-
baseline_acc = accuracy_score(y_test, baseline_preds)
|
|
217
|
+
baseline_acc = accuracy_score(y_test, baseline_model.predict(X_test))
|
|
130
218
|
|
|
131
|
-
#
|
|
132
|
-
|
|
133
|
-
for exp in explanations:
|
|
134
|
-
for fname, val in sorted(exp.explanation_data["feature_attributions"].items(), key=lambda x: abs(x[1]), reverse=True)[:top_k]:
|
|
135
|
-
try:
|
|
136
|
-
idx = exp.feature_names.index(fname)
|
|
137
|
-
feature_counts[idx] = feature_counts.get(idx, 0) + 1
|
|
138
|
-
except:
|
|
139
|
-
continue
|
|
140
|
-
|
|
141
|
-
top_features = sorted(feature_counts.items(), key=lambda x: x[1], reverse=True)[:top_k]
|
|
142
|
-
top_feature_indices = [idx for idx, _ in top_features]
|
|
143
|
-
|
|
144
|
-
# Remove top-k from training and test data
|
|
145
|
-
X_train_mod = copy.deepcopy(X_train)
|
|
146
|
-
X_test_mod = copy.deepcopy(X_test)
|
|
219
|
+
# Collect top-k feature indices via voting across explanations
|
|
220
|
+
feature_votes: Dict[int, int] = {}
|
|
147
221
|
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
222
|
+
for exp in explanations:
|
|
223
|
+
attributions = exp.explanation_data.get("feature_attributions", {})
|
|
224
|
+
if not attributions:
|
|
225
|
+
continue
|
|
226
|
+
|
|
227
|
+
# Get feature_names from explanation
|
|
228
|
+
feature_names = getattr(exp, 'feature_names', None)
|
|
229
|
+
|
|
230
|
+
# Get top-k features by absolute importance
|
|
231
|
+
sorted_attrs = sorted(
|
|
232
|
+
attributions.items(),
|
|
233
|
+
key=lambda x: abs(x[1]),
|
|
234
|
+
reverse=True
|
|
235
|
+
)[:top_k]
|
|
236
|
+
|
|
237
|
+
for i, (fname, _) in enumerate(sorted_attrs):
|
|
238
|
+
idx = _extract_feature_index(fname, feature_names, fallback_index=i)
|
|
239
|
+
if 0 <= idx < n_features:
|
|
240
|
+
feature_votes[idx] = feature_votes.get(idx, 0) + 1
|
|
241
|
+
|
|
242
|
+
# Select most voted features
|
|
243
|
+
top_features = sorted(feature_votes.items(), key=lambda x: x[1], reverse=True)[:top_k]
|
|
244
|
+
top_indices = [idx for idx, _ in top_features]
|
|
155
245
|
|
|
246
|
+
if not top_indices:
|
|
247
|
+
return 0.0
|
|
248
|
+
|
|
249
|
+
# Compute baseline values
|
|
156
250
|
if isinstance(baseline_value, str):
|
|
157
251
|
if baseline_value == "mean":
|
|
158
252
|
feature_baseline = np.mean(X_train, axis=0)
|
|
159
253
|
elif baseline_value == "median":
|
|
160
254
|
feature_baseline = np.median(X_train, axis=0)
|
|
161
255
|
else:
|
|
162
|
-
raise ValueError(f"Unsupported
|
|
256
|
+
raise ValueError(f"Unsupported baseline: {baseline_value}")
|
|
163
257
|
elif callable(baseline_value):
|
|
164
258
|
feature_baseline = baseline_value(X_train)
|
|
165
259
|
elif isinstance(baseline_value, np.ndarray):
|
|
166
|
-
if baseline_value.shape != (X_train.shape[1],):
|
|
167
|
-
raise ValueError("baseline_value ndarray must match number of features")
|
|
168
260
|
feature_baseline = baseline_value
|
|
169
|
-
elif isinstance(baseline_value, (float, int, np.number)):
|
|
170
|
-
feature_baseline = np.full(X_train.shape[1], baseline_value)
|
|
171
261
|
else:
|
|
172
|
-
|
|
262
|
+
feature_baseline = np.full(n_features, float(baseline_value))
|
|
263
|
+
|
|
264
|
+
# Remove features
|
|
265
|
+
X_train_mod = X_train.copy()
|
|
266
|
+
X_test_mod = X_test.copy()
|
|
173
267
|
|
|
174
|
-
for idx in
|
|
268
|
+
for idx in top_indices:
|
|
175
269
|
X_train_mod[:, idx] = feature_baseline[idx]
|
|
176
270
|
X_test_mod[:, idx] = feature_baseline[idx]
|
|
177
|
-
# X_train_mod[:, idx] = baseline_value
|
|
178
|
-
# X_test_mod[:, idx] = baseline_value
|
|
179
271
|
|
|
180
272
|
# Retrain and evaluate
|
|
181
273
|
retrained_model = model_class(**model_kwargs)
|
|
182
274
|
retrained_model.fit(X_train_mod, y_train)
|
|
183
|
-
|
|
184
|
-
retrained_acc = accuracy_score(y_test, retrained_preds)
|
|
275
|
+
retrained_acc = accuracy_score(y_test, retrained_model.predict(X_test_mod))
|
|
185
276
|
|
|
186
|
-
return baseline_acc - retrained_acc
|
|
277
|
+
return float(baseline_acc - retrained_acc)
|
|
187
278
|
|
|
188
279
|
|
|
189
280
|
def compute_roar_curve(
|
|
190
281
|
model_class,
|
|
191
|
-
X_train,
|
|
192
|
-
y_train,
|
|
193
|
-
X_test,
|
|
194
|
-
y_test,
|
|
195
|
-
explanations,
|
|
196
|
-
max_k=5,
|
|
197
|
-
baseline_value="mean",
|
|
198
|
-
model_kwargs=None
|
|
199
|
-
) ->
|
|
282
|
+
X_train: np.ndarray,
|
|
283
|
+
y_train: np.ndarray,
|
|
284
|
+
X_test: np.ndarray,
|
|
285
|
+
y_test: np.ndarray,
|
|
286
|
+
explanations: List[Explanation],
|
|
287
|
+
max_k: int = 5,
|
|
288
|
+
baseline_value: Union[str, float, np.ndarray, Callable] = "mean",
|
|
289
|
+
model_kwargs: Optional[Dict] = None
|
|
290
|
+
) -> Dict[int, float]:
|
|
200
291
|
"""
|
|
201
|
-
Compute ROAR
|
|
202
|
-
|
|
203
|
-
Args:
|
|
204
|
-
model_class: model type (e.g. LogisticRegression)
|
|
205
|
-
X_train, y_train, X_test, y_test: full dataset
|
|
206
|
-
explanations: list of Explanation objects
|
|
207
|
-
max_k: maximum top-k to try
|
|
208
|
-
baseline_value: string, scalar, ndarray, or callable
|
|
209
|
-
model_kwargs: passed to model class
|
|
292
|
+
Compute ROAR scores for k=1 to max_k.
|
|
210
293
|
|
|
211
294
|
Returns:
|
|
212
|
-
Dict
|
|
295
|
+
Dict mapping k to accuracy drop
|
|
213
296
|
"""
|
|
214
|
-
from copy import deepcopy
|
|
215
|
-
|
|
216
297
|
model_kwargs = model_kwargs or {}
|
|
217
298
|
curve = {}
|
|
218
299
|
|
|
219
300
|
for k in range(1, max_k + 1):
|
|
220
301
|
acc_drop = compute_roar(
|
|
221
302
|
model_class=model_class,
|
|
222
|
-
X_train=
|
|
223
|
-
y_train=
|
|
224
|
-
X_test=
|
|
225
|
-
y_test=
|
|
226
|
-
explanations=
|
|
303
|
+
X_train=X_train.copy(),
|
|
304
|
+
y_train=y_train.copy(),
|
|
305
|
+
X_test=X_test.copy(),
|
|
306
|
+
y_test=y_test.copy(),
|
|
307
|
+
explanations=explanations,
|
|
227
308
|
top_k=k,
|
|
228
309
|
baseline_value=baseline_value,
|
|
229
|
-
model_kwargs=
|
|
310
|
+
model_kwargs=model_kwargs
|
|
230
311
|
)
|
|
231
312
|
curve[k] = acc_drop
|
|
232
313
|
|
|
233
|
-
return curve
|
|
314
|
+
return curve
|
|
@@ -8,14 +8,36 @@ model (linear regression) to perturbed samples around the instance.
|
|
|
8
8
|
Reference:
|
|
9
9
|
Ribeiro, M.T., Singh, S., & Guestrin, C. (2016). "Why Should I Trust You?":
|
|
10
10
|
Explaining the Predictions of Any Classifier. KDD 2016.
|
|
11
|
+
https://arxiv.org/abs/1602.04938
|
|
11
12
|
"""
|
|
12
13
|
|
|
13
14
|
import numpy as np
|
|
14
|
-
from
|
|
15
|
+
from typing import List, Optional
|
|
15
16
|
|
|
16
17
|
from explainiverse.core.explainer import BaseExplainer
|
|
17
18
|
from explainiverse.core.explanation import Explanation
|
|
18
19
|
|
|
20
|
+
# Lazy import check - don't import lime at module level
|
|
21
|
+
_LIME_AVAILABLE = None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _check_lime_available():
|
|
25
|
+
"""Check if LIME is available and raise ImportError if not."""
|
|
26
|
+
global _LIME_AVAILABLE
|
|
27
|
+
|
|
28
|
+
if _LIME_AVAILABLE is None:
|
|
29
|
+
try:
|
|
30
|
+
import lime
|
|
31
|
+
_LIME_AVAILABLE = True
|
|
32
|
+
except ImportError:
|
|
33
|
+
_LIME_AVAILABLE = False
|
|
34
|
+
|
|
35
|
+
if not _LIME_AVAILABLE:
|
|
36
|
+
raise ImportError(
|
|
37
|
+
"LIME is required for LimeExplainer. "
|
|
38
|
+
"Install it with: pip install lime"
|
|
39
|
+
)
|
|
40
|
+
|
|
19
41
|
|
|
20
42
|
class LimeExplainer(BaseExplainer):
|
|
21
43
|
"""
|
|
@@ -34,9 +56,26 @@ class LimeExplainer(BaseExplainer):
|
|
|
34
56
|
class_names: List of class names
|
|
35
57
|
mode: 'classification' or 'regression'
|
|
36
58
|
explainer: The underlying LimeTabularExplainer
|
|
59
|
+
|
|
60
|
+
Example:
|
|
61
|
+
>>> from explainiverse.explainers.attribution import LimeExplainer
|
|
62
|
+
>>> explainer = LimeExplainer(
|
|
63
|
+
... model=adapter,
|
|
64
|
+
... training_data=X_train,
|
|
65
|
+
... feature_names=feature_names,
|
|
66
|
+
... class_names=class_names
|
|
67
|
+
... )
|
|
68
|
+
>>> explanation = explainer.explain(X_test[0])
|
|
37
69
|
"""
|
|
38
70
|
|
|
39
|
-
def __init__(
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
model,
|
|
74
|
+
training_data: np.ndarray,
|
|
75
|
+
feature_names: List[str],
|
|
76
|
+
class_names: List[str],
|
|
77
|
+
mode: str = "classification"
|
|
78
|
+
):
|
|
40
79
|
"""
|
|
41
80
|
Initialize the LIME explainer.
|
|
42
81
|
|
|
@@ -47,20 +86,35 @@ class LimeExplainer(BaseExplainer):
|
|
|
47
86
|
feature_names: List of feature names.
|
|
48
87
|
class_names: List of class names.
|
|
49
88
|
mode: 'classification' or 'regression'.
|
|
89
|
+
|
|
90
|
+
Raises:
|
|
91
|
+
ImportError: If lime package is not installed.
|
|
50
92
|
"""
|
|
93
|
+
# Check availability before importing
|
|
94
|
+
_check_lime_available()
|
|
95
|
+
|
|
96
|
+
# Import after check passes
|
|
97
|
+
from lime.lime_tabular import LimeTabularExplainer
|
|
98
|
+
|
|
51
99
|
super().__init__(model)
|
|
52
100
|
self.feature_names = list(feature_names)
|
|
53
101
|
self.class_names = list(class_names)
|
|
54
102
|
self.mode = mode
|
|
103
|
+
self.training_data = np.asarray(training_data)
|
|
55
104
|
|
|
56
105
|
self.explainer = LimeTabularExplainer(
|
|
57
|
-
training_data=training_data,
|
|
58
|
-
feature_names=feature_names,
|
|
59
|
-
class_names=class_names,
|
|
106
|
+
training_data=self.training_data,
|
|
107
|
+
feature_names=self.feature_names,
|
|
108
|
+
class_names=self.class_names,
|
|
60
109
|
mode=mode
|
|
61
110
|
)
|
|
62
111
|
|
|
63
|
-
def explain(
|
|
112
|
+
def explain(
|
|
113
|
+
self,
|
|
114
|
+
instance: np.ndarray,
|
|
115
|
+
num_features: int = 5,
|
|
116
|
+
top_labels: int = 1
|
|
117
|
+
) -> Explanation:
|
|
64
118
|
"""
|
|
65
119
|
Generate a local explanation for the given instance.
|
|
66
120
|
|
|
@@ -72,6 +126,8 @@ class LimeExplainer(BaseExplainer):
|
|
|
72
126
|
Returns:
|
|
73
127
|
Explanation object with feature attributions
|
|
74
128
|
"""
|
|
129
|
+
instance = np.asarray(instance).flatten()
|
|
130
|
+
|
|
75
131
|
lime_exp = self.explainer.explain_instance(
|
|
76
132
|
data_row=instance,
|
|
77
133
|
predict_fn=self.model.predict,
|
|
@@ -86,5 +142,32 @@ class LimeExplainer(BaseExplainer):
|
|
|
86
142
|
return Explanation(
|
|
87
143
|
explainer_name="LIME",
|
|
88
144
|
target_class=label_name,
|
|
89
|
-
explanation_data={"feature_attributions": attributions}
|
|
145
|
+
explanation_data={"feature_attributions": attributions},
|
|
146
|
+
feature_names=self.feature_names
|
|
90
147
|
)
|
|
148
|
+
|
|
149
|
+
def explain_batch(
|
|
150
|
+
self,
|
|
151
|
+
X: np.ndarray,
|
|
152
|
+
num_features: int = 5,
|
|
153
|
+
top_labels: int = 1
|
|
154
|
+
) -> List[Explanation]:
|
|
155
|
+
"""
|
|
156
|
+
Generate explanations for multiple instances.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
X: 2D numpy array of instances
|
|
160
|
+
num_features: Number of features per explanation
|
|
161
|
+
top_labels: Number of top labels to explain
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
List of Explanation objects
|
|
165
|
+
"""
|
|
166
|
+
X = np.asarray(X)
|
|
167
|
+
if X.ndim == 1:
|
|
168
|
+
X = X.reshape(1, -1)
|
|
169
|
+
|
|
170
|
+
return [
|
|
171
|
+
self.explain(X[i], num_features=num_features, top_labels=top_labels)
|
|
172
|
+
for i in range(X.shape[0])
|
|
173
|
+
]
|