explainiverse 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- explainiverse/__init__.py +15 -3
- explainiverse/adapters/__init__.py +11 -1
- explainiverse/adapters/pytorch_adapter.py +396 -0
- explainiverse/core/registry.py +18 -0
- explainiverse/explainers/__init__.py +4 -1
- explainiverse/explainers/attribution/__init__.py +2 -1
- explainiverse/explainers/attribution/treeshap_wrapper.py +434 -0
- {explainiverse-0.2.0.dist-info → explainiverse-0.2.2.dist-info}/METADATA +79 -10
- {explainiverse-0.2.0.dist-info → explainiverse-0.2.2.dist-info}/RECORD +11 -9
- {explainiverse-0.2.0.dist-info → explainiverse-0.2.2.dist-info}/LICENSE +0 -0
- {explainiverse-0.2.0.dist-info → explainiverse-0.2.2.dist-info}/WHEEL +0 -0
explainiverse/__init__.py
CHANGED
|
@@ -2,8 +2,9 @@
|
|
|
2
2
|
"""
|
|
3
3
|
Explainiverse - A unified, extensible explainability framework.
|
|
4
4
|
|
|
5
|
-
Supports multiple XAI methods including LIME, SHAP, Anchors,
|
|
6
|
-
Permutation Importance, PDP, ALE, and SAGE through a
|
|
5
|
+
Supports multiple XAI methods including LIME, SHAP, TreeSHAP, Anchors,
|
|
6
|
+
Counterfactuals, Permutation Importance, PDP, ALE, and SAGE through a
|
|
7
|
+
consistent interface.
|
|
7
8
|
|
|
8
9
|
Quick Start:
|
|
9
10
|
from explainiverse import default_registry
|
|
@@ -14,6 +15,10 @@ Quick Start:
|
|
|
14
15
|
# Create an explainer
|
|
15
16
|
explainer = default_registry.create("lime", model=adapter, training_data=X, ...)
|
|
16
17
|
explanation = explainer.explain(instance)
|
|
18
|
+
|
|
19
|
+
For PyTorch models:
|
|
20
|
+
from explainiverse import PyTorchAdapter # Requires torch
|
|
21
|
+
adapter = PyTorchAdapter(model, task="classification")
|
|
17
22
|
"""
|
|
18
23
|
|
|
19
24
|
from explainiverse.core.explainer import BaseExplainer
|
|
@@ -25,9 +30,10 @@ from explainiverse.core.registry import (
|
|
|
25
30
|
get_default_registry,
|
|
26
31
|
)
|
|
27
32
|
from explainiverse.adapters.sklearn_adapter import SklearnAdapter
|
|
33
|
+
from explainiverse.adapters import TORCH_AVAILABLE
|
|
28
34
|
from explainiverse.engine.suite import ExplanationSuite
|
|
29
35
|
|
|
30
|
-
__version__ = "0.2.
|
|
36
|
+
__version__ = "0.2.2"
|
|
31
37
|
|
|
32
38
|
__all__ = [
|
|
33
39
|
# Core
|
|
@@ -40,6 +46,12 @@ __all__ = [
|
|
|
40
46
|
"get_default_registry",
|
|
41
47
|
# Adapters
|
|
42
48
|
"SklearnAdapter",
|
|
49
|
+
"TORCH_AVAILABLE",
|
|
43
50
|
# Engine
|
|
44
51
|
"ExplanationSuite",
|
|
45
52
|
]
|
|
53
|
+
|
|
54
|
+
# Conditionally export PyTorchAdapter if torch is available
|
|
55
|
+
if TORCH_AVAILABLE:
|
|
56
|
+
from explainiverse.adapters import PyTorchAdapter
|
|
57
|
+
__all__.append("PyTorchAdapter")
|
|
@@ -1,9 +1,19 @@
|
|
|
1
1
|
# src/explainiverse/adapters/__init__.py
|
|
2
2
|
"""
|
|
3
3
|
Model adapters - wrappers that provide a consistent interface for different ML frameworks.
|
|
4
|
+
|
|
5
|
+
Available adapters:
|
|
6
|
+
- SklearnAdapter: For scikit-learn models (always available)
|
|
7
|
+
- PyTorchAdapter: For PyTorch nn.Module models (requires torch)
|
|
4
8
|
"""
|
|
5
9
|
|
|
6
10
|
from explainiverse.adapters.base_adapter import BaseModelAdapter
|
|
7
11
|
from explainiverse.adapters.sklearn_adapter import SklearnAdapter
|
|
8
12
|
|
|
9
|
-
|
|
13
|
+
# Conditionally import PyTorchAdapter if torch is available
|
|
14
|
+
try:
|
|
15
|
+
from explainiverse.adapters.pytorch_adapter import PyTorchAdapter, TORCH_AVAILABLE
|
|
16
|
+
__all__ = ["BaseModelAdapter", "SklearnAdapter", "PyTorchAdapter", "TORCH_AVAILABLE"]
|
|
17
|
+
except ImportError:
|
|
18
|
+
TORCH_AVAILABLE = False
|
|
19
|
+
__all__ = ["BaseModelAdapter", "SklearnAdapter", "TORCH_AVAILABLE"]
|
|
@@ -0,0 +1,396 @@
|
|
|
1
|
+
# src/explainiverse/adapters/pytorch_adapter.py
|
|
2
|
+
"""
|
|
3
|
+
PyTorch Model Adapter for Explainiverse.
|
|
4
|
+
|
|
5
|
+
Provides a unified interface for PyTorch neural networks, enabling
|
|
6
|
+
compatibility with all explainers in the framework.
|
|
7
|
+
|
|
8
|
+
Example:
|
|
9
|
+
import torch.nn as nn
|
|
10
|
+
from explainiverse.adapters import PyTorchAdapter
|
|
11
|
+
|
|
12
|
+
model = nn.Sequential(
|
|
13
|
+
nn.Linear(10, 64),
|
|
14
|
+
nn.ReLU(),
|
|
15
|
+
nn.Linear(64, 3)
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
adapter = PyTorchAdapter(
|
|
19
|
+
model,
|
|
20
|
+
task="classification",
|
|
21
|
+
class_names=["cat", "dog", "bird"]
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
probs = adapter.predict(X) # Returns numpy array
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
import numpy as np
|
|
28
|
+
from typing import List, Optional, Union, Callable
|
|
29
|
+
|
|
30
|
+
from .base_adapter import BaseModelAdapter
|
|
31
|
+
|
|
32
|
+
# Check if PyTorch is available
|
|
33
|
+
try:
|
|
34
|
+
import torch
|
|
35
|
+
import torch.nn as nn
|
|
36
|
+
TORCH_AVAILABLE = True
|
|
37
|
+
except ImportError:
|
|
38
|
+
TORCH_AVAILABLE = False
|
|
39
|
+
torch = None
|
|
40
|
+
nn = None
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _check_torch_available():
|
|
44
|
+
"""Raise ImportError if PyTorch is not installed."""
|
|
45
|
+
if not TORCH_AVAILABLE:
|
|
46
|
+
raise ImportError(
|
|
47
|
+
"PyTorch is required for PyTorchAdapter. "
|
|
48
|
+
"Install it with: pip install torch"
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class PyTorchAdapter(BaseModelAdapter):
|
|
53
|
+
"""
|
|
54
|
+
Adapter for PyTorch neural network models.
|
|
55
|
+
|
|
56
|
+
Wraps a PyTorch nn.Module to provide a consistent interface for
|
|
57
|
+
explainability methods. Handles device management, tensor/numpy
|
|
58
|
+
conversions, and supports both classification and regression tasks.
|
|
59
|
+
|
|
60
|
+
Attributes:
|
|
61
|
+
model: The PyTorch model (nn.Module)
|
|
62
|
+
task: "classification" or "regression"
|
|
63
|
+
device: torch.device for computation
|
|
64
|
+
class_names: List of class names (for classification)
|
|
65
|
+
feature_names: List of feature names
|
|
66
|
+
output_activation: Optional activation function for outputs
|
|
67
|
+
|
|
68
|
+
Example:
|
|
69
|
+
>>> model = MyNeuralNetwork()
|
|
70
|
+
>>> adapter = PyTorchAdapter(model, task="classification")
|
|
71
|
+
>>> probs = adapter.predict(X_numpy) # Returns probabilities
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
def __init__(
|
|
75
|
+
self,
|
|
76
|
+
model,
|
|
77
|
+
task: str = "classification",
|
|
78
|
+
feature_names: Optional[List[str]] = None,
|
|
79
|
+
class_names: Optional[List[str]] = None,
|
|
80
|
+
device: Optional[str] = None,
|
|
81
|
+
output_activation: Optional[str] = "auto",
|
|
82
|
+
batch_size: int = 32
|
|
83
|
+
):
|
|
84
|
+
"""
|
|
85
|
+
Initialize the PyTorch adapter.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
model: A PyTorch nn.Module model.
|
|
89
|
+
task: "classification" or "regression".
|
|
90
|
+
feature_names: List of input feature names.
|
|
91
|
+
class_names: List of output class names (classification only).
|
|
92
|
+
device: Device to run on ("cpu", "cuda", "cuda:0", etc.).
|
|
93
|
+
If None, auto-detects based on model parameters.
|
|
94
|
+
output_activation: Activation for output layer:
|
|
95
|
+
- "auto": softmax for classification, none for regression
|
|
96
|
+
- "softmax": Apply softmax (classification)
|
|
97
|
+
- "sigmoid": Apply sigmoid (binary classification)
|
|
98
|
+
- "none" or None: No activation (raw logits/values)
|
|
99
|
+
batch_size: Batch size for large inputs (default: 32).
|
|
100
|
+
"""
|
|
101
|
+
_check_torch_available()
|
|
102
|
+
|
|
103
|
+
if not isinstance(model, nn.Module):
|
|
104
|
+
raise TypeError(
|
|
105
|
+
f"Expected nn.Module, got {type(model).__name__}. "
|
|
106
|
+
"For sklearn models, use SklearnAdapter instead."
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
super().__init__(model, feature_names)
|
|
110
|
+
|
|
111
|
+
self.task = task
|
|
112
|
+
self.class_names = list(class_names) if class_names else None
|
|
113
|
+
self.batch_size = batch_size
|
|
114
|
+
|
|
115
|
+
# Determine device
|
|
116
|
+
if device is not None:
|
|
117
|
+
self.device = torch.device(device)
|
|
118
|
+
else:
|
|
119
|
+
# Auto-detect from model parameters
|
|
120
|
+
try:
|
|
121
|
+
param = next(model.parameters())
|
|
122
|
+
self.device = param.device
|
|
123
|
+
except StopIteration:
|
|
124
|
+
# Model has no parameters, use CPU
|
|
125
|
+
self.device = torch.device("cpu")
|
|
126
|
+
|
|
127
|
+
# Move model to device and set to eval mode
|
|
128
|
+
self.model = model.to(self.device)
|
|
129
|
+
self.model.eval()
|
|
130
|
+
|
|
131
|
+
# Configure output activation
|
|
132
|
+
if output_activation == "auto":
|
|
133
|
+
if task == "classification":
|
|
134
|
+
self.output_activation = "softmax"
|
|
135
|
+
else:
|
|
136
|
+
self.output_activation = None
|
|
137
|
+
else:
|
|
138
|
+
self.output_activation = output_activation if output_activation != "none" else None
|
|
139
|
+
|
|
140
|
+
def _to_tensor(self, data: np.ndarray) -> "torch.Tensor":
|
|
141
|
+
"""Convert numpy array to tensor on the correct device."""
|
|
142
|
+
if isinstance(data, torch.Tensor):
|
|
143
|
+
return data.to(self.device).float()
|
|
144
|
+
return torch.tensor(data, dtype=torch.float32, device=self.device)
|
|
145
|
+
|
|
146
|
+
def _to_numpy(self, tensor: "torch.Tensor") -> np.ndarray:
|
|
147
|
+
"""Convert tensor to numpy array."""
|
|
148
|
+
return tensor.detach().cpu().numpy()
|
|
149
|
+
|
|
150
|
+
def _apply_activation(self, output: "torch.Tensor") -> "torch.Tensor":
|
|
151
|
+
"""Apply output activation function."""
|
|
152
|
+
if self.output_activation == "softmax":
|
|
153
|
+
return torch.softmax(output, dim=-1)
|
|
154
|
+
elif self.output_activation == "sigmoid":
|
|
155
|
+
return torch.sigmoid(output)
|
|
156
|
+
return output
|
|
157
|
+
|
|
158
|
+
def predict(self, data: np.ndarray) -> np.ndarray:
|
|
159
|
+
"""
|
|
160
|
+
Generate predictions for input data.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
data: Input data as numpy array. Shape: (n_samples, n_features)
|
|
164
|
+
or (n_samples, channels, height, width) for images.
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
Predictions as numpy array:
|
|
168
|
+
- Classification: probabilities of shape (n_samples, n_classes)
|
|
169
|
+
- Regression: values of shape (n_samples, n_outputs)
|
|
170
|
+
"""
|
|
171
|
+
data = np.array(data)
|
|
172
|
+
|
|
173
|
+
# Handle single instance
|
|
174
|
+
if data.ndim == 1:
|
|
175
|
+
data = data.reshape(1, -1)
|
|
176
|
+
|
|
177
|
+
n_samples = data.shape[0]
|
|
178
|
+
outputs = []
|
|
179
|
+
|
|
180
|
+
with torch.no_grad():
|
|
181
|
+
for i in range(0, n_samples, self.batch_size):
|
|
182
|
+
batch = data[i:i + self.batch_size]
|
|
183
|
+
tensor_batch = self._to_tensor(batch)
|
|
184
|
+
|
|
185
|
+
output = self.model(tensor_batch)
|
|
186
|
+
output = self._apply_activation(output)
|
|
187
|
+
outputs.append(self._to_numpy(output))
|
|
188
|
+
|
|
189
|
+
return np.vstack(outputs)
|
|
190
|
+
|
|
191
|
+
def predict_with_gradients(
|
|
192
|
+
self,
|
|
193
|
+
data: np.ndarray,
|
|
194
|
+
target_class: Optional[int] = None
|
|
195
|
+
) -> tuple:
|
|
196
|
+
"""
|
|
197
|
+
Generate predictions and compute gradients w.r.t. inputs.
|
|
198
|
+
|
|
199
|
+
This is essential for gradient-based attribution methods like
|
|
200
|
+
Integrated Gradients, GradCAM, and Saliency Maps.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
data: Input data as numpy array.
|
|
204
|
+
target_class: Class index for gradient computation.
|
|
205
|
+
If None, uses the predicted class.
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
Tuple of (predictions, gradients) as numpy arrays.
|
|
209
|
+
"""
|
|
210
|
+
data = np.array(data)
|
|
211
|
+
if data.ndim == 1:
|
|
212
|
+
data = data.reshape(1, -1)
|
|
213
|
+
|
|
214
|
+
# Convert to tensor with gradient tracking
|
|
215
|
+
tensor_data = self._to_tensor(data)
|
|
216
|
+
tensor_data.requires_grad_(True)
|
|
217
|
+
|
|
218
|
+
# Forward pass
|
|
219
|
+
output = self.model(tensor_data)
|
|
220
|
+
activated_output = self._apply_activation(output)
|
|
221
|
+
|
|
222
|
+
# Determine target for gradient
|
|
223
|
+
if self.task == "classification":
|
|
224
|
+
if target_class is None:
|
|
225
|
+
target_class = output.argmax(dim=-1)
|
|
226
|
+
elif isinstance(target_class, int):
|
|
227
|
+
target_class = torch.tensor([target_class] * data.shape[0], device=self.device)
|
|
228
|
+
|
|
229
|
+
# Select target class scores for gradient
|
|
230
|
+
target_scores = output.gather(1, target_class.view(-1, 1)).squeeze()
|
|
231
|
+
else:
|
|
232
|
+
# Regression: gradient w.r.t. output
|
|
233
|
+
target_scores = output.squeeze()
|
|
234
|
+
|
|
235
|
+
# Backward pass
|
|
236
|
+
if target_scores.dim() == 0:
|
|
237
|
+
target_scores.backward()
|
|
238
|
+
else:
|
|
239
|
+
target_scores.sum().backward()
|
|
240
|
+
|
|
241
|
+
gradients = tensor_data.grad
|
|
242
|
+
|
|
243
|
+
return (
|
|
244
|
+
self._to_numpy(activated_output),
|
|
245
|
+
self._to_numpy(gradients)
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
def get_layer_output(
|
|
249
|
+
self,
|
|
250
|
+
data: np.ndarray,
|
|
251
|
+
layer_name: str
|
|
252
|
+
) -> np.ndarray:
|
|
253
|
+
"""
|
|
254
|
+
Get intermediate layer activations.
|
|
255
|
+
|
|
256
|
+
Useful for methods like GradCAM that need feature map activations.
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
data: Input data as numpy array.
|
|
260
|
+
layer_name: Name of the layer to extract (as registered in model).
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
Layer activations as numpy array.
|
|
264
|
+
"""
|
|
265
|
+
data = np.array(data)
|
|
266
|
+
if data.ndim == 1:
|
|
267
|
+
data = data.reshape(1, -1)
|
|
268
|
+
|
|
269
|
+
activations = {}
|
|
270
|
+
|
|
271
|
+
def hook_fn(module, input, output):
|
|
272
|
+
activations['output'] = output
|
|
273
|
+
|
|
274
|
+
# Find and hook the layer
|
|
275
|
+
layer = dict(self.model.named_modules()).get(layer_name)
|
|
276
|
+
if layer is None:
|
|
277
|
+
available = list(dict(self.model.named_modules()).keys())
|
|
278
|
+
raise ValueError(
|
|
279
|
+
f"Layer '{layer_name}' not found. Available layers: {available}"
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
handle = layer.register_forward_hook(hook_fn)
|
|
283
|
+
|
|
284
|
+
try:
|
|
285
|
+
with torch.no_grad():
|
|
286
|
+
tensor_data = self._to_tensor(data)
|
|
287
|
+
_ = self.model(tensor_data)
|
|
288
|
+
finally:
|
|
289
|
+
handle.remove()
|
|
290
|
+
|
|
291
|
+
return self._to_numpy(activations['output'])
|
|
292
|
+
|
|
293
|
+
def get_layer_gradients(
|
|
294
|
+
self,
|
|
295
|
+
data: np.ndarray,
|
|
296
|
+
layer_name: str,
|
|
297
|
+
target_class: Optional[int] = None
|
|
298
|
+
) -> tuple:
|
|
299
|
+
"""
|
|
300
|
+
Get gradients of output w.r.t. a specific layer's activations.
|
|
301
|
+
|
|
302
|
+
Essential for GradCAM and similar visualization methods.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
data: Input data as numpy array.
|
|
306
|
+
layer_name: Name of the layer for gradient computation.
|
|
307
|
+
target_class: Target class for gradient (classification).
|
|
308
|
+
|
|
309
|
+
Returns:
|
|
310
|
+
Tuple of (layer_activations, layer_gradients) as numpy arrays.
|
|
311
|
+
"""
|
|
312
|
+
data = np.array(data)
|
|
313
|
+
if data.ndim == 1:
|
|
314
|
+
data = data.reshape(1, -1)
|
|
315
|
+
|
|
316
|
+
activations = {}
|
|
317
|
+
gradients = {}
|
|
318
|
+
|
|
319
|
+
def forward_hook(module, input, output):
|
|
320
|
+
activations['output'] = output
|
|
321
|
+
|
|
322
|
+
def backward_hook(module, grad_input, grad_output):
|
|
323
|
+
gradients['output'] = grad_output[0]
|
|
324
|
+
|
|
325
|
+
# Find and hook the layer
|
|
326
|
+
layer = dict(self.model.named_modules()).get(layer_name)
|
|
327
|
+
if layer is None:
|
|
328
|
+
available = list(dict(self.model.named_modules()).keys())
|
|
329
|
+
raise ValueError(
|
|
330
|
+
f"Layer '{layer_name}' not found. Available layers: {available}"
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
forward_handle = layer.register_forward_hook(forward_hook)
|
|
334
|
+
backward_handle = layer.register_full_backward_hook(backward_hook)
|
|
335
|
+
|
|
336
|
+
try:
|
|
337
|
+
tensor_data = self._to_tensor(data)
|
|
338
|
+
tensor_data.requires_grad_(True)
|
|
339
|
+
|
|
340
|
+
output = self.model(tensor_data)
|
|
341
|
+
|
|
342
|
+
if self.task == "classification":
|
|
343
|
+
if target_class is None:
|
|
344
|
+
target_class = output.argmax(dim=-1)
|
|
345
|
+
elif isinstance(target_class, int):
|
|
346
|
+
target_class = torch.tensor([target_class] * data.shape[0], device=self.device)
|
|
347
|
+
|
|
348
|
+
target_scores = output.gather(1, target_class.view(-1, 1)).squeeze()
|
|
349
|
+
else:
|
|
350
|
+
target_scores = output.squeeze()
|
|
351
|
+
|
|
352
|
+
if target_scores.dim() == 0:
|
|
353
|
+
target_scores.backward()
|
|
354
|
+
else:
|
|
355
|
+
target_scores.sum().backward()
|
|
356
|
+
finally:
|
|
357
|
+
forward_handle.remove()
|
|
358
|
+
backward_handle.remove()
|
|
359
|
+
|
|
360
|
+
return (
|
|
361
|
+
self._to_numpy(activations['output']),
|
|
362
|
+
self._to_numpy(gradients['output'])
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
def list_layers(self) -> List[str]:
|
|
366
|
+
"""
|
|
367
|
+
List all named layers/modules in the model.
|
|
368
|
+
|
|
369
|
+
Returns:
|
|
370
|
+
List of layer names that can be used with get_layer_output/gradients.
|
|
371
|
+
"""
|
|
372
|
+
return [name for name, _ in self.model.named_modules() if name]
|
|
373
|
+
|
|
374
|
+
def to(self, device: str) -> "PyTorchAdapter":
|
|
375
|
+
"""
|
|
376
|
+
Move the model to a different device.
|
|
377
|
+
|
|
378
|
+
Args:
|
|
379
|
+
device: Target device ("cpu", "cuda", "cuda:0", etc.)
|
|
380
|
+
|
|
381
|
+
Returns:
|
|
382
|
+
Self for chaining.
|
|
383
|
+
"""
|
|
384
|
+
self.device = torch.device(device)
|
|
385
|
+
self.model = self.model.to(self.device)
|
|
386
|
+
return self
|
|
387
|
+
|
|
388
|
+
def train_mode(self) -> "PyTorchAdapter":
|
|
389
|
+
"""Set model to training mode (enables dropout, batchnorm updates)."""
|
|
390
|
+
self.model.train()
|
|
391
|
+
return self
|
|
392
|
+
|
|
393
|
+
def eval_mode(self) -> "PyTorchAdapter":
|
|
394
|
+
"""Set model to evaluation mode (disables dropout, freezes batchnorm)."""
|
|
395
|
+
self.model.eval()
|
|
396
|
+
return self
|
explainiverse/core/registry.py
CHANGED
|
@@ -362,6 +362,7 @@ def _create_default_registry() -> ExplainerRegistry:
|
|
|
362
362
|
"""Create and populate the default global registry."""
|
|
363
363
|
from explainiverse.explainers.attribution.lime_wrapper import LimeExplainer
|
|
364
364
|
from explainiverse.explainers.attribution.shap_wrapper import ShapExplainer
|
|
365
|
+
from explainiverse.explainers.attribution.treeshap_wrapper import TreeShapExplainer
|
|
365
366
|
from explainiverse.explainers.rule_based.anchors_wrapper import AnchorsExplainer
|
|
366
367
|
from explainiverse.explainers.global_explainers.permutation_importance import PermutationImportanceExplainer
|
|
367
368
|
from explainiverse.explainers.global_explainers.partial_dependence import PartialDependenceExplainer
|
|
@@ -409,6 +410,23 @@ def _create_default_registry() -> ExplainerRegistry:
|
|
|
409
410
|
)
|
|
410
411
|
)
|
|
411
412
|
|
|
413
|
+
# Register TreeSHAP (optimized for tree models)
|
|
414
|
+
registry.register(
|
|
415
|
+
name="treeshap",
|
|
416
|
+
explainer_class=TreeShapExplainer,
|
|
417
|
+
meta=ExplainerMeta(
|
|
418
|
+
scope="local",
|
|
419
|
+
model_types=["tree", "ensemble"],
|
|
420
|
+
data_types=["tabular"],
|
|
421
|
+
task_types=["classification", "regression"],
|
|
422
|
+
description="TreeSHAP - exact SHAP values for tree-based models (RandomForest, XGBoost, etc.)",
|
|
423
|
+
paper_reference="Lundberg et al., 2018 - 'Consistent Individualized Feature Attribution for Tree Ensembles'",
|
|
424
|
+
complexity="O(TLD^2) - polynomial in tree depth",
|
|
425
|
+
requires_training_data=False,
|
|
426
|
+
supports_batching=True
|
|
427
|
+
)
|
|
428
|
+
)
|
|
429
|
+
|
|
412
430
|
# Register Anchors
|
|
413
431
|
registry.register(
|
|
414
432
|
name="anchors",
|
|
@@ -4,7 +4,8 @@ Explainiverse Explainers - comprehensive XAI method implementations.
|
|
|
4
4
|
|
|
5
5
|
Local Explainers (instance-level):
|
|
6
6
|
- LIME: Local Interpretable Model-agnostic Explanations
|
|
7
|
-
- SHAP: SHapley Additive exPlanations
|
|
7
|
+
- SHAP: SHapley Additive exPlanations (KernelSHAP - model-agnostic)
|
|
8
|
+
- TreeSHAP: Optimized exact SHAP for tree-based models
|
|
8
9
|
- Anchors: High-precision rule-based explanations
|
|
9
10
|
- Counterfactual: Diverse counterfactual explanations
|
|
10
11
|
|
|
@@ -17,6 +18,7 @@ Global Explainers (model-level):
|
|
|
17
18
|
|
|
18
19
|
from explainiverse.explainers.attribution.lime_wrapper import LimeExplainer
|
|
19
20
|
from explainiverse.explainers.attribution.shap_wrapper import ShapExplainer
|
|
21
|
+
from explainiverse.explainers.attribution.treeshap_wrapper import TreeShapExplainer
|
|
20
22
|
from explainiverse.explainers.rule_based.anchors_wrapper import AnchorsExplainer
|
|
21
23
|
from explainiverse.explainers.counterfactual.dice_wrapper import CounterfactualExplainer
|
|
22
24
|
from explainiverse.explainers.global_explainers.permutation_importance import PermutationImportanceExplainer
|
|
@@ -28,6 +30,7 @@ __all__ = [
|
|
|
28
30
|
# Local explainers
|
|
29
31
|
"LimeExplainer",
|
|
30
32
|
"ShapExplainer",
|
|
33
|
+
"TreeShapExplainer",
|
|
31
34
|
"AnchorsExplainer",
|
|
32
35
|
"CounterfactualExplainer",
|
|
33
36
|
# Global explainers
|
|
@@ -5,5 +5,6 @@ Attribution-based explainers - feature importance explanations.
|
|
|
5
5
|
|
|
6
6
|
from explainiverse.explainers.attribution.lime_wrapper import LimeExplainer
|
|
7
7
|
from explainiverse.explainers.attribution.shap_wrapper import ShapExplainer
|
|
8
|
+
from explainiverse.explainers.attribution.treeshap_wrapper import TreeShapExplainer
|
|
8
9
|
|
|
9
|
-
__all__ = ["LimeExplainer", "ShapExplainer"]
|
|
10
|
+
__all__ = ["LimeExplainer", "ShapExplainer", "TreeShapExplainer"]
|
|
@@ -0,0 +1,434 @@
|
|
|
1
|
+
# src/explainiverse/explainers/attribution/treeshap_wrapper.py
|
|
2
|
+
"""
|
|
3
|
+
TreeSHAP Explainer - Optimized SHAP for Tree-based Models.
|
|
4
|
+
|
|
5
|
+
TreeSHAP computes exact SHAP values in polynomial time for tree-based models,
|
|
6
|
+
making it significantly faster than KernelSHAP while providing exact (not
|
|
7
|
+
approximate) Shapley values.
|
|
8
|
+
|
|
9
|
+
Reference:
|
|
10
|
+
Lundberg, S.M., Erion, G.G., & Lee, S.I. (2018). Consistent Individualized
|
|
11
|
+
Feature Attribution for Tree Ensembles. arXiv:1802.03888.
|
|
12
|
+
|
|
13
|
+
Supported Models:
|
|
14
|
+
- scikit-learn: RandomForest, GradientBoosting, DecisionTree, ExtraTrees
|
|
15
|
+
- XGBoost: XGBClassifier, XGBRegressor
|
|
16
|
+
- LightGBM: LGBMClassifier, LGBMRegressor (if installed)
|
|
17
|
+
- CatBoost: CatBoostClassifier, CatBoostRegressor (if installed)
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
import numpy as np
|
|
21
|
+
import shap
|
|
22
|
+
from typing import List, Optional, Union
|
|
23
|
+
|
|
24
|
+
from explainiverse.core.explainer import BaseExplainer
|
|
25
|
+
from explainiverse.core.explanation import Explanation
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
# Tree-based model types that TreeSHAP supports
|
|
29
|
+
SUPPORTED_TREE_MODELS = (
|
|
30
|
+
"RandomForestClassifier",
|
|
31
|
+
"RandomForestRegressor",
|
|
32
|
+
"GradientBoostingClassifier",
|
|
33
|
+
"GradientBoostingRegressor",
|
|
34
|
+
"DecisionTreeClassifier",
|
|
35
|
+
"DecisionTreeRegressor",
|
|
36
|
+
"ExtraTreesClassifier",
|
|
37
|
+
"ExtraTreesRegressor",
|
|
38
|
+
"XGBClassifier",
|
|
39
|
+
"XGBRegressor",
|
|
40
|
+
"XGBRFClassifier",
|
|
41
|
+
"XGBRFRegressor",
|
|
42
|
+
"LGBMClassifier",
|
|
43
|
+
"LGBMRegressor",
|
|
44
|
+
"CatBoostClassifier",
|
|
45
|
+
"CatBoostRegressor",
|
|
46
|
+
"HistGradientBoostingClassifier",
|
|
47
|
+
"HistGradientBoostingRegressor",
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _is_tree_model(model) -> bool:
|
|
52
|
+
"""Check if a model is a supported tree-based model."""
|
|
53
|
+
model_name = type(model).__name__
|
|
54
|
+
return model_name in SUPPORTED_TREE_MODELS
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _get_raw_model(model):
|
|
58
|
+
"""
|
|
59
|
+
Extract the raw model from an adapter if necessary.
|
|
60
|
+
|
|
61
|
+
TreeExplainer needs the actual sklearn/xgboost model, not an adapter.
|
|
62
|
+
"""
|
|
63
|
+
# If it's an adapter, get the underlying model
|
|
64
|
+
if hasattr(model, 'model'):
|
|
65
|
+
return model.model
|
|
66
|
+
return model
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class TreeShapExplainer(BaseExplainer):
|
|
70
|
+
"""
|
|
71
|
+
TreeSHAP explainer for tree-based models.
|
|
72
|
+
|
|
73
|
+
Uses SHAP's TreeExplainer to compute exact SHAP values in polynomial time.
|
|
74
|
+
This is significantly faster than KernelSHAP for supported tree models
|
|
75
|
+
and provides exact Shapley values rather than approximations.
|
|
76
|
+
|
|
77
|
+
Key advantages over KernelSHAP:
|
|
78
|
+
- Exact SHAP values (not approximations)
|
|
79
|
+
- O(TLD²) complexity vs O(TL2^M) for KernelSHAP
|
|
80
|
+
- Can compute interaction values
|
|
81
|
+
- No background data sampling needed
|
|
82
|
+
|
|
83
|
+
Attributes:
|
|
84
|
+
model: The tree-based model (sklearn, XGBoost, LightGBM, or CatBoost)
|
|
85
|
+
feature_names: List of feature names
|
|
86
|
+
class_names: List of class names for classification
|
|
87
|
+
explainer: The underlying SHAP TreeExplainer
|
|
88
|
+
task: "classification" or "regression"
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
def __init__(
|
|
92
|
+
self,
|
|
93
|
+
model,
|
|
94
|
+
feature_names: List[str],
|
|
95
|
+
class_names: Optional[List[str]] = None,
|
|
96
|
+
background_data: Optional[np.ndarray] = None,
|
|
97
|
+
task: str = "classification",
|
|
98
|
+
model_output: str = "auto",
|
|
99
|
+
feature_perturbation: str = "tree_path_dependent"
|
|
100
|
+
):
|
|
101
|
+
"""
|
|
102
|
+
Initialize the TreeSHAP explainer.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
model: A tree-based model or adapter containing one.
|
|
106
|
+
Supported: RandomForest, GradientBoosting, XGBoost,
|
|
107
|
+
LightGBM, CatBoost, DecisionTree, ExtraTrees.
|
|
108
|
+
feature_names: List of feature names.
|
|
109
|
+
class_names: List of class names (for classification).
|
|
110
|
+
background_data: Optional background dataset for interventional
|
|
111
|
+
feature perturbation. If None, uses tree_path_dependent.
|
|
112
|
+
task: "classification" or "regression".
|
|
113
|
+
model_output: How to transform model output. Options:
|
|
114
|
+
- "auto": Automatically detect
|
|
115
|
+
- "raw": Raw model output
|
|
116
|
+
- "probability": Probability output (classification)
|
|
117
|
+
- "log_loss": Log loss output
|
|
118
|
+
feature_perturbation: Method for handling feature perturbation:
|
|
119
|
+
- "tree_path_dependent": Fast, uses tree structure
|
|
120
|
+
- "interventional": Slower, requires background data
|
|
121
|
+
"""
|
|
122
|
+
# Extract raw model if wrapped in adapter
|
|
123
|
+
raw_model = _get_raw_model(model)
|
|
124
|
+
|
|
125
|
+
# Validate that it's a supported tree model
|
|
126
|
+
if not _is_tree_model(raw_model):
|
|
127
|
+
model_type = type(raw_model).__name__
|
|
128
|
+
raise ValueError(
|
|
129
|
+
f"TreeSHAP requires a tree-based model. Got {model_type}. "
|
|
130
|
+
f"Supported models: {', '.join(SUPPORTED_TREE_MODELS[:6])}..."
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
super().__init__(model)
|
|
134
|
+
self.raw_model = raw_model
|
|
135
|
+
self.feature_names = list(feature_names)
|
|
136
|
+
self.class_names = list(class_names) if class_names else None
|
|
137
|
+
self.task = task
|
|
138
|
+
self.model_output = model_output
|
|
139
|
+
self.feature_perturbation = feature_perturbation
|
|
140
|
+
|
|
141
|
+
# Create TreeExplainer
|
|
142
|
+
explainer_kwargs = {}
|
|
143
|
+
|
|
144
|
+
if feature_perturbation == "interventional" and background_data is not None:
|
|
145
|
+
explainer_kwargs["data"] = background_data
|
|
146
|
+
explainer_kwargs["feature_perturbation"] = "interventional"
|
|
147
|
+
|
|
148
|
+
if model_output != "auto":
|
|
149
|
+
explainer_kwargs["model_output"] = model_output
|
|
150
|
+
|
|
151
|
+
self.explainer = shap.TreeExplainer(raw_model, **explainer_kwargs)
|
|
152
|
+
self.background_data = background_data
|
|
153
|
+
|
|
154
|
+
def explain(
|
|
155
|
+
self,
|
|
156
|
+
instance: np.ndarray,
|
|
157
|
+
target_class: Optional[int] = None,
|
|
158
|
+
check_additivity: bool = False
|
|
159
|
+
) -> Explanation:
|
|
160
|
+
"""
|
|
161
|
+
Generate TreeSHAP explanation for a single instance.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
instance: 1D numpy array of input features.
|
|
165
|
+
target_class: For multi-class, which class to explain.
|
|
166
|
+
If None, uses the predicted class.
|
|
167
|
+
check_additivity: Whether to verify SHAP values sum to
|
|
168
|
+
prediction - expected_value.
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
Explanation object with feature attributions.
|
|
172
|
+
"""
|
|
173
|
+
instance = np.array(instance).flatten()
|
|
174
|
+
instance_2d = instance.reshape(1, -1)
|
|
175
|
+
|
|
176
|
+
# Compute SHAP values
|
|
177
|
+
shap_values = self.explainer.shap_values(
|
|
178
|
+
instance_2d,
|
|
179
|
+
check_additivity=check_additivity
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# Handle different output formats
|
|
183
|
+
if isinstance(shap_values, list):
|
|
184
|
+
# Multi-class classification: list of arrays, one per class
|
|
185
|
+
n_classes = len(shap_values)
|
|
186
|
+
|
|
187
|
+
if target_class is None:
|
|
188
|
+
# Use predicted class
|
|
189
|
+
if hasattr(self.raw_model, 'predict'):
|
|
190
|
+
pred = self.raw_model.predict(instance_2d)[0]
|
|
191
|
+
target_class = int(pred)
|
|
192
|
+
else:
|
|
193
|
+
target_class = 0
|
|
194
|
+
|
|
195
|
+
# Ensure target_class is valid
|
|
196
|
+
target_class = min(target_class, n_classes - 1)
|
|
197
|
+
class_shap = shap_values[target_class][0]
|
|
198
|
+
|
|
199
|
+
# Get class name
|
|
200
|
+
if self.class_names and target_class < len(self.class_names):
|
|
201
|
+
label_name = self.class_names[target_class]
|
|
202
|
+
else:
|
|
203
|
+
label_name = f"class_{target_class}"
|
|
204
|
+
|
|
205
|
+
# Store all class SHAP values for reference
|
|
206
|
+
all_class_shap = {
|
|
207
|
+
(self.class_names[i] if self.class_names and i < len(self.class_names)
|
|
208
|
+
else f"class_{i}"): shap_values[i][0].tolist()
|
|
209
|
+
for i in range(n_classes)
|
|
210
|
+
}
|
|
211
|
+
else:
|
|
212
|
+
# Binary classification or regression
|
|
213
|
+
class_shap = shap_values[0] if shap_values.ndim > 1 else shap_values.flatten()
|
|
214
|
+
label_name = self.class_names[1] if self.class_names and len(self.class_names) > 1 else "output"
|
|
215
|
+
all_class_shap = None
|
|
216
|
+
|
|
217
|
+
# Build attributions dict
|
|
218
|
+
flat_shap = np.array(class_shap).flatten()
|
|
219
|
+
attributions = {
|
|
220
|
+
fname: float(flat_shap[i])
|
|
221
|
+
for i, fname in enumerate(self.feature_names)
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
# Get expected value (base value)
|
|
225
|
+
expected_value = self.explainer.expected_value
|
|
226
|
+
if isinstance(expected_value, (list, np.ndarray)):
|
|
227
|
+
if target_class is not None and target_class < len(expected_value):
|
|
228
|
+
base_value = float(expected_value[target_class])
|
|
229
|
+
else:
|
|
230
|
+
base_value = float(expected_value[0])
|
|
231
|
+
else:
|
|
232
|
+
base_value = float(expected_value)
|
|
233
|
+
|
|
234
|
+
explanation_data = {
|
|
235
|
+
"feature_attributions": attributions,
|
|
236
|
+
"base_value": base_value,
|
|
237
|
+
"shap_values_raw": flat_shap.tolist(),
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
if all_class_shap is not None:
|
|
241
|
+
explanation_data["all_class_shap_values"] = all_class_shap
|
|
242
|
+
|
|
243
|
+
return Explanation(
|
|
244
|
+
explainer_name="TreeSHAP",
|
|
245
|
+
target_class=label_name,
|
|
246
|
+
explanation_data=explanation_data
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
def explain_batch(
|
|
250
|
+
self,
|
|
251
|
+
X: np.ndarray,
|
|
252
|
+
target_class: Optional[int] = None,
|
|
253
|
+
check_additivity: bool = False
|
|
254
|
+
) -> List[Explanation]:
|
|
255
|
+
"""
|
|
256
|
+
Generate TreeSHAP explanations for multiple instances efficiently.
|
|
257
|
+
|
|
258
|
+
TreeSHAP can process batches more efficiently than individual calls.
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
X: 2D numpy array of instances (n_samples, n_features).
|
|
262
|
+
target_class: For multi-class, which class to explain.
|
|
263
|
+
check_additivity: Whether to verify SHAP value additivity.
|
|
264
|
+
|
|
265
|
+
Returns:
|
|
266
|
+
List of Explanation objects.
|
|
267
|
+
"""
|
|
268
|
+
X = np.array(X)
|
|
269
|
+
if X.ndim == 1:
|
|
270
|
+
X = X.reshape(1, -1)
|
|
271
|
+
|
|
272
|
+
# Compute SHAP values for all instances at once
|
|
273
|
+
shap_values = self.explainer.shap_values(X, check_additivity=check_additivity)
|
|
274
|
+
|
|
275
|
+
explanations = []
|
|
276
|
+
for i in range(X.shape[0]):
|
|
277
|
+
if isinstance(shap_values, list):
|
|
278
|
+
# Multi-class
|
|
279
|
+
n_classes = len(shap_values)
|
|
280
|
+
tc = target_class if target_class is not None else 0
|
|
281
|
+
tc = min(tc, n_classes - 1)
|
|
282
|
+
class_shap = shap_values[tc][i]
|
|
283
|
+
|
|
284
|
+
if self.class_names and tc < len(self.class_names):
|
|
285
|
+
label_name = self.class_names[tc]
|
|
286
|
+
else:
|
|
287
|
+
label_name = f"class_{tc}"
|
|
288
|
+
else:
|
|
289
|
+
class_shap = shap_values[i]
|
|
290
|
+
label_name = self.class_names[1] if self.class_names and len(self.class_names) > 1 else "output"
|
|
291
|
+
|
|
292
|
+
flat_shap = np.array(class_shap).flatten()
|
|
293
|
+
attributions = {
|
|
294
|
+
fname: float(flat_shap[j])
|
|
295
|
+
for j, fname in enumerate(self.feature_names)
|
|
296
|
+
}
|
|
297
|
+
|
|
298
|
+
expected_value = self.explainer.expected_value
|
|
299
|
+
if isinstance(expected_value, (list, np.ndarray)):
|
|
300
|
+
tc = target_class if target_class is not None else 0
|
|
301
|
+
base_value = float(expected_value[min(tc, len(expected_value) - 1)])
|
|
302
|
+
else:
|
|
303
|
+
base_value = float(expected_value)
|
|
304
|
+
|
|
305
|
+
explanations.append(Explanation(
|
|
306
|
+
explainer_name="TreeSHAP",
|
|
307
|
+
target_class=label_name,
|
|
308
|
+
explanation_data={
|
|
309
|
+
"feature_attributions": attributions,
|
|
310
|
+
"base_value": base_value,
|
|
311
|
+
"shap_values_raw": flat_shap.tolist(),
|
|
312
|
+
}
|
|
313
|
+
))
|
|
314
|
+
|
|
315
|
+
return explanations
|
|
316
|
+
|
|
317
|
+
def explain_interactions(
|
|
318
|
+
self,
|
|
319
|
+
instance: np.ndarray,
|
|
320
|
+
target_class: Optional[int] = None
|
|
321
|
+
) -> Explanation:
|
|
322
|
+
"""
|
|
323
|
+
Compute SHAP interaction values for an instance.
|
|
324
|
+
|
|
325
|
+
Interaction values show how pairs of features jointly contribute
|
|
326
|
+
to the prediction. The diagonal contains main effects.
|
|
327
|
+
|
|
328
|
+
Args:
|
|
329
|
+
instance: 1D numpy array of input features.
|
|
330
|
+
target_class: For multi-class, which class to explain.
|
|
331
|
+
|
|
332
|
+
Returns:
|
|
333
|
+
Explanation object with interaction matrix.
|
|
334
|
+
"""
|
|
335
|
+
instance = np.array(instance).flatten()
|
|
336
|
+
instance_2d = instance.reshape(1, -1)
|
|
337
|
+
|
|
338
|
+
# Compute interaction values
|
|
339
|
+
interaction_values = self.explainer.shap_interaction_values(instance_2d)
|
|
340
|
+
|
|
341
|
+
# Determine target class for prediction
|
|
342
|
+
if target_class is None and hasattr(self.raw_model, 'predict'):
|
|
343
|
+
target_class = int(self.raw_model.predict(instance_2d)[0])
|
|
344
|
+
elif target_class is None:
|
|
345
|
+
target_class = 0
|
|
346
|
+
|
|
347
|
+
# Handle different return formats from shap_interaction_values
|
|
348
|
+
if isinstance(interaction_values, list):
|
|
349
|
+
# Multi-class: list of arrays, one per class
|
|
350
|
+
n_classes = len(interaction_values)
|
|
351
|
+
tc = min(target_class, n_classes - 1)
|
|
352
|
+
interactions = np.array(interaction_values[tc][0])
|
|
353
|
+
|
|
354
|
+
if self.class_names and tc < len(self.class_names):
|
|
355
|
+
label_name = self.class_names[tc]
|
|
356
|
+
else:
|
|
357
|
+
label_name = f"class_{tc}"
|
|
358
|
+
elif interaction_values.ndim == 4:
|
|
359
|
+
# Shape: (n_samples, n_features, n_features, n_classes)
|
|
360
|
+
n_classes = interaction_values.shape[3]
|
|
361
|
+
tc = min(target_class, n_classes - 1)
|
|
362
|
+
interactions = interaction_values[0, :, :, tc]
|
|
363
|
+
|
|
364
|
+
if self.class_names and tc < len(self.class_names):
|
|
365
|
+
label_name = self.class_names[tc]
|
|
366
|
+
else:
|
|
367
|
+
label_name = f"class_{tc}"
|
|
368
|
+
else:
|
|
369
|
+
# Binary or regression: (n_samples, n_features, n_features)
|
|
370
|
+
interactions = interaction_values[0]
|
|
371
|
+
label_name = self.class_names[1] if self.class_names and len(self.class_names) > 1 else "output"
|
|
372
|
+
|
|
373
|
+
# Ensure interactions is 2D (n_features x n_features)
|
|
374
|
+
interactions = np.array(interactions)
|
|
375
|
+
if interactions.ndim > 2:
|
|
376
|
+
# If still multi-dimensional, take first slice
|
|
377
|
+
interactions = interactions[:, :, 0] if interactions.ndim == 3 else interactions
|
|
378
|
+
|
|
379
|
+
# Build interaction dict with feature name pairs
|
|
380
|
+
n_features = len(self.feature_names)
|
|
381
|
+
interaction_dict = {}
|
|
382
|
+
main_effects = {}
|
|
383
|
+
|
|
384
|
+
for i in range(n_features):
|
|
385
|
+
fname_i = self.feature_names[i]
|
|
386
|
+
val = interactions[i, i]
|
|
387
|
+
main_effects[fname_i] = float(val) if np.isscalar(val) or val.size == 1 else float(val.flat[0])
|
|
388
|
+
|
|
389
|
+
for j in range(i + 1, n_features):
|
|
390
|
+
fname_j = self.feature_names[j]
|
|
391
|
+
# Interaction values are symmetric, so we sum both directions
|
|
392
|
+
val_ij = interactions[i, j]
|
|
393
|
+
val_ji = interactions[j, i]
|
|
394
|
+
ij = float(val_ij) if np.isscalar(val_ij) or val_ij.size == 1 else float(val_ij.flat[0])
|
|
395
|
+
ji = float(val_ji) if np.isscalar(val_ji) or val_ji.size == 1 else float(val_ji.flat[0])
|
|
396
|
+
interaction_dict[f"{fname_i} x {fname_j}"] = ij + ji
|
|
397
|
+
|
|
398
|
+
# Sort interactions by absolute value
|
|
399
|
+
sorted_interactions = dict(sorted(
|
|
400
|
+
interaction_dict.items(),
|
|
401
|
+
key=lambda x: abs(x[1]),
|
|
402
|
+
reverse=True
|
|
403
|
+
))
|
|
404
|
+
|
|
405
|
+
return Explanation(
|
|
406
|
+
explainer_name="TreeSHAP_Interactions",
|
|
407
|
+
target_class=label_name,
|
|
408
|
+
explanation_data={
|
|
409
|
+
"feature_attributions": main_effects,
|
|
410
|
+
"interactions": sorted_interactions,
|
|
411
|
+
"interaction_matrix": interactions.tolist(),
|
|
412
|
+
"feature_names": self.feature_names
|
|
413
|
+
}
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
def get_expected_value(self, target_class: Optional[int] = None) -> float:
|
|
417
|
+
"""
|
|
418
|
+
Get the expected (base) value of the model.
|
|
419
|
+
|
|
420
|
+
This is the average model output over the background dataset.
|
|
421
|
+
|
|
422
|
+
Args:
|
|
423
|
+
target_class: For multi-class, which class's expected value.
|
|
424
|
+
|
|
425
|
+
Returns:
|
|
426
|
+
The expected value as a float.
|
|
427
|
+
"""
|
|
428
|
+
expected_value = self.explainer.expected_value
|
|
429
|
+
|
|
430
|
+
if isinstance(expected_value, (list, np.ndarray)):
|
|
431
|
+
tc = target_class if target_class is not None else 0
|
|
432
|
+
return float(expected_value[min(tc, len(expected_value) - 1)])
|
|
433
|
+
|
|
434
|
+
return float(expected_value)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: explainiverse
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.2
|
|
4
4
|
Summary: Unified, extensible explainability framework supporting LIME, SHAP, Anchors, Counterfactuals, PDP, ALE, SAGE, and more
|
|
5
5
|
Home-page: https://github.com/jemsbhai/explainiverse
|
|
6
6
|
License: MIT
|
|
@@ -17,11 +17,13 @@ Classifier: Programming Language :: Python :: 3.10
|
|
|
17
17
|
Classifier: Programming Language :: Python :: 3.11
|
|
18
18
|
Classifier: Programming Language :: Python :: 3.12
|
|
19
19
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
20
|
+
Provides-Extra: torch
|
|
20
21
|
Requires-Dist: lime (>=0.2.0.1,<0.3.0.0)
|
|
21
22
|
Requires-Dist: numpy (>=1.24,<2.0)
|
|
22
23
|
Requires-Dist: scikit-learn (>=1.1,<1.6)
|
|
23
24
|
Requires-Dist: scipy (>=1.10,<2.0)
|
|
24
25
|
Requires-Dist: shap (>=0.48.0,<0.49.0)
|
|
26
|
+
Requires-Dist: torch (>=2.0) ; extra == "torch"
|
|
25
27
|
Requires-Dist: xgboost (>=1.7,<3.0)
|
|
26
28
|
Project-URL: Repository, https://github.com/jemsbhai/explainiverse
|
|
27
29
|
Description-Content-Type: text/markdown
|
|
@@ -29,7 +31,7 @@ Description-Content-Type: text/markdown
|
|
|
29
31
|
# Explainiverse
|
|
30
32
|
|
|
31
33
|
**Explainiverse** is a unified, extensible Python framework for Explainable AI (XAI).
|
|
32
|
-
It provides a standardized interface for model-agnostic explainability with
|
|
34
|
+
It provides a standardized interface for model-agnostic explainability with 9 state-of-the-art XAI methods, evaluation metrics, and a plugin registry for easy extensibility.
|
|
33
35
|
|
|
34
36
|
---
|
|
35
37
|
|
|
@@ -40,6 +42,7 @@ It provides a standardized interface for model-agnostic explainability with 8 st
|
|
|
40
42
|
**Local Explainers** (instance-level explanations):
|
|
41
43
|
- **LIME** - Local Interpretable Model-agnostic Explanations ([Ribeiro et al., 2016](https://arxiv.org/abs/1602.04938))
|
|
42
44
|
- **SHAP** - SHapley Additive exPlanations via KernelSHAP ([Lundberg & Lee, 2017](https://arxiv.org/abs/1705.07874))
|
|
45
|
+
- **TreeSHAP** - Exact SHAP values for tree models, 10x+ faster ([Lundberg et al., 2018](https://arxiv.org/abs/1802.03888))
|
|
43
46
|
- **Anchors** - High-precision rule-based explanations ([Ribeiro et al., 2018](https://ojs.aaai.org/index.php/AAAI/article/view/11491))
|
|
44
47
|
- **Counterfactual** - DiCE-style diverse counterfactual explanations ([Mothilal et al., 2020](https://arxiv.org/abs/1905.07697))
|
|
45
48
|
|
|
@@ -62,7 +65,7 @@ It provides a standardized interface for model-agnostic explainability with 8 st
|
|
|
62
65
|
### 🧪 Standardized Interface
|
|
63
66
|
- Consistent `BaseExplainer` API
|
|
64
67
|
- Unified `Explanation` output format
|
|
65
|
-
- Model adapters for sklearn and
|
|
68
|
+
- Model adapters for sklearn and PyTorch
|
|
66
69
|
|
|
67
70
|
---
|
|
68
71
|
|
|
@@ -74,6 +77,12 @@ From PyPI:
|
|
|
74
77
|
pip install explainiverse
|
|
75
78
|
```
|
|
76
79
|
|
|
80
|
+
With PyTorch support (for neural network explanations):
|
|
81
|
+
|
|
82
|
+
```bash
|
|
83
|
+
pip install explainiverse[torch]
|
|
84
|
+
```
|
|
85
|
+
|
|
77
86
|
For development:
|
|
78
87
|
|
|
79
88
|
```bash
|
|
@@ -100,7 +109,7 @@ adapter = SklearnAdapter(model, class_names=iris.target_names.tolist())
|
|
|
100
109
|
|
|
101
110
|
# List available explainers
|
|
102
111
|
print(default_registry.list_explainers())
|
|
103
|
-
# ['lime', 'shap', 'anchors', 'counterfactual', 'permutation_importance', 'partial_dependence', 'ale', 'sage']
|
|
112
|
+
# ['lime', 'shap', 'treeshap', 'anchors', 'counterfactual', 'permutation_importance', 'partial_dependence', 'ale', 'sage']
|
|
104
113
|
|
|
105
114
|
# Create and use an explainer
|
|
106
115
|
explainer = default_registry.create(
|
|
@@ -119,11 +128,11 @@ print(explanation.explanation_data["feature_attributions"])
|
|
|
119
128
|
```python
|
|
120
129
|
# Find local explainers for tabular data
|
|
121
130
|
local_tabular = default_registry.filter(scope="local", data_type="tabular")
|
|
122
|
-
print(local_tabular) # ['lime', 'shap', 'anchors', 'counterfactual']
|
|
131
|
+
print(local_tabular) # ['lime', 'shap', 'treeshap', 'anchors', 'counterfactual']
|
|
123
132
|
|
|
124
|
-
# Find
|
|
125
|
-
|
|
126
|
-
print(
|
|
133
|
+
# Find explainers optimized for tree models
|
|
134
|
+
tree_explainers = default_registry.filter(model_type="tree")
|
|
135
|
+
print(tree_explainers) # ['treeshap']
|
|
127
136
|
|
|
128
137
|
# Get recommendations
|
|
129
138
|
recommendations = default_registry.recommend(
|
|
@@ -133,6 +142,64 @@ recommendations = default_registry.recommend(
|
|
|
133
142
|
)
|
|
134
143
|
```
|
|
135
144
|
|
|
145
|
+
### TreeSHAP for Tree Models (10x+ Faster)
|
|
146
|
+
|
|
147
|
+
```python
|
|
148
|
+
from explainiverse.explainers import TreeShapExplainer
|
|
149
|
+
from sklearn.ensemble import RandomForestClassifier
|
|
150
|
+
|
|
151
|
+
# Train a tree-based model
|
|
152
|
+
model = RandomForestClassifier(n_estimators=100).fit(X_train, y_train)
|
|
153
|
+
|
|
154
|
+
# TreeSHAP works directly with the model (no adapter needed)
|
|
155
|
+
explainer = TreeShapExplainer(
|
|
156
|
+
model=model,
|
|
157
|
+
feature_names=feature_names,
|
|
158
|
+
class_names=class_names
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
# Single instance explanation
|
|
162
|
+
explanation = explainer.explain(X_test[0])
|
|
163
|
+
print(explanation.explanation_data["feature_attributions"])
|
|
164
|
+
|
|
165
|
+
# Batch explanations (efficient)
|
|
166
|
+
explanations = explainer.explain_batch(X_test[:10])
|
|
167
|
+
|
|
168
|
+
# Feature interactions
|
|
169
|
+
interactions = explainer.explain_interactions(X_test[0])
|
|
170
|
+
print(interactions.explanation_data["interaction_matrix"])
|
|
171
|
+
```
|
|
172
|
+
|
|
173
|
+
### PyTorch Adapter for Neural Networks
|
|
174
|
+
|
|
175
|
+
```python
|
|
176
|
+
from explainiverse import PyTorchAdapter
|
|
177
|
+
import torch.nn as nn
|
|
178
|
+
|
|
179
|
+
# Define a PyTorch model
|
|
180
|
+
model = nn.Sequential(
|
|
181
|
+
nn.Linear(10, 64),
|
|
182
|
+
nn.ReLU(),
|
|
183
|
+
nn.Linear(64, 3)
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
# Wrap with adapter
|
|
187
|
+
adapter = PyTorchAdapter(
|
|
188
|
+
model,
|
|
189
|
+
task="classification",
|
|
190
|
+
class_names=["cat", "dog", "bird"]
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
# Use with any explainer
|
|
194
|
+
predictions = adapter.predict(X) # Returns numpy array
|
|
195
|
+
|
|
196
|
+
# Get gradients for attribution methods
|
|
197
|
+
predictions, gradients = adapter.predict_with_gradients(X)
|
|
198
|
+
|
|
199
|
+
# Access intermediate layers
|
|
200
|
+
activations = adapter.get_layer_output(X, layer_name="0")
|
|
201
|
+
```
|
|
202
|
+
|
|
136
203
|
### Using Specific Explainers
|
|
137
204
|
|
|
138
205
|
```python
|
|
@@ -233,12 +300,14 @@ poetry run pytest tests/test_new_explainers.py -v
|
|
|
233
300
|
## Roadmap
|
|
234
301
|
|
|
235
302
|
- [x] LIME, SHAP (KernelSHAP)
|
|
303
|
+
- [x] TreeSHAP (optimized for tree models) ✅ NEW
|
|
236
304
|
- [x] Anchors, Counterfactuals
|
|
237
305
|
- [x] Permutation Importance, PDP, ALE, SAGE
|
|
238
306
|
- [x] Explainer Registry with filtering
|
|
239
|
-
- [
|
|
307
|
+
- [x] PyTorch Adapter ✅ NEW
|
|
240
308
|
- [ ] Integrated Gradients (gradient-based for neural nets)
|
|
241
|
-
- [ ]
|
|
309
|
+
- [ ] GradCAM for CNNs
|
|
310
|
+
- [ ] TensorFlow adapter
|
|
242
311
|
- [ ] Interactive visualization dashboard
|
|
243
312
|
|
|
244
313
|
---
|
|
@@ -1,19 +1,21 @@
|
|
|
1
|
-
explainiverse/__init__.py,sha256
|
|
2
|
-
explainiverse/adapters/__init__.py,sha256=
|
|
1
|
+
explainiverse/__init__.py,sha256=-4H6WbfGwpeoNpO9w0CEahKQBPsvIYe_lK5e10cZWD0,1612
|
|
2
|
+
explainiverse/adapters/__init__.py,sha256=HcQGISyp-YQ4jEj2IYveX_c9X5otLcTNWRnVRRhzRik,781
|
|
3
3
|
explainiverse/adapters/base_adapter.py,sha256=Nqt0GeDn_-PjTyJcZsE8dRTulavqFQsv8sMYWS_ps-M,603
|
|
4
|
+
explainiverse/adapters/pytorch_adapter.py,sha256=GTilJAR1VF_OgWG88qZoqlqefHaSXB3i9iOwCJkyHTg,13318
|
|
4
5
|
explainiverse/adapters/sklearn_adapter.py,sha256=pzIBtMuqrG-6ZbUqUCMt7rSk3Ow0FgrY268FSweFvw4,958
|
|
5
6
|
explainiverse/core/__init__.py,sha256=P3jHMnH5coFqTTO1w-gT-rurkCM1-9r3pF-055pbXMg,474
|
|
6
7
|
explainiverse/core/explainer.py,sha256=Z9on-9VblYDlQx9oBm1BHpmAf_NsQajZ3qr-u48Aejo,784
|
|
7
8
|
explainiverse/core/explanation.py,sha256=6zxFh_TH8tFHc-r_H5-WHQ05Sp1Kp2TxLz3gyFek5jo,881
|
|
8
|
-
explainiverse/core/registry.py,sha256=
|
|
9
|
+
explainiverse/core/registry.py,sha256=_BXWi1fJY3cGjYA1Xn1DwvY91jbpJrpX6_8EVzrRT20,19876
|
|
9
10
|
explainiverse/engine/__init__.py,sha256=1sZO8nH1mmwK2e-KUavBQm7zYDWUe27nyWoFy9tgsiA,197
|
|
10
11
|
explainiverse/engine/suite.py,sha256=sq8SK_6Pf0qRckTmVJ7Mdosu9bhkjAGPGN8ymLGFP9E,4914
|
|
11
12
|
explainiverse/evaluation/__init__.py,sha256=Y50L_b4HKthg4epwcayPHXh0l4i4MUuzvaNlqPmUNZY,212
|
|
12
13
|
explainiverse/evaluation/metrics.py,sha256=tSBXtyA_-0zOGCGjlPZU6LdGKRH_QpWfgKa78sdlovs,7453
|
|
13
|
-
explainiverse/explainers/__init__.py,sha256=
|
|
14
|
-
explainiverse/explainers/attribution/__init__.py,sha256=
|
|
14
|
+
explainiverse/explainers/__init__.py,sha256=Op-Z_BTJ7BdqA_9gTnruomN2-rKtrkPCt1Zq1iCzxr0,1758
|
|
15
|
+
explainiverse/explainers/attribution/__init__.py,sha256=YeVs9bS_IWDtqGbp6T37V6Zp5ZDWzLdAXHxxyFGpiQM,431
|
|
15
16
|
explainiverse/explainers/attribution/lime_wrapper.py,sha256=OnXIV7t6yd-vt38sIi7XmHFbgzlZfCEbRlFyGGd5XiE,3245
|
|
16
17
|
explainiverse/explainers/attribution/shap_wrapper.py,sha256=tKie5AvN7mb55PWOYdMvW0lUAYjfHPzYosEloEY2ZzI,3210
|
|
18
|
+
explainiverse/explainers/attribution/treeshap_wrapper.py,sha256=LcBjHzQjmeyWCwLXALJ0WFQ9ol-N_8dod577EDxFDKY,16758
|
|
17
19
|
explainiverse/explainers/counterfactual/__init__.py,sha256=gEV6P8h2fZ3-pv5rqp5sNDqrLErh5ntqpxIIBVCMFv4,247
|
|
18
20
|
explainiverse/explainers/counterfactual/dice_wrapper.py,sha256=PyJYF-z1nyyy0mFROnkJqPtcuT2PwEBARwfh37mZ5ew,11373
|
|
19
21
|
explainiverse/explainers/global_explainers/__init__.py,sha256=91xayho0r-fVeIxBLTxF-aBaBhRTRRXxGZ7oUHh7z64,713
|
|
@@ -23,7 +25,7 @@ explainiverse/explainers/global_explainers/permutation_importance.py,sha256=bcgK
|
|
|
23
25
|
explainiverse/explainers/global_explainers/sage.py,sha256=57Xw1SK529x5JXWt0TVrcFYUUP3C65LfUwgoM-Z3gaw,5839
|
|
24
26
|
explainiverse/explainers/rule_based/__init__.py,sha256=gKzlFCAzwurAMLJcuYgal4XhDj1thteBGcaHWmN7iWk,243
|
|
25
27
|
explainiverse/explainers/rule_based/anchors_wrapper.py,sha256=ML7W6aam-eMGZHy5ilol8qupZvNBJpYAFatEEPnuMyo,13254
|
|
26
|
-
explainiverse-0.2.
|
|
27
|
-
explainiverse-0.2.
|
|
28
|
-
explainiverse-0.2.
|
|
29
|
-
explainiverse-0.2.
|
|
28
|
+
explainiverse-0.2.2.dist-info/LICENSE,sha256=28rbHe8rJgmUlRdxJACfq1Sj-MtCEhyHxkJedQd1ZYA,1070
|
|
29
|
+
explainiverse-0.2.2.dist-info/METADATA,sha256=kis3ejJCLRhBJWf5p13FzY2ZeSbnWfJxk6LS1hd7A1w,9497
|
|
30
|
+
explainiverse-0.2.2.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
31
|
+
explainiverse-0.2.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|