explainiverse 0.1.1a1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- explainiverse/__init__.py +45 -1
- explainiverse/adapters/__init__.py +9 -0
- explainiverse/adapters/base_adapter.py +25 -25
- explainiverse/adapters/sklearn_adapter.py +32 -32
- explainiverse/core/__init__.py +22 -0
- explainiverse/core/explainer.py +31 -31
- explainiverse/core/explanation.py +24 -24
- explainiverse/core/registry.py +545 -0
- explainiverse/engine/__init__.py +8 -0
- explainiverse/engine/suite.py +142 -142
- explainiverse/evaluation/__init__.py +8 -0
- explainiverse/evaluation/metrics.py +232 -232
- explainiverse/explainers/__init__.py +38 -0
- explainiverse/explainers/attribution/__init__.py +9 -0
- explainiverse/explainers/attribution/lime_wrapper.py +90 -63
- explainiverse/explainers/attribution/shap_wrapper.py +89 -66
- explainiverse/explainers/counterfactual/__init__.py +8 -0
- explainiverse/explainers/counterfactual/dice_wrapper.py +302 -0
- explainiverse/explainers/global_explainers/__init__.py +23 -0
- explainiverse/explainers/global_explainers/ale.py +191 -0
- explainiverse/explainers/global_explainers/partial_dependence.py +192 -0
- explainiverse/explainers/global_explainers/permutation_importance.py +123 -0
- explainiverse/explainers/global_explainers/sage.py +164 -0
- explainiverse/explainers/rule_based/__init__.py +8 -0
- explainiverse/explainers/rule_based/anchors_wrapper.py +350 -0
- explainiverse-0.2.0.dist-info/METADATA +264 -0
- explainiverse-0.2.0.dist-info/RECORD +29 -0
- explainiverse-0.1.1a1.dist-info/METADATA +0 -128
- explainiverse-0.1.1a1.dist-info/RECORD +0 -19
- {explainiverse-0.1.1a1.dist-info → explainiverse-0.2.0.dist-info}/LICENSE +0 -0
- {explainiverse-0.1.1a1.dist-info → explainiverse-0.2.0.dist-info}/WHEEL +0 -0
|
@@ -1,66 +1,89 @@
|
|
|
1
|
-
# src/explainiverse/explainers/attribution/shap_wrapper.py
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
1
|
+
# src/explainiverse/explainers/attribution/shap_wrapper.py
|
|
2
|
+
"""
|
|
3
|
+
SHAP Explainer - SHapley Additive exPlanations.
|
|
4
|
+
|
|
5
|
+
SHAP values provide a unified measure of feature importance based on
|
|
6
|
+
game-theoretic Shapley values, offering both local and global interpretability.
|
|
7
|
+
|
|
8
|
+
Reference:
|
|
9
|
+
Lundberg, S.M. & Lee, S.I. (2017). A Unified Approach to Interpreting
|
|
10
|
+
Model Predictions. NeurIPS 2017.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import shap
|
|
14
|
+
import numpy as np
|
|
15
|
+
|
|
16
|
+
from explainiverse.core.explainer import BaseExplainer
|
|
17
|
+
from explainiverse.core.explanation import Explanation
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ShapExplainer(BaseExplainer):
|
|
21
|
+
"""
|
|
22
|
+
SHAP explainer (KernelSHAP-based) for model-agnostic explanations.
|
|
23
|
+
|
|
24
|
+
KernelSHAP is a model-agnostic method that approximates SHAP values
|
|
25
|
+
using a weighted linear regression. It works with any model that
|
|
26
|
+
provides predictions.
|
|
27
|
+
|
|
28
|
+
Attributes:
|
|
29
|
+
model: Model adapter with .predict() method
|
|
30
|
+
feature_names: List of feature names
|
|
31
|
+
class_names: List of class labels
|
|
32
|
+
explainer: The underlying SHAP KernelExplainer
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(self, model, background_data, feature_names, class_names):
|
|
36
|
+
"""
|
|
37
|
+
Initialize the SHAP explainer.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
model: A model adapter with a .predict method.
|
|
41
|
+
background_data: A 2D numpy array used as SHAP background distribution.
|
|
42
|
+
Typically a representative sample of training data.
|
|
43
|
+
feature_names: List of feature names.
|
|
44
|
+
class_names: List of class labels.
|
|
45
|
+
"""
|
|
46
|
+
super().__init__(model)
|
|
47
|
+
self.feature_names = list(feature_names)
|
|
48
|
+
self.class_names = list(class_names)
|
|
49
|
+
self.explainer = shap.KernelExplainer(model.predict, background_data)
|
|
50
|
+
|
|
51
|
+
def explain(self, instance, top_labels=1):
|
|
52
|
+
"""
|
|
53
|
+
Generate SHAP explanation for a single instance.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
instance: 1D numpy array of input features.
|
|
57
|
+
top_labels: Number of top classes to explain (default: 1)
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
Explanation object with feature attributions
|
|
61
|
+
"""
|
|
62
|
+
instance = np.array(instance).reshape(1, -1) # Ensure 2D
|
|
63
|
+
shap_values = self.explainer.shap_values(instance)
|
|
64
|
+
|
|
65
|
+
if isinstance(shap_values, list):
|
|
66
|
+
# Multi-class: list of arrays, one per class
|
|
67
|
+
predicted_probs = self.model.predict(instance)[0]
|
|
68
|
+
top_indices = np.argsort(predicted_probs)[-top_labels:][::-1]
|
|
69
|
+
label_index = top_indices[0]
|
|
70
|
+
label_name = self.class_names[label_index]
|
|
71
|
+
class_shap = shap_values[label_index][0]
|
|
72
|
+
else:
|
|
73
|
+
# Single-class (regression or binary classification)
|
|
74
|
+
label_index = 0
|
|
75
|
+
label_name = self.class_names[0] if self.class_names else "class_0"
|
|
76
|
+
class_shap = shap_values[0]
|
|
77
|
+
|
|
78
|
+
# Build attributions dict
|
|
79
|
+
flat_shap = np.array(class_shap).flatten()
|
|
80
|
+
attributions = {
|
|
81
|
+
fname: float(flat_shap[i])
|
|
82
|
+
for i, fname in enumerate(self.feature_names)
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
return Explanation(
|
|
86
|
+
explainer_name="SHAP",
|
|
87
|
+
target_class=label_name,
|
|
88
|
+
explanation_data={"feature_attributions": attributions}
|
|
89
|
+
)
|
|
@@ -0,0 +1,302 @@
|
|
|
1
|
+
# src/explainiverse/explainers/counterfactual/dice_wrapper.py
|
|
2
|
+
"""
|
|
3
|
+
Counterfactual Explainer - DiCE-style diverse counterfactual explanations.
|
|
4
|
+
|
|
5
|
+
Counterfactual explanations answer "What minimal changes would flip the prediction?"
|
|
6
|
+
|
|
7
|
+
Reference:
|
|
8
|
+
Mothilal, R.K., Sharma, A., & Tan, C. (2020). Explaining Machine Learning
|
|
9
|
+
Classifiers through Diverse Counterfactual Explanations. FAT* 2020.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
from typing import List, Optional, Dict, Any, Union
|
|
14
|
+
from scipy.optimize import minimize
|
|
15
|
+
from explainiverse.core.explainer import BaseExplainer
|
|
16
|
+
from explainiverse.core.explanation import Explanation
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class CounterfactualExplainer(BaseExplainer):
|
|
20
|
+
"""
|
|
21
|
+
Counterfactual explainer using gradient-free optimization.
|
|
22
|
+
|
|
23
|
+
Generates minimal perturbations that change the model's prediction
|
|
24
|
+
to a desired class (or just a different class).
|
|
25
|
+
|
|
26
|
+
Attributes:
|
|
27
|
+
model: Model adapter with .predict() method
|
|
28
|
+
training_data: Reference data for constraints
|
|
29
|
+
feature_names: List of feature names
|
|
30
|
+
continuous_features: List of continuous feature names
|
|
31
|
+
categorical_features: List of categorical feature names
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
model,
|
|
37
|
+
training_data: np.ndarray,
|
|
38
|
+
feature_names: List[str],
|
|
39
|
+
continuous_features: Optional[List[str]] = None,
|
|
40
|
+
categorical_features: Optional[List[str]] = None,
|
|
41
|
+
feature_ranges: Optional[Dict[str, tuple]] = None,
|
|
42
|
+
proximity_weight: float = 0.5,
|
|
43
|
+
diversity_weight: float = 0.5,
|
|
44
|
+
random_state: int = 42
|
|
45
|
+
):
|
|
46
|
+
"""
|
|
47
|
+
Initialize the Counterfactual explainer.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
model: Model adapter with .predict() method
|
|
51
|
+
training_data: Reference data (n_samples, n_features)
|
|
52
|
+
feature_names: List of feature names
|
|
53
|
+
continuous_features: Features that can take continuous values
|
|
54
|
+
categorical_features: Features with discrete values
|
|
55
|
+
feature_ranges: Dict of {feature_name: (min, max)} constraints
|
|
56
|
+
proximity_weight: Weight for proximity loss (closer to original)
|
|
57
|
+
diversity_weight: Weight for diversity among counterfactuals
|
|
58
|
+
random_state: Random seed
|
|
59
|
+
"""
|
|
60
|
+
super().__init__(model)
|
|
61
|
+
self.training_data = np.array(training_data)
|
|
62
|
+
self.feature_names = list(feature_names)
|
|
63
|
+
self.continuous_features = continuous_features or feature_names
|
|
64
|
+
self.categorical_features = categorical_features or []
|
|
65
|
+
self.proximity_weight = proximity_weight
|
|
66
|
+
self.diversity_weight = diversity_weight
|
|
67
|
+
self.random_state = random_state
|
|
68
|
+
self.rng = np.random.RandomState(random_state)
|
|
69
|
+
|
|
70
|
+
# Compute feature ranges from data if not provided
|
|
71
|
+
if feature_ranges:
|
|
72
|
+
self.feature_ranges = feature_ranges
|
|
73
|
+
else:
|
|
74
|
+
self.feature_ranges = {}
|
|
75
|
+
for idx, name in enumerate(feature_names):
|
|
76
|
+
values = self.training_data[:, idx]
|
|
77
|
+
self.feature_ranges[name] = (float(np.min(values)), float(np.max(values)))
|
|
78
|
+
|
|
79
|
+
# Compute feature scales for normalization
|
|
80
|
+
self._compute_scales()
|
|
81
|
+
|
|
82
|
+
def _compute_scales(self):
|
|
83
|
+
"""Compute scaling factors for each feature."""
|
|
84
|
+
self.scales = np.zeros(len(self.feature_names))
|
|
85
|
+
for idx, name in enumerate(self.feature_names):
|
|
86
|
+
min_val, max_val = self.feature_ranges.get(name, (0, 1))
|
|
87
|
+
scale = max_val - min_val
|
|
88
|
+
self.scales[idx] = scale if scale > 0 else 1.0
|
|
89
|
+
|
|
90
|
+
def _get_target_class(
|
|
91
|
+
self,
|
|
92
|
+
instance: np.ndarray,
|
|
93
|
+
desired_class: Optional[int] = None
|
|
94
|
+
) -> int:
|
|
95
|
+
"""Determine the target class for the counterfactual."""
|
|
96
|
+
predictions = self.model.predict(instance.reshape(1, -1))
|
|
97
|
+
|
|
98
|
+
if predictions.ndim == 2:
|
|
99
|
+
current_class = np.argmax(predictions[0])
|
|
100
|
+
n_classes = predictions.shape[1]
|
|
101
|
+
else:
|
|
102
|
+
current_class = int(predictions[0] > 0.5)
|
|
103
|
+
n_classes = 2
|
|
104
|
+
|
|
105
|
+
if desired_class is not None:
|
|
106
|
+
return desired_class
|
|
107
|
+
|
|
108
|
+
# Default: flip to any other class
|
|
109
|
+
if n_classes == 2:
|
|
110
|
+
return 1 - current_class
|
|
111
|
+
else:
|
|
112
|
+
# For multi-class, pick the second most likely class
|
|
113
|
+
probs = predictions[0]
|
|
114
|
+
sorted_classes = np.argsort(probs)[::-1]
|
|
115
|
+
return int(sorted_classes[1]) if sorted_classes[0] == current_class else int(sorted_classes[0])
|
|
116
|
+
|
|
117
|
+
def _proximity_loss(self, cf: np.ndarray, original: np.ndarray) -> float:
|
|
118
|
+
"""Compute normalized distance between counterfactual and original."""
|
|
119
|
+
diff = (cf - original) / self.scales
|
|
120
|
+
return float(np.sum(diff ** 2))
|
|
121
|
+
|
|
122
|
+
def _validity_loss(self, cf: np.ndarray, target_class: int) -> float:
|
|
123
|
+
"""Compute loss for achieving the target class."""
|
|
124
|
+
predictions = self.model.predict(cf.reshape(1, -1))
|
|
125
|
+
|
|
126
|
+
if predictions.ndim == 2:
|
|
127
|
+
target_prob = predictions[0, target_class]
|
|
128
|
+
return -np.log(target_prob + 1e-10)
|
|
129
|
+
else:
|
|
130
|
+
if target_class == 1:
|
|
131
|
+
return -np.log(predictions[0] + 1e-10)
|
|
132
|
+
else:
|
|
133
|
+
return -np.log(1 - predictions[0] + 1e-10)
|
|
134
|
+
|
|
135
|
+
def _diversity_loss(self, cfs: List[np.ndarray]) -> float:
|
|
136
|
+
"""Compute diversity loss (encourage different counterfactuals)."""
|
|
137
|
+
if len(cfs) < 2:
|
|
138
|
+
return 0.0
|
|
139
|
+
|
|
140
|
+
total_dist = 0.0
|
|
141
|
+
count = 0
|
|
142
|
+
for i in range(len(cfs)):
|
|
143
|
+
for j in range(i + 1, len(cfs)):
|
|
144
|
+
diff = (cfs[i] - cfs[j]) / self.scales
|
|
145
|
+
total_dist += np.sum(diff ** 2)
|
|
146
|
+
count += 1
|
|
147
|
+
|
|
148
|
+
return -total_dist / count if count > 0 else 0.0
|
|
149
|
+
|
|
150
|
+
def _generate_single_counterfactual(
|
|
151
|
+
self,
|
|
152
|
+
instance: np.ndarray,
|
|
153
|
+
target_class: int,
|
|
154
|
+
max_iter: int = 100
|
|
155
|
+
) -> Optional[np.ndarray]:
|
|
156
|
+
"""
|
|
157
|
+
Generate a single counterfactual using optimization.
|
|
158
|
+
"""
|
|
159
|
+
# Start from a random perturbation of the instance
|
|
160
|
+
cf = instance.copy()
|
|
161
|
+
cf += self.rng.randn(len(cf)) * 0.1 * self.scales
|
|
162
|
+
|
|
163
|
+
# Clip to valid ranges
|
|
164
|
+
for idx, name in enumerate(self.feature_names):
|
|
165
|
+
min_val, max_val = self.feature_ranges.get(name, (-np.inf, np.inf))
|
|
166
|
+
cf[idx] = np.clip(cf[idx], min_val, max_val)
|
|
167
|
+
|
|
168
|
+
def objective(x):
|
|
169
|
+
validity = self._validity_loss(x, target_class)
|
|
170
|
+
proximity = self._proximity_loss(x, instance)
|
|
171
|
+
return validity + self.proximity_weight * proximity
|
|
172
|
+
|
|
173
|
+
# Define bounds
|
|
174
|
+
bounds = []
|
|
175
|
+
for idx, name in enumerate(self.feature_names):
|
|
176
|
+
min_val, max_val = self.feature_ranges.get(name, (-np.inf, np.inf))
|
|
177
|
+
bounds.append((min_val, max_val))
|
|
178
|
+
|
|
179
|
+
# Optimize
|
|
180
|
+
result = minimize(
|
|
181
|
+
objective,
|
|
182
|
+
cf,
|
|
183
|
+
method='L-BFGS-B',
|
|
184
|
+
bounds=bounds,
|
|
185
|
+
options={'maxiter': max_iter}
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
cf_result = result.x
|
|
189
|
+
|
|
190
|
+
# Check if valid (prediction changed)
|
|
191
|
+
predictions = self.model.predict(cf_result.reshape(1, -1))
|
|
192
|
+
if predictions.ndim == 2:
|
|
193
|
+
pred_class = np.argmax(predictions[0])
|
|
194
|
+
else:
|
|
195
|
+
pred_class = int(predictions[0] > 0.5)
|
|
196
|
+
|
|
197
|
+
if pred_class == target_class:
|
|
198
|
+
return cf_result
|
|
199
|
+
return None
|
|
200
|
+
|
|
201
|
+
def _generate_diverse_counterfactuals(
|
|
202
|
+
self,
|
|
203
|
+
instance: np.ndarray,
|
|
204
|
+
target_class: int,
|
|
205
|
+
num_counterfactuals: int,
|
|
206
|
+
max_attempts: int = 50
|
|
207
|
+
) -> List[np.ndarray]:
|
|
208
|
+
"""
|
|
209
|
+
Generate multiple diverse counterfactuals.
|
|
210
|
+
"""
|
|
211
|
+
counterfactuals = []
|
|
212
|
+
attempts = 0
|
|
213
|
+
|
|
214
|
+
while len(counterfactuals) < num_counterfactuals and attempts < max_attempts:
|
|
215
|
+
# Add some randomization to encourage diversity
|
|
216
|
+
self.rng = np.random.RandomState(self.random_state + attempts)
|
|
217
|
+
|
|
218
|
+
cf = self._generate_single_counterfactual(instance, target_class)
|
|
219
|
+
|
|
220
|
+
if cf is not None:
|
|
221
|
+
# Check if it's diverse enough from existing CFs
|
|
222
|
+
is_diverse = True
|
|
223
|
+
for existing_cf in counterfactuals:
|
|
224
|
+
diff = np.abs(cf - existing_cf) / self.scales
|
|
225
|
+
if np.max(diff) < 0.1: # Too similar
|
|
226
|
+
is_diverse = False
|
|
227
|
+
break
|
|
228
|
+
|
|
229
|
+
if is_diverse:
|
|
230
|
+
counterfactuals.append(cf)
|
|
231
|
+
|
|
232
|
+
attempts += 1
|
|
233
|
+
|
|
234
|
+
return counterfactuals
|
|
235
|
+
|
|
236
|
+
def explain(
|
|
237
|
+
self,
|
|
238
|
+
instance: np.ndarray,
|
|
239
|
+
num_counterfactuals: int = 3,
|
|
240
|
+
desired_class: Optional[int] = None,
|
|
241
|
+
**kwargs
|
|
242
|
+
) -> Explanation:
|
|
243
|
+
"""
|
|
244
|
+
Generate counterfactual explanations.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
instance: The instance to explain (1D array)
|
|
248
|
+
num_counterfactuals: Number of diverse counterfactuals to generate
|
|
249
|
+
desired_class: Target class (default: flip to different class)
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
Explanation object with counterfactuals and changes
|
|
253
|
+
"""
|
|
254
|
+
instance = np.array(instance).flatten()
|
|
255
|
+
target_class = self._get_target_class(instance, desired_class)
|
|
256
|
+
|
|
257
|
+
# Get original prediction
|
|
258
|
+
original_pred = self.model.predict(instance.reshape(1, -1))
|
|
259
|
+
if original_pred.ndim == 2:
|
|
260
|
+
original_class = int(np.argmax(original_pred[0]))
|
|
261
|
+
else:
|
|
262
|
+
original_class = int(original_pred[0] > 0.5)
|
|
263
|
+
|
|
264
|
+
# Generate counterfactuals
|
|
265
|
+
counterfactuals = self._generate_diverse_counterfactuals(
|
|
266
|
+
instance, target_class, num_counterfactuals
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
# Compute changes for each counterfactual
|
|
270
|
+
all_changes = []
|
|
271
|
+
for cf in counterfactuals:
|
|
272
|
+
changes = {}
|
|
273
|
+
for idx, name in enumerate(self.feature_names):
|
|
274
|
+
diff = cf[idx] - instance[idx]
|
|
275
|
+
if abs(diff) > 1e-6:
|
|
276
|
+
changes[name] = {
|
|
277
|
+
"original": float(instance[idx]),
|
|
278
|
+
"counterfactual": float(cf[idx]),
|
|
279
|
+
"change": float(diff)
|
|
280
|
+
}
|
|
281
|
+
all_changes.append(changes)
|
|
282
|
+
|
|
283
|
+
# Compute feature importance based on average change magnitude
|
|
284
|
+
feature_importance = {}
|
|
285
|
+
for idx, name in enumerate(self.feature_names):
|
|
286
|
+
total_change = 0.0
|
|
287
|
+
for cf in counterfactuals:
|
|
288
|
+
total_change += abs(cf[idx] - instance[idx]) / self.scales[idx]
|
|
289
|
+
feature_importance[name] = total_change / max(len(counterfactuals), 1)
|
|
290
|
+
|
|
291
|
+
return Explanation(
|
|
292
|
+
explainer_name="Counterfactual",
|
|
293
|
+
target_class=f"class_{target_class}",
|
|
294
|
+
explanation_data={
|
|
295
|
+
"counterfactuals": [cf.tolist() for cf in counterfactuals],
|
|
296
|
+
"changes": all_changes,
|
|
297
|
+
"original_class": original_class,
|
|
298
|
+
"target_class": target_class,
|
|
299
|
+
"num_generated": len(counterfactuals),
|
|
300
|
+
"feature_attributions": feature_importance
|
|
301
|
+
}
|
|
302
|
+
)
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# src/explainiverse/explainers/global_explainers/__init__.py
|
|
2
|
+
"""
|
|
3
|
+
Global explainers - model-level explanations.
|
|
4
|
+
|
|
5
|
+
These explainers provide insights about the overall model behavior,
|
|
6
|
+
not individual predictions.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from explainiverse.explainers.global_explainers.permutation_importance import (
|
|
10
|
+
PermutationImportanceExplainer
|
|
11
|
+
)
|
|
12
|
+
from explainiverse.explainers.global_explainers.partial_dependence import (
|
|
13
|
+
PartialDependenceExplainer
|
|
14
|
+
)
|
|
15
|
+
from explainiverse.explainers.global_explainers.ale import ALEExplainer
|
|
16
|
+
from explainiverse.explainers.global_explainers.sage import SAGEExplainer
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"PermutationImportanceExplainer",
|
|
20
|
+
"PartialDependenceExplainer",
|
|
21
|
+
"ALEExplainer",
|
|
22
|
+
"SAGEExplainer",
|
|
23
|
+
]
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
# src/explainiverse/explainers/global_explainers/ale.py
|
|
2
|
+
"""
|
|
3
|
+
Accumulated Local Effects (ALE) Explainer.
|
|
4
|
+
|
|
5
|
+
ALE plots are an alternative to Partial Dependence Plots that are unbiased
|
|
6
|
+
when features are correlated. They measure how the prediction changes locally
|
|
7
|
+
when the feature value changes.
|
|
8
|
+
|
|
9
|
+
Reference:
|
|
10
|
+
Apley, D.W. & Zhu, J. (2020). Visualizing the Effects of Predictor Variables
|
|
11
|
+
in Black Box Supervised Learning Models. Journal of the Royal Statistical Society
|
|
12
|
+
Series B, 82(4), 1059-1086.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
from typing import List, Optional, Union, Tuple
|
|
17
|
+
from explainiverse.core.explainer import BaseExplainer
|
|
18
|
+
from explainiverse.core.explanation import Explanation
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ALEExplainer(BaseExplainer):
|
|
22
|
+
"""
|
|
23
|
+
Accumulated Local Effects (ALE) explainer.
|
|
24
|
+
|
|
25
|
+
Unlike PDP, ALE avoids extrapolation issues when features are correlated
|
|
26
|
+
by using local differences rather than marginal averages.
|
|
27
|
+
|
|
28
|
+
Attributes:
|
|
29
|
+
model: Model adapter with .predict() method
|
|
30
|
+
X: Training/reference data
|
|
31
|
+
feature_names: List of feature names
|
|
32
|
+
n_bins: Number of bins for computing ALE
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
model,
|
|
38
|
+
X: np.ndarray,
|
|
39
|
+
feature_names: List[str],
|
|
40
|
+
n_bins: int = 20
|
|
41
|
+
):
|
|
42
|
+
"""
|
|
43
|
+
Initialize the ALE explainer.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
model: Model adapter with .predict() method
|
|
47
|
+
X: Reference dataset (n_samples, n_features)
|
|
48
|
+
feature_names: List of feature names
|
|
49
|
+
n_bins: Number of bins for ALE computation
|
|
50
|
+
"""
|
|
51
|
+
super().__init__(model)
|
|
52
|
+
self.X = np.array(X)
|
|
53
|
+
self.feature_names = list(feature_names)
|
|
54
|
+
self.n_bins = n_bins
|
|
55
|
+
|
|
56
|
+
def _get_feature_idx(self, feature: Union[int, str]) -> int:
|
|
57
|
+
"""Convert feature name to index if needed."""
|
|
58
|
+
if isinstance(feature, str):
|
|
59
|
+
return self.feature_names.index(feature)
|
|
60
|
+
return feature
|
|
61
|
+
|
|
62
|
+
def _compute_quantile_bins(self, values: np.ndarray) -> np.ndarray:
|
|
63
|
+
"""
|
|
64
|
+
Compute bin edges using quantiles to ensure similar sample sizes per bin.
|
|
65
|
+
"""
|
|
66
|
+
percentiles = np.linspace(0, 100, self.n_bins + 1)
|
|
67
|
+
bin_edges = np.percentile(values, percentiles)
|
|
68
|
+
# Remove duplicate edges
|
|
69
|
+
bin_edges = np.unique(bin_edges)
|
|
70
|
+
return bin_edges
|
|
71
|
+
|
|
72
|
+
def _compute_ale_1d(
|
|
73
|
+
self,
|
|
74
|
+
feature_idx: int,
|
|
75
|
+
target_class: int = 1
|
|
76
|
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
77
|
+
"""
|
|
78
|
+
Compute 1D ALE for a single feature.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
feature_idx: Index of the feature
|
|
82
|
+
target_class: Class index for which to compute ALE
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
Tuple of (bin_centers, ale_values, bin_edges)
|
|
86
|
+
"""
|
|
87
|
+
values = self.X[:, feature_idx]
|
|
88
|
+
bin_edges = self._compute_quantile_bins(values)
|
|
89
|
+
|
|
90
|
+
if len(bin_edges) < 2:
|
|
91
|
+
# Not enough unique values
|
|
92
|
+
return np.array([np.mean(values)]), np.array([0.0]), bin_edges
|
|
93
|
+
|
|
94
|
+
# Compute local effects for each bin
|
|
95
|
+
local_effects = []
|
|
96
|
+
|
|
97
|
+
for i in range(len(bin_edges) - 1):
|
|
98
|
+
lower, upper = bin_edges[i], bin_edges[i + 1]
|
|
99
|
+
|
|
100
|
+
# Find samples in this bin
|
|
101
|
+
if i == len(bin_edges) - 2:
|
|
102
|
+
# Include upper bound in last bin
|
|
103
|
+
in_bin = (values >= lower) & (values <= upper)
|
|
104
|
+
else:
|
|
105
|
+
in_bin = (values >= lower) & (values < upper)
|
|
106
|
+
|
|
107
|
+
if not np.any(in_bin):
|
|
108
|
+
local_effects.append(0.0)
|
|
109
|
+
continue
|
|
110
|
+
|
|
111
|
+
X_bin = self.X[in_bin]
|
|
112
|
+
|
|
113
|
+
# Compute predictions at bin edges
|
|
114
|
+
X_lower = X_bin.copy()
|
|
115
|
+
X_lower[:, feature_idx] = lower
|
|
116
|
+
|
|
117
|
+
X_upper = X_bin.copy()
|
|
118
|
+
X_upper[:, feature_idx] = upper
|
|
119
|
+
|
|
120
|
+
pred_lower = self.model.predict(X_lower)
|
|
121
|
+
pred_upper = self.model.predict(X_upper)
|
|
122
|
+
|
|
123
|
+
# Extract target class predictions
|
|
124
|
+
if pred_lower.ndim == 2:
|
|
125
|
+
pred_lower = pred_lower[:, target_class]
|
|
126
|
+
pred_upper = pred_upper[:, target_class]
|
|
127
|
+
|
|
128
|
+
# Local effect = average difference
|
|
129
|
+
effect = np.mean(pred_upper - pred_lower)
|
|
130
|
+
local_effects.append(effect)
|
|
131
|
+
|
|
132
|
+
# Accumulate effects
|
|
133
|
+
ale_values = np.cumsum(local_effects)
|
|
134
|
+
|
|
135
|
+
# Center around zero (mean-center)
|
|
136
|
+
ale_values = ale_values - np.mean(ale_values)
|
|
137
|
+
|
|
138
|
+
# Compute bin centers
|
|
139
|
+
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
|
|
140
|
+
|
|
141
|
+
return bin_centers, ale_values, bin_edges
|
|
142
|
+
|
|
143
|
+
def explain(
|
|
144
|
+
self,
|
|
145
|
+
feature: Union[int, str],
|
|
146
|
+
target_class: int = 1,
|
|
147
|
+
**kwargs
|
|
148
|
+
) -> Explanation:
|
|
149
|
+
"""
|
|
150
|
+
Compute ALE for a specified feature.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
feature: Feature index or name
|
|
154
|
+
target_class: Class index for which to compute ALE
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
Explanation object with ALE values
|
|
158
|
+
"""
|
|
159
|
+
idx = self._get_feature_idx(feature)
|
|
160
|
+
bin_centers, ale_values, bin_edges = self._compute_ale_1d(idx, target_class)
|
|
161
|
+
|
|
162
|
+
feature_name = self.feature_names[idx]
|
|
163
|
+
|
|
164
|
+
return Explanation(
|
|
165
|
+
explainer_name="ALE",
|
|
166
|
+
target_class=f"class_{target_class}",
|
|
167
|
+
explanation_data={
|
|
168
|
+
"ale_values": ale_values.tolist(),
|
|
169
|
+
"bin_centers": bin_centers.tolist(),
|
|
170
|
+
"bin_edges": bin_edges.tolist(),
|
|
171
|
+
"feature": feature_name,
|
|
172
|
+
"feature_attributions": {
|
|
173
|
+
feature_name: float(np.max(ale_values) - np.min(ale_values))
|
|
174
|
+
}
|
|
175
|
+
}
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
def explain_all(self, target_class: int = 1) -> List[Explanation]:
|
|
179
|
+
"""
|
|
180
|
+
Compute ALE for all features.
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
target_class: Class index for which to compute ALE
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
List of Explanation objects, one per feature
|
|
187
|
+
"""
|
|
188
|
+
return [
|
|
189
|
+
self.explain(idx, target_class)
|
|
190
|
+
for idx in range(len(self.feature_names))
|
|
191
|
+
]
|