explainiverse 0.1.1a1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. explainiverse/__init__.py +45 -1
  2. explainiverse/adapters/__init__.py +9 -0
  3. explainiverse/adapters/base_adapter.py +25 -25
  4. explainiverse/adapters/sklearn_adapter.py +32 -32
  5. explainiverse/core/__init__.py +22 -0
  6. explainiverse/core/explainer.py +31 -31
  7. explainiverse/core/explanation.py +24 -24
  8. explainiverse/core/registry.py +545 -0
  9. explainiverse/engine/__init__.py +8 -0
  10. explainiverse/engine/suite.py +142 -142
  11. explainiverse/evaluation/__init__.py +8 -0
  12. explainiverse/evaluation/metrics.py +232 -232
  13. explainiverse/explainers/__init__.py +38 -0
  14. explainiverse/explainers/attribution/__init__.py +9 -0
  15. explainiverse/explainers/attribution/lime_wrapper.py +90 -63
  16. explainiverse/explainers/attribution/shap_wrapper.py +89 -66
  17. explainiverse/explainers/counterfactual/__init__.py +8 -0
  18. explainiverse/explainers/counterfactual/dice_wrapper.py +302 -0
  19. explainiverse/explainers/global_explainers/__init__.py +23 -0
  20. explainiverse/explainers/global_explainers/ale.py +191 -0
  21. explainiverse/explainers/global_explainers/partial_dependence.py +192 -0
  22. explainiverse/explainers/global_explainers/permutation_importance.py +123 -0
  23. explainiverse/explainers/global_explainers/sage.py +164 -0
  24. explainiverse/explainers/rule_based/__init__.py +8 -0
  25. explainiverse/explainers/rule_based/anchors_wrapper.py +350 -0
  26. explainiverse-0.2.0.dist-info/METADATA +264 -0
  27. explainiverse-0.2.0.dist-info/RECORD +29 -0
  28. explainiverse-0.1.1a1.dist-info/METADATA +0 -128
  29. explainiverse-0.1.1a1.dist-info/RECORD +0 -19
  30. {explainiverse-0.1.1a1.dist-info → explainiverse-0.2.0.dist-info}/LICENSE +0 -0
  31. {explainiverse-0.1.1a1.dist-info → explainiverse-0.2.0.dist-info}/WHEEL +0 -0
@@ -1,66 +1,89 @@
1
- # src/explainiverse/explainers/attribution/shap_wrapper.py
2
-
3
- import shap
4
- import numpy as np
5
-
6
- from explainiverse.core.explainer import BaseExplainer
7
- from explainiverse.core.explanation import Explanation
8
-
9
-
10
- class ShapExplainer(BaseExplainer):
11
- """
12
- SHAP explainer (KernelSHAP-based) for model-agnostic explanations.
13
- """
14
-
15
- def __init__(self, model, background_data, feature_names, class_names):
16
- """
17
- Args:
18
- model: A model adapter with a .predict method.
19
- background_data: A 2D numpy array used as SHAP background distribution.
20
- feature_names: List of feature names.
21
- class_names: List of class labels.
22
- """
23
- super().__init__(model)
24
- self.feature_names = feature_names
25
- self.class_names = class_names
26
- self.explainer = shap.KernelExplainer(model.predict, background_data)
27
-
28
-
29
- def explain(self, instance, top_labels=1):
30
- """
31
- Generate SHAP explanation for a single instance.
32
-
33
- Args:
34
- instance: 1D numpy array of input features.
35
- top_labels: Number of top classes to explain (default: 1)
36
-
37
- Returns:
38
- Explanation object
39
- """
40
- instance = np.array(instance).reshape(1, -1) # Ensure 2D
41
- shap_values = self.explainer.shap_values(instance)
42
-
43
- if isinstance(shap_values, list):
44
- # Multi-class: list of arrays, one per class
45
- predicted_probs = self.model.predict(instance)[0]
46
- top_indices = np.argsort(predicted_probs)[-top_labels:][::-1]
47
- label_index = top_indices[0]
48
- label_name = self.class_names[label_index]
49
- class_shap = shap_values[label_index][0]
50
- else:
51
- # Single-class (regression or binary classification)
52
- label_index = 0
53
- label_name = self.class_names[0] if self.class_names else "class_0"
54
- class_shap = shap_values[0]
55
-
56
- flat_shap = np.array(class_shap).flatten()
57
- attributions = {
58
- fname: float(flat_shap[i])
59
- for i, fname in enumerate(self.feature_names)
60
- }
61
-
62
- return Explanation(
63
- explainer_name="SHAP",
64
- target_class=label_name,
65
- explanation_data={"feature_attributions": attributions}
66
- )
1
+ # src/explainiverse/explainers/attribution/shap_wrapper.py
2
+ """
3
+ SHAP Explainer - SHapley Additive exPlanations.
4
+
5
+ SHAP values provide a unified measure of feature importance based on
6
+ game-theoretic Shapley values, offering both local and global interpretability.
7
+
8
+ Reference:
9
+ Lundberg, S.M. & Lee, S.I. (2017). A Unified Approach to Interpreting
10
+ Model Predictions. NeurIPS 2017.
11
+ """
12
+
13
+ import shap
14
+ import numpy as np
15
+
16
+ from explainiverse.core.explainer import BaseExplainer
17
+ from explainiverse.core.explanation import Explanation
18
+
19
+
20
+ class ShapExplainer(BaseExplainer):
21
+ """
22
+ SHAP explainer (KernelSHAP-based) for model-agnostic explanations.
23
+
24
+ KernelSHAP is a model-agnostic method that approximates SHAP values
25
+ using a weighted linear regression. It works with any model that
26
+ provides predictions.
27
+
28
+ Attributes:
29
+ model: Model adapter with .predict() method
30
+ feature_names: List of feature names
31
+ class_names: List of class labels
32
+ explainer: The underlying SHAP KernelExplainer
33
+ """
34
+
35
+ def __init__(self, model, background_data, feature_names, class_names):
36
+ """
37
+ Initialize the SHAP explainer.
38
+
39
+ Args:
40
+ model: A model adapter with a .predict method.
41
+ background_data: A 2D numpy array used as SHAP background distribution.
42
+ Typically a representative sample of training data.
43
+ feature_names: List of feature names.
44
+ class_names: List of class labels.
45
+ """
46
+ super().__init__(model)
47
+ self.feature_names = list(feature_names)
48
+ self.class_names = list(class_names)
49
+ self.explainer = shap.KernelExplainer(model.predict, background_data)
50
+
51
+ def explain(self, instance, top_labels=1):
52
+ """
53
+ Generate SHAP explanation for a single instance.
54
+
55
+ Args:
56
+ instance: 1D numpy array of input features.
57
+ top_labels: Number of top classes to explain (default: 1)
58
+
59
+ Returns:
60
+ Explanation object with feature attributions
61
+ """
62
+ instance = np.array(instance).reshape(1, -1) # Ensure 2D
63
+ shap_values = self.explainer.shap_values(instance)
64
+
65
+ if isinstance(shap_values, list):
66
+ # Multi-class: list of arrays, one per class
67
+ predicted_probs = self.model.predict(instance)[0]
68
+ top_indices = np.argsort(predicted_probs)[-top_labels:][::-1]
69
+ label_index = top_indices[0]
70
+ label_name = self.class_names[label_index]
71
+ class_shap = shap_values[label_index][0]
72
+ else:
73
+ # Single-class (regression or binary classification)
74
+ label_index = 0
75
+ label_name = self.class_names[0] if self.class_names else "class_0"
76
+ class_shap = shap_values[0]
77
+
78
+ # Build attributions dict
79
+ flat_shap = np.array(class_shap).flatten()
80
+ attributions = {
81
+ fname: float(flat_shap[i])
82
+ for i, fname in enumerate(self.feature_names)
83
+ }
84
+
85
+ return Explanation(
86
+ explainer_name="SHAP",
87
+ target_class=label_name,
88
+ explanation_data={"feature_attributions": attributions}
89
+ )
@@ -0,0 +1,8 @@
1
+ # src/explainiverse/explainers/counterfactual/__init__.py
2
+ """
3
+ Counterfactual explainers - "what-if" explanations.
4
+ """
5
+
6
+ from explainiverse.explainers.counterfactual.dice_wrapper import CounterfactualExplainer
7
+
8
+ __all__ = ["CounterfactualExplainer"]
@@ -0,0 +1,302 @@
1
+ # src/explainiverse/explainers/counterfactual/dice_wrapper.py
2
+ """
3
+ Counterfactual Explainer - DiCE-style diverse counterfactual explanations.
4
+
5
+ Counterfactual explanations answer "What minimal changes would flip the prediction?"
6
+
7
+ Reference:
8
+ Mothilal, R.K., Sharma, A., & Tan, C. (2020). Explaining Machine Learning
9
+ Classifiers through Diverse Counterfactual Explanations. FAT* 2020.
10
+ """
11
+
12
+ import numpy as np
13
+ from typing import List, Optional, Dict, Any, Union
14
+ from scipy.optimize import minimize
15
+ from explainiverse.core.explainer import BaseExplainer
16
+ from explainiverse.core.explanation import Explanation
17
+
18
+
19
+ class CounterfactualExplainer(BaseExplainer):
20
+ """
21
+ Counterfactual explainer using gradient-free optimization.
22
+
23
+ Generates minimal perturbations that change the model's prediction
24
+ to a desired class (or just a different class).
25
+
26
+ Attributes:
27
+ model: Model adapter with .predict() method
28
+ training_data: Reference data for constraints
29
+ feature_names: List of feature names
30
+ continuous_features: List of continuous feature names
31
+ categorical_features: List of categorical feature names
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ model,
37
+ training_data: np.ndarray,
38
+ feature_names: List[str],
39
+ continuous_features: Optional[List[str]] = None,
40
+ categorical_features: Optional[List[str]] = None,
41
+ feature_ranges: Optional[Dict[str, tuple]] = None,
42
+ proximity_weight: float = 0.5,
43
+ diversity_weight: float = 0.5,
44
+ random_state: int = 42
45
+ ):
46
+ """
47
+ Initialize the Counterfactual explainer.
48
+
49
+ Args:
50
+ model: Model adapter with .predict() method
51
+ training_data: Reference data (n_samples, n_features)
52
+ feature_names: List of feature names
53
+ continuous_features: Features that can take continuous values
54
+ categorical_features: Features with discrete values
55
+ feature_ranges: Dict of {feature_name: (min, max)} constraints
56
+ proximity_weight: Weight for proximity loss (closer to original)
57
+ diversity_weight: Weight for diversity among counterfactuals
58
+ random_state: Random seed
59
+ """
60
+ super().__init__(model)
61
+ self.training_data = np.array(training_data)
62
+ self.feature_names = list(feature_names)
63
+ self.continuous_features = continuous_features or feature_names
64
+ self.categorical_features = categorical_features or []
65
+ self.proximity_weight = proximity_weight
66
+ self.diversity_weight = diversity_weight
67
+ self.random_state = random_state
68
+ self.rng = np.random.RandomState(random_state)
69
+
70
+ # Compute feature ranges from data if not provided
71
+ if feature_ranges:
72
+ self.feature_ranges = feature_ranges
73
+ else:
74
+ self.feature_ranges = {}
75
+ for idx, name in enumerate(feature_names):
76
+ values = self.training_data[:, idx]
77
+ self.feature_ranges[name] = (float(np.min(values)), float(np.max(values)))
78
+
79
+ # Compute feature scales for normalization
80
+ self._compute_scales()
81
+
82
+ def _compute_scales(self):
83
+ """Compute scaling factors for each feature."""
84
+ self.scales = np.zeros(len(self.feature_names))
85
+ for idx, name in enumerate(self.feature_names):
86
+ min_val, max_val = self.feature_ranges.get(name, (0, 1))
87
+ scale = max_val - min_val
88
+ self.scales[idx] = scale if scale > 0 else 1.0
89
+
90
+ def _get_target_class(
91
+ self,
92
+ instance: np.ndarray,
93
+ desired_class: Optional[int] = None
94
+ ) -> int:
95
+ """Determine the target class for the counterfactual."""
96
+ predictions = self.model.predict(instance.reshape(1, -1))
97
+
98
+ if predictions.ndim == 2:
99
+ current_class = np.argmax(predictions[0])
100
+ n_classes = predictions.shape[1]
101
+ else:
102
+ current_class = int(predictions[0] > 0.5)
103
+ n_classes = 2
104
+
105
+ if desired_class is not None:
106
+ return desired_class
107
+
108
+ # Default: flip to any other class
109
+ if n_classes == 2:
110
+ return 1 - current_class
111
+ else:
112
+ # For multi-class, pick the second most likely class
113
+ probs = predictions[0]
114
+ sorted_classes = np.argsort(probs)[::-1]
115
+ return int(sorted_classes[1]) if sorted_classes[0] == current_class else int(sorted_classes[0])
116
+
117
+ def _proximity_loss(self, cf: np.ndarray, original: np.ndarray) -> float:
118
+ """Compute normalized distance between counterfactual and original."""
119
+ diff = (cf - original) / self.scales
120
+ return float(np.sum(diff ** 2))
121
+
122
+ def _validity_loss(self, cf: np.ndarray, target_class: int) -> float:
123
+ """Compute loss for achieving the target class."""
124
+ predictions = self.model.predict(cf.reshape(1, -1))
125
+
126
+ if predictions.ndim == 2:
127
+ target_prob = predictions[0, target_class]
128
+ return -np.log(target_prob + 1e-10)
129
+ else:
130
+ if target_class == 1:
131
+ return -np.log(predictions[0] + 1e-10)
132
+ else:
133
+ return -np.log(1 - predictions[0] + 1e-10)
134
+
135
+ def _diversity_loss(self, cfs: List[np.ndarray]) -> float:
136
+ """Compute diversity loss (encourage different counterfactuals)."""
137
+ if len(cfs) < 2:
138
+ return 0.0
139
+
140
+ total_dist = 0.0
141
+ count = 0
142
+ for i in range(len(cfs)):
143
+ for j in range(i + 1, len(cfs)):
144
+ diff = (cfs[i] - cfs[j]) / self.scales
145
+ total_dist += np.sum(diff ** 2)
146
+ count += 1
147
+
148
+ return -total_dist / count if count > 0 else 0.0
149
+
150
+ def _generate_single_counterfactual(
151
+ self,
152
+ instance: np.ndarray,
153
+ target_class: int,
154
+ max_iter: int = 100
155
+ ) -> Optional[np.ndarray]:
156
+ """
157
+ Generate a single counterfactual using optimization.
158
+ """
159
+ # Start from a random perturbation of the instance
160
+ cf = instance.copy()
161
+ cf += self.rng.randn(len(cf)) * 0.1 * self.scales
162
+
163
+ # Clip to valid ranges
164
+ for idx, name in enumerate(self.feature_names):
165
+ min_val, max_val = self.feature_ranges.get(name, (-np.inf, np.inf))
166
+ cf[idx] = np.clip(cf[idx], min_val, max_val)
167
+
168
+ def objective(x):
169
+ validity = self._validity_loss(x, target_class)
170
+ proximity = self._proximity_loss(x, instance)
171
+ return validity + self.proximity_weight * proximity
172
+
173
+ # Define bounds
174
+ bounds = []
175
+ for idx, name in enumerate(self.feature_names):
176
+ min_val, max_val = self.feature_ranges.get(name, (-np.inf, np.inf))
177
+ bounds.append((min_val, max_val))
178
+
179
+ # Optimize
180
+ result = minimize(
181
+ objective,
182
+ cf,
183
+ method='L-BFGS-B',
184
+ bounds=bounds,
185
+ options={'maxiter': max_iter}
186
+ )
187
+
188
+ cf_result = result.x
189
+
190
+ # Check if valid (prediction changed)
191
+ predictions = self.model.predict(cf_result.reshape(1, -1))
192
+ if predictions.ndim == 2:
193
+ pred_class = np.argmax(predictions[0])
194
+ else:
195
+ pred_class = int(predictions[0] > 0.5)
196
+
197
+ if pred_class == target_class:
198
+ return cf_result
199
+ return None
200
+
201
+ def _generate_diverse_counterfactuals(
202
+ self,
203
+ instance: np.ndarray,
204
+ target_class: int,
205
+ num_counterfactuals: int,
206
+ max_attempts: int = 50
207
+ ) -> List[np.ndarray]:
208
+ """
209
+ Generate multiple diverse counterfactuals.
210
+ """
211
+ counterfactuals = []
212
+ attempts = 0
213
+
214
+ while len(counterfactuals) < num_counterfactuals and attempts < max_attempts:
215
+ # Add some randomization to encourage diversity
216
+ self.rng = np.random.RandomState(self.random_state + attempts)
217
+
218
+ cf = self._generate_single_counterfactual(instance, target_class)
219
+
220
+ if cf is not None:
221
+ # Check if it's diverse enough from existing CFs
222
+ is_diverse = True
223
+ for existing_cf in counterfactuals:
224
+ diff = np.abs(cf - existing_cf) / self.scales
225
+ if np.max(diff) < 0.1: # Too similar
226
+ is_diverse = False
227
+ break
228
+
229
+ if is_diverse:
230
+ counterfactuals.append(cf)
231
+
232
+ attempts += 1
233
+
234
+ return counterfactuals
235
+
236
+ def explain(
237
+ self,
238
+ instance: np.ndarray,
239
+ num_counterfactuals: int = 3,
240
+ desired_class: Optional[int] = None,
241
+ **kwargs
242
+ ) -> Explanation:
243
+ """
244
+ Generate counterfactual explanations.
245
+
246
+ Args:
247
+ instance: The instance to explain (1D array)
248
+ num_counterfactuals: Number of diverse counterfactuals to generate
249
+ desired_class: Target class (default: flip to different class)
250
+
251
+ Returns:
252
+ Explanation object with counterfactuals and changes
253
+ """
254
+ instance = np.array(instance).flatten()
255
+ target_class = self._get_target_class(instance, desired_class)
256
+
257
+ # Get original prediction
258
+ original_pred = self.model.predict(instance.reshape(1, -1))
259
+ if original_pred.ndim == 2:
260
+ original_class = int(np.argmax(original_pred[0]))
261
+ else:
262
+ original_class = int(original_pred[0] > 0.5)
263
+
264
+ # Generate counterfactuals
265
+ counterfactuals = self._generate_diverse_counterfactuals(
266
+ instance, target_class, num_counterfactuals
267
+ )
268
+
269
+ # Compute changes for each counterfactual
270
+ all_changes = []
271
+ for cf in counterfactuals:
272
+ changes = {}
273
+ for idx, name in enumerate(self.feature_names):
274
+ diff = cf[idx] - instance[idx]
275
+ if abs(diff) > 1e-6:
276
+ changes[name] = {
277
+ "original": float(instance[idx]),
278
+ "counterfactual": float(cf[idx]),
279
+ "change": float(diff)
280
+ }
281
+ all_changes.append(changes)
282
+
283
+ # Compute feature importance based on average change magnitude
284
+ feature_importance = {}
285
+ for idx, name in enumerate(self.feature_names):
286
+ total_change = 0.0
287
+ for cf in counterfactuals:
288
+ total_change += abs(cf[idx] - instance[idx]) / self.scales[idx]
289
+ feature_importance[name] = total_change / max(len(counterfactuals), 1)
290
+
291
+ return Explanation(
292
+ explainer_name="Counterfactual",
293
+ target_class=f"class_{target_class}",
294
+ explanation_data={
295
+ "counterfactuals": [cf.tolist() for cf in counterfactuals],
296
+ "changes": all_changes,
297
+ "original_class": original_class,
298
+ "target_class": target_class,
299
+ "num_generated": len(counterfactuals),
300
+ "feature_attributions": feature_importance
301
+ }
302
+ )
@@ -0,0 +1,23 @@
1
+ # src/explainiverse/explainers/global_explainers/__init__.py
2
+ """
3
+ Global explainers - model-level explanations.
4
+
5
+ These explainers provide insights about the overall model behavior,
6
+ not individual predictions.
7
+ """
8
+
9
+ from explainiverse.explainers.global_explainers.permutation_importance import (
10
+ PermutationImportanceExplainer
11
+ )
12
+ from explainiverse.explainers.global_explainers.partial_dependence import (
13
+ PartialDependenceExplainer
14
+ )
15
+ from explainiverse.explainers.global_explainers.ale import ALEExplainer
16
+ from explainiverse.explainers.global_explainers.sage import SAGEExplainer
17
+
18
+ __all__ = [
19
+ "PermutationImportanceExplainer",
20
+ "PartialDependenceExplainer",
21
+ "ALEExplainer",
22
+ "SAGEExplainer",
23
+ ]
@@ -0,0 +1,191 @@
1
+ # src/explainiverse/explainers/global_explainers/ale.py
2
+ """
3
+ Accumulated Local Effects (ALE) Explainer.
4
+
5
+ ALE plots are an alternative to Partial Dependence Plots that are unbiased
6
+ when features are correlated. They measure how the prediction changes locally
7
+ when the feature value changes.
8
+
9
+ Reference:
10
+ Apley, D.W. & Zhu, J. (2020). Visualizing the Effects of Predictor Variables
11
+ in Black Box Supervised Learning Models. Journal of the Royal Statistical Society
12
+ Series B, 82(4), 1059-1086.
13
+ """
14
+
15
+ import numpy as np
16
+ from typing import List, Optional, Union, Tuple
17
+ from explainiverse.core.explainer import BaseExplainer
18
+ from explainiverse.core.explanation import Explanation
19
+
20
+
21
+ class ALEExplainer(BaseExplainer):
22
+ """
23
+ Accumulated Local Effects (ALE) explainer.
24
+
25
+ Unlike PDP, ALE avoids extrapolation issues when features are correlated
26
+ by using local differences rather than marginal averages.
27
+
28
+ Attributes:
29
+ model: Model adapter with .predict() method
30
+ X: Training/reference data
31
+ feature_names: List of feature names
32
+ n_bins: Number of bins for computing ALE
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ model,
38
+ X: np.ndarray,
39
+ feature_names: List[str],
40
+ n_bins: int = 20
41
+ ):
42
+ """
43
+ Initialize the ALE explainer.
44
+
45
+ Args:
46
+ model: Model adapter with .predict() method
47
+ X: Reference dataset (n_samples, n_features)
48
+ feature_names: List of feature names
49
+ n_bins: Number of bins for ALE computation
50
+ """
51
+ super().__init__(model)
52
+ self.X = np.array(X)
53
+ self.feature_names = list(feature_names)
54
+ self.n_bins = n_bins
55
+
56
+ def _get_feature_idx(self, feature: Union[int, str]) -> int:
57
+ """Convert feature name to index if needed."""
58
+ if isinstance(feature, str):
59
+ return self.feature_names.index(feature)
60
+ return feature
61
+
62
+ def _compute_quantile_bins(self, values: np.ndarray) -> np.ndarray:
63
+ """
64
+ Compute bin edges using quantiles to ensure similar sample sizes per bin.
65
+ """
66
+ percentiles = np.linspace(0, 100, self.n_bins + 1)
67
+ bin_edges = np.percentile(values, percentiles)
68
+ # Remove duplicate edges
69
+ bin_edges = np.unique(bin_edges)
70
+ return bin_edges
71
+
72
+ def _compute_ale_1d(
73
+ self,
74
+ feature_idx: int,
75
+ target_class: int = 1
76
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
77
+ """
78
+ Compute 1D ALE for a single feature.
79
+
80
+ Args:
81
+ feature_idx: Index of the feature
82
+ target_class: Class index for which to compute ALE
83
+
84
+ Returns:
85
+ Tuple of (bin_centers, ale_values, bin_edges)
86
+ """
87
+ values = self.X[:, feature_idx]
88
+ bin_edges = self._compute_quantile_bins(values)
89
+
90
+ if len(bin_edges) < 2:
91
+ # Not enough unique values
92
+ return np.array([np.mean(values)]), np.array([0.0]), bin_edges
93
+
94
+ # Compute local effects for each bin
95
+ local_effects = []
96
+
97
+ for i in range(len(bin_edges) - 1):
98
+ lower, upper = bin_edges[i], bin_edges[i + 1]
99
+
100
+ # Find samples in this bin
101
+ if i == len(bin_edges) - 2:
102
+ # Include upper bound in last bin
103
+ in_bin = (values >= lower) & (values <= upper)
104
+ else:
105
+ in_bin = (values >= lower) & (values < upper)
106
+
107
+ if not np.any(in_bin):
108
+ local_effects.append(0.0)
109
+ continue
110
+
111
+ X_bin = self.X[in_bin]
112
+
113
+ # Compute predictions at bin edges
114
+ X_lower = X_bin.copy()
115
+ X_lower[:, feature_idx] = lower
116
+
117
+ X_upper = X_bin.copy()
118
+ X_upper[:, feature_idx] = upper
119
+
120
+ pred_lower = self.model.predict(X_lower)
121
+ pred_upper = self.model.predict(X_upper)
122
+
123
+ # Extract target class predictions
124
+ if pred_lower.ndim == 2:
125
+ pred_lower = pred_lower[:, target_class]
126
+ pred_upper = pred_upper[:, target_class]
127
+
128
+ # Local effect = average difference
129
+ effect = np.mean(pred_upper - pred_lower)
130
+ local_effects.append(effect)
131
+
132
+ # Accumulate effects
133
+ ale_values = np.cumsum(local_effects)
134
+
135
+ # Center around zero (mean-center)
136
+ ale_values = ale_values - np.mean(ale_values)
137
+
138
+ # Compute bin centers
139
+ bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
140
+
141
+ return bin_centers, ale_values, bin_edges
142
+
143
+ def explain(
144
+ self,
145
+ feature: Union[int, str],
146
+ target_class: int = 1,
147
+ **kwargs
148
+ ) -> Explanation:
149
+ """
150
+ Compute ALE for a specified feature.
151
+
152
+ Args:
153
+ feature: Feature index or name
154
+ target_class: Class index for which to compute ALE
155
+
156
+ Returns:
157
+ Explanation object with ALE values
158
+ """
159
+ idx = self._get_feature_idx(feature)
160
+ bin_centers, ale_values, bin_edges = self._compute_ale_1d(idx, target_class)
161
+
162
+ feature_name = self.feature_names[idx]
163
+
164
+ return Explanation(
165
+ explainer_name="ALE",
166
+ target_class=f"class_{target_class}",
167
+ explanation_data={
168
+ "ale_values": ale_values.tolist(),
169
+ "bin_centers": bin_centers.tolist(),
170
+ "bin_edges": bin_edges.tolist(),
171
+ "feature": feature_name,
172
+ "feature_attributions": {
173
+ feature_name: float(np.max(ale_values) - np.min(ale_values))
174
+ }
175
+ }
176
+ )
177
+
178
+ def explain_all(self, target_class: int = 1) -> List[Explanation]:
179
+ """
180
+ Compute ALE for all features.
181
+
182
+ Args:
183
+ target_class: Class index for which to compute ALE
184
+
185
+ Returns:
186
+ List of Explanation objects, one per feature
187
+ """
188
+ return [
189
+ self.explain(idx, target_class)
190
+ for idx in range(len(self.feature_names))
191
+ ]