explainiverse 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- explainiverse/__init__.py +5 -4
- explainiverse/adapters/pytorch_adapter.py +88 -25
- explainiverse/core/explanation.py +165 -10
- explainiverse/core/registry.py +18 -0
- explainiverse/engine/suite.py +187 -78
- explainiverse/evaluation/metrics.py +189 -108
- explainiverse/explainers/attribution/lime_wrapper.py +90 -7
- explainiverse/explainers/attribution/shap_wrapper.py +104 -8
- explainiverse/explainers/gradient/__init__.py +3 -0
- explainiverse/explainers/gradient/integrated_gradients.py +189 -76
- explainiverse/explainers/gradient/lrp.py +1206 -0
- {explainiverse-0.7.0.dist-info → explainiverse-0.8.0.dist-info}/METADATA +76 -13
- {explainiverse-0.7.0.dist-info → explainiverse-0.8.0.dist-info}/RECORD +15 -14
- {explainiverse-0.7.0.dist-info → explainiverse-0.8.0.dist-info}/LICENSE +0 -0
- {explainiverse-0.7.0.dist-info → explainiverse-0.8.0.dist-info}/WHEEL +0 -0
|
@@ -8,14 +8,36 @@ game-theoretic Shapley values, offering both local and global interpretability.
|
|
|
8
8
|
Reference:
|
|
9
9
|
Lundberg, S.M. & Lee, S.I. (2017). A Unified Approach to Interpreting
|
|
10
10
|
Model Predictions. NeurIPS 2017.
|
|
11
|
+
https://arxiv.org/abs/1705.07874
|
|
11
12
|
"""
|
|
12
13
|
|
|
13
|
-
import shap
|
|
14
14
|
import numpy as np
|
|
15
|
+
from typing import List, Optional
|
|
15
16
|
|
|
16
17
|
from explainiverse.core.explainer import BaseExplainer
|
|
17
18
|
from explainiverse.core.explanation import Explanation
|
|
18
19
|
|
|
20
|
+
# Lazy import check - don't import shap at module level
|
|
21
|
+
_SHAP_AVAILABLE = None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _check_shap_available():
|
|
25
|
+
"""Check if SHAP is available and raise ImportError if not."""
|
|
26
|
+
global _SHAP_AVAILABLE
|
|
27
|
+
|
|
28
|
+
if _SHAP_AVAILABLE is None:
|
|
29
|
+
try:
|
|
30
|
+
import shap
|
|
31
|
+
_SHAP_AVAILABLE = True
|
|
32
|
+
except ImportError:
|
|
33
|
+
_SHAP_AVAILABLE = False
|
|
34
|
+
|
|
35
|
+
if not _SHAP_AVAILABLE:
|
|
36
|
+
raise ImportError(
|
|
37
|
+
"SHAP is required for ShapExplainer. "
|
|
38
|
+
"Install it with: pip install shap"
|
|
39
|
+
)
|
|
40
|
+
|
|
19
41
|
|
|
20
42
|
class ShapExplainer(BaseExplainer):
|
|
21
43
|
"""
|
|
@@ -30,9 +52,25 @@ class ShapExplainer(BaseExplainer):
|
|
|
30
52
|
feature_names: List of feature names
|
|
31
53
|
class_names: List of class labels
|
|
32
54
|
explainer: The underlying SHAP KernelExplainer
|
|
55
|
+
|
|
56
|
+
Example:
|
|
57
|
+
>>> from explainiverse.explainers.attribution import ShapExplainer
|
|
58
|
+
>>> explainer = ShapExplainer(
|
|
59
|
+
... model=adapter,
|
|
60
|
+
... background_data=X_train[:100],
|
|
61
|
+
... feature_names=feature_names,
|
|
62
|
+
... class_names=class_names
|
|
63
|
+
... )
|
|
64
|
+
>>> explanation = explainer.explain(X_test[0])
|
|
33
65
|
"""
|
|
34
66
|
|
|
35
|
-
def __init__(
|
|
67
|
+
def __init__(
|
|
68
|
+
self,
|
|
69
|
+
model,
|
|
70
|
+
background_data: np.ndarray,
|
|
71
|
+
feature_names: List[str],
|
|
72
|
+
class_names: List[str]
|
|
73
|
+
):
|
|
36
74
|
"""
|
|
37
75
|
Initialize the SHAP explainer.
|
|
38
76
|
|
|
@@ -42,13 +80,31 @@ class ShapExplainer(BaseExplainer):
|
|
|
42
80
|
Typically a representative sample of training data.
|
|
43
81
|
feature_names: List of feature names.
|
|
44
82
|
class_names: List of class labels.
|
|
83
|
+
|
|
84
|
+
Raises:
|
|
85
|
+
ImportError: If shap package is not installed.
|
|
45
86
|
"""
|
|
87
|
+
# Check availability before importing
|
|
88
|
+
_check_shap_available()
|
|
89
|
+
|
|
90
|
+
# Import after check passes
|
|
91
|
+
import shap as shap_module
|
|
92
|
+
|
|
46
93
|
super().__init__(model)
|
|
47
94
|
self.feature_names = list(feature_names)
|
|
48
95
|
self.class_names = list(class_names)
|
|
49
|
-
self.
|
|
96
|
+
self.background_data = np.asarray(background_data)
|
|
97
|
+
|
|
98
|
+
self.explainer = shap_module.KernelExplainer(
|
|
99
|
+
model.predict,
|
|
100
|
+
self.background_data
|
|
101
|
+
)
|
|
50
102
|
|
|
51
|
-
def explain(
|
|
103
|
+
def explain(
|
|
104
|
+
self,
|
|
105
|
+
instance: np.ndarray,
|
|
106
|
+
top_labels: int = 1
|
|
107
|
+
) -> Explanation:
|
|
52
108
|
"""
|
|
53
109
|
Generate SHAP explanation for a single instance.
|
|
54
110
|
|
|
@@ -59,21 +115,26 @@ class ShapExplainer(BaseExplainer):
|
|
|
59
115
|
Returns:
|
|
60
116
|
Explanation object with feature attributions
|
|
61
117
|
"""
|
|
62
|
-
instance = np.
|
|
118
|
+
instance = np.asarray(instance)
|
|
119
|
+
original_instance = instance.copy()
|
|
120
|
+
|
|
121
|
+
if instance.ndim == 1:
|
|
122
|
+
instance = instance.reshape(1, -1)
|
|
123
|
+
|
|
63
124
|
shap_values = self.explainer.shap_values(instance)
|
|
64
125
|
|
|
65
126
|
if isinstance(shap_values, list):
|
|
66
127
|
# Multi-class: list of arrays, one per class
|
|
67
128
|
predicted_probs = self.model.predict(instance)[0]
|
|
68
129
|
top_indices = np.argsort(predicted_probs)[-top_labels:][::-1]
|
|
69
|
-
label_index = top_indices[0]
|
|
130
|
+
label_index = int(top_indices[0])
|
|
70
131
|
label_name = self.class_names[label_index]
|
|
71
132
|
class_shap = shap_values[label_index][0]
|
|
72
133
|
else:
|
|
73
134
|
# Single-class (regression or binary classification)
|
|
74
135
|
label_index = 0
|
|
75
136
|
label_name = self.class_names[0] if self.class_names else "class_0"
|
|
76
|
-
class_shap = shap_values[0]
|
|
137
|
+
class_shap = shap_values[0] if shap_values.ndim > 1 else shap_values
|
|
77
138
|
|
|
78
139
|
# Build attributions dict
|
|
79
140
|
flat_shap = np.array(class_shap).flatten()
|
|
@@ -81,9 +142,44 @@ class ShapExplainer(BaseExplainer):
|
|
|
81
142
|
fname: float(flat_shap[i])
|
|
82
143
|
for i, fname in enumerate(self.feature_names)
|
|
83
144
|
}
|
|
145
|
+
|
|
146
|
+
# Get expected value
|
|
147
|
+
if isinstance(self.explainer.expected_value, (list, np.ndarray)):
|
|
148
|
+
expected_val = float(self.explainer.expected_value[label_index])
|
|
149
|
+
else:
|
|
150
|
+
expected_val = float(self.explainer.expected_value)
|
|
84
151
|
|
|
85
152
|
return Explanation(
|
|
86
153
|
explainer_name="SHAP",
|
|
87
154
|
target_class=label_name,
|
|
88
|
-
explanation_data={
|
|
155
|
+
explanation_data={
|
|
156
|
+
"feature_attributions": attributions,
|
|
157
|
+
"shap_values_raw": flat_shap.tolist(),
|
|
158
|
+
"expected_value": expected_val
|
|
159
|
+
},
|
|
160
|
+
feature_names=self.feature_names
|
|
89
161
|
)
|
|
162
|
+
|
|
163
|
+
def explain_batch(
|
|
164
|
+
self,
|
|
165
|
+
X: np.ndarray,
|
|
166
|
+
top_labels: int = 1
|
|
167
|
+
) -> List[Explanation]:
|
|
168
|
+
"""
|
|
169
|
+
Generate explanations for multiple instances.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
X: 2D numpy array of instances
|
|
173
|
+
top_labels: Number of top labels to explain
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
List of Explanation objects
|
|
177
|
+
"""
|
|
178
|
+
X = np.asarray(X)
|
|
179
|
+
if X.ndim == 1:
|
|
180
|
+
X = X.reshape(1, -1)
|
|
181
|
+
|
|
182
|
+
return [
|
|
183
|
+
self.explain(X[i], top_labels=top_labels)
|
|
184
|
+
for i in range(X.shape[0])
|
|
185
|
+
]
|
|
@@ -13,6 +13,7 @@ Explainers:
|
|
|
13
13
|
- SmoothGradExplainer: Noise-averaged gradients
|
|
14
14
|
- SaliencyExplainer: Basic gradient attribution
|
|
15
15
|
- TCAVExplainer: Concept-based explanations (TCAV)
|
|
16
|
+
- LRPExplainer: Layer-wise Relevance Propagation
|
|
16
17
|
"""
|
|
17
18
|
|
|
18
19
|
from explainiverse.explainers.gradient.integrated_gradients import IntegratedGradientsExplainer
|
|
@@ -21,6 +22,7 @@ from explainiverse.explainers.gradient.deeplift import DeepLIFTExplainer, DeepLI
|
|
|
21
22
|
from explainiverse.explainers.gradient.smoothgrad import SmoothGradExplainer
|
|
22
23
|
from explainiverse.explainers.gradient.saliency import SaliencyExplainer
|
|
23
24
|
from explainiverse.explainers.gradient.tcav import TCAVExplainer, ConceptActivationVector
|
|
25
|
+
from explainiverse.explainers.gradient.lrp import LRPExplainer
|
|
24
26
|
|
|
25
27
|
__all__ = [
|
|
26
28
|
"IntegratedGradientsExplainer",
|
|
@@ -31,4 +33,5 @@ __all__ = [
|
|
|
31
33
|
"SaliencyExplainer",
|
|
32
34
|
"TCAVExplainer",
|
|
33
35
|
"ConceptActivationVector",
|
|
36
|
+
"LRPExplainer",
|
|
34
37
|
]
|
|
@@ -30,7 +30,7 @@ Example:
|
|
|
30
30
|
"""
|
|
31
31
|
|
|
32
32
|
import numpy as np
|
|
33
|
-
from typing import List, Optional, Union, Callable
|
|
33
|
+
from typing import List, Optional, Union, Callable, Tuple
|
|
34
34
|
|
|
35
35
|
from explainiverse.core.explainer import BaseExplainer
|
|
36
36
|
from explainiverse.core.explanation import Explanation
|
|
@@ -44,23 +44,28 @@ class IntegratedGradientsExplainer(BaseExplainer):
|
|
|
44
44
|
a baseline (default: zero vector) to the input. The integral is
|
|
45
45
|
approximated using the Riemann sum.
|
|
46
46
|
|
|
47
|
+
Supports both tabular data (1D/2D) and image data (3D/4D), preserving
|
|
48
|
+
the original input shape for proper gradient computation.
|
|
49
|
+
|
|
47
50
|
Attributes:
|
|
48
51
|
model: Model adapter with predict_with_gradients() method
|
|
49
|
-
feature_names: List of feature names
|
|
52
|
+
feature_names: List of feature names (for tabular data)
|
|
50
53
|
class_names: List of class names (for classification)
|
|
51
54
|
n_steps: Number of steps for integral approximation
|
|
52
55
|
baseline: Baseline input (default: zeros)
|
|
53
|
-
method: Integration method
|
|
56
|
+
method: Integration method
|
|
57
|
+
input_shape: Expected input shape (inferred or specified)
|
|
54
58
|
"""
|
|
55
59
|
|
|
56
60
|
def __init__(
|
|
57
61
|
self,
|
|
58
62
|
model,
|
|
59
|
-
feature_names: List[str],
|
|
63
|
+
feature_names: Optional[List[str]] = None,
|
|
60
64
|
class_names: Optional[List[str]] = None,
|
|
61
65
|
n_steps: int = 50,
|
|
62
|
-
baseline: Optional[np.ndarray] = None,
|
|
63
|
-
method: str = "riemann_middle"
|
|
66
|
+
baseline: Optional[Union[np.ndarray, str, Callable]] = None,
|
|
67
|
+
method: str = "riemann_middle",
|
|
68
|
+
input_shape: Optional[Tuple[int, ...]] = None
|
|
64
69
|
):
|
|
65
70
|
"""
|
|
66
71
|
Initialize the Integrated Gradients explainer.
|
|
@@ -68,17 +73,23 @@ class IntegratedGradientsExplainer(BaseExplainer):
|
|
|
68
73
|
Args:
|
|
69
74
|
model: A model adapter with predict_with_gradients() method.
|
|
70
75
|
Use PyTorchAdapter for PyTorch models.
|
|
71
|
-
feature_names: List of input feature names.
|
|
76
|
+
feature_names: List of input feature names. Required for tabular
|
|
77
|
+
data to create named attributions.
|
|
72
78
|
class_names: List of class names (for classification tasks).
|
|
73
79
|
n_steps: Number of steps for approximating the integral.
|
|
74
80
|
More steps = more accurate but slower. Default: 50.
|
|
75
|
-
baseline: Baseline input for comparison
|
|
76
|
-
|
|
81
|
+
baseline: Baseline input for comparison:
|
|
82
|
+
- None: uses zeros
|
|
83
|
+
- "random": random baseline (useful for images)
|
|
84
|
+
- np.ndarray: specific baseline values
|
|
85
|
+
- Callable: function(instance) -> baseline
|
|
77
86
|
method: Integration method:
|
|
78
87
|
- "riemann_middle": Middle Riemann sum (default, most accurate)
|
|
79
88
|
- "riemann_left": Left Riemann sum
|
|
80
89
|
- "riemann_right": Right Riemann sum
|
|
81
90
|
- "riemann_trapezoid": Trapezoidal rule
|
|
91
|
+
input_shape: Expected shape of a single input (excluding batch dim).
|
|
92
|
+
If None, inferred from first explain() call.
|
|
82
93
|
"""
|
|
83
94
|
super().__init__(model)
|
|
84
95
|
|
|
@@ -89,28 +100,61 @@ class IntegratedGradientsExplainer(BaseExplainer):
|
|
|
89
100
|
"Use PyTorchAdapter for PyTorch models."
|
|
90
101
|
)
|
|
91
102
|
|
|
92
|
-
self.feature_names = list(feature_names)
|
|
103
|
+
self.feature_names = list(feature_names) if feature_names else None
|
|
93
104
|
self.class_names = list(class_names) if class_names else None
|
|
94
105
|
self.n_steps = n_steps
|
|
95
106
|
self.baseline = baseline
|
|
96
107
|
self.method = method
|
|
108
|
+
self.input_shape = input_shape
|
|
109
|
+
|
|
110
|
+
def _infer_data_type(self, instance: np.ndarray) -> str:
|
|
111
|
+
"""
|
|
112
|
+
Infer whether input is tabular or image data.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
instance: Input instance (without batch dimension)
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
"tabular" for 1D data, "image" for 2D+ data
|
|
119
|
+
"""
|
|
120
|
+
if instance.ndim == 1:
|
|
121
|
+
return "tabular"
|
|
122
|
+
elif instance.ndim >= 2:
|
|
123
|
+
return "image"
|
|
124
|
+
else:
|
|
125
|
+
return "tabular"
|
|
97
126
|
|
|
98
127
|
def _get_baseline(self, instance: np.ndarray) -> np.ndarray:
|
|
99
|
-
"""
|
|
128
|
+
"""
|
|
129
|
+
Get the baseline for a given input shape.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
instance: Input instance (preserves shape)
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
Baseline array with same shape as instance
|
|
136
|
+
"""
|
|
100
137
|
if self.baseline is None:
|
|
101
138
|
# Default: zero baseline
|
|
102
139
|
return np.zeros_like(instance)
|
|
103
|
-
elif isinstance(self.baseline, str)
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
140
|
+
elif isinstance(self.baseline, str):
|
|
141
|
+
if self.baseline == "random":
|
|
142
|
+
# Random baseline (useful for images)
|
|
143
|
+
return np.random.uniform(
|
|
144
|
+
low=float(instance.min()),
|
|
145
|
+
high=float(instance.max()),
|
|
146
|
+
size=instance.shape
|
|
147
|
+
).astype(instance.dtype)
|
|
148
|
+
elif self.baseline == "mean":
|
|
149
|
+
# Mean value baseline
|
|
150
|
+
return np.full_like(instance, instance.mean())
|
|
151
|
+
else:
|
|
152
|
+
raise ValueError(f"Unknown baseline type: {self.baseline}")
|
|
110
153
|
elif callable(self.baseline):
|
|
111
|
-
|
|
154
|
+
result = self.baseline(instance)
|
|
155
|
+
return np.asarray(result).reshape(instance.shape)
|
|
112
156
|
else:
|
|
113
|
-
return np.
|
|
157
|
+
return np.asarray(self.baseline).reshape(instance.shape)
|
|
114
158
|
|
|
115
159
|
def _get_interpolation_alphas(self) -> np.ndarray:
|
|
116
160
|
"""Get interpolation points based on method."""
|
|
@@ -134,37 +178,67 @@ class IntegratedGradientsExplainer(BaseExplainer):
|
|
|
134
178
|
"""
|
|
135
179
|
Compute integrated gradients for a single instance.
|
|
136
180
|
|
|
181
|
+
Preserves input shape throughout computation for proper gradient flow.
|
|
182
|
+
|
|
137
183
|
The integral is approximated as:
|
|
138
184
|
IG_i = (x_i - x'_i) * sum_{k=1}^{m} grad_i(x' + k/m * (x - x')) / m
|
|
139
185
|
|
|
140
186
|
where x is the input, x' is the baseline, and m is n_steps.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
instance: Input instance (any shape)
|
|
190
|
+
baseline: Baseline with same shape as instance
|
|
191
|
+
target_class: Target class for attribution
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
Attributions with same shape as instance
|
|
141
195
|
"""
|
|
196
|
+
# Store original shape
|
|
197
|
+
original_shape = instance.shape
|
|
198
|
+
|
|
142
199
|
# Get interpolation points
|
|
143
200
|
alphas = self._get_interpolation_alphas()
|
|
144
201
|
|
|
145
202
|
# Compute path from baseline to input
|
|
146
|
-
# Shape: (n_steps, n_features)
|
|
147
203
|
delta = instance - baseline
|
|
148
|
-
interpolated_inputs = baseline + alphas[:, np.newaxis] * delta
|
|
149
204
|
|
|
150
|
-
#
|
|
205
|
+
# Collect gradients at each interpolation point
|
|
151
206
|
all_gradients = []
|
|
152
|
-
|
|
207
|
+
|
|
208
|
+
for alpha in alphas:
|
|
209
|
+
# Interpolated input: baseline + alpha * (input - baseline)
|
|
210
|
+
interp_input = baseline + alpha * delta
|
|
211
|
+
|
|
212
|
+
# Add batch dimension for model
|
|
213
|
+
if interp_input.ndim == len(original_shape):
|
|
214
|
+
interp_batch = interp_input[np.newaxis, ...]
|
|
215
|
+
else:
|
|
216
|
+
interp_batch = interp_input
|
|
217
|
+
|
|
218
|
+
# Get gradients
|
|
153
219
|
_, gradients = self.model.predict_with_gradients(
|
|
154
|
-
|
|
220
|
+
interp_batch,
|
|
155
221
|
target_class=target_class
|
|
156
222
|
)
|
|
157
|
-
|
|
223
|
+
|
|
224
|
+
# Remove batch dimension if present
|
|
225
|
+
if gradients.shape[0] == 1 and len(gradients.shape) > len(original_shape):
|
|
226
|
+
gradients = gradients[0]
|
|
227
|
+
|
|
228
|
+
all_gradients.append(gradients.reshape(original_shape))
|
|
158
229
|
|
|
159
|
-
all_gradients = np.array(all_gradients) # Shape: (n_steps,
|
|
230
|
+
all_gradients = np.array(all_gradients) # Shape: (n_steps, *original_shape)
|
|
160
231
|
|
|
161
232
|
# Approximate the integral
|
|
162
233
|
if self.method == "riemann_trapezoid":
|
|
163
|
-
# Trapezoidal rule
|
|
234
|
+
# Trapezoidal rule
|
|
164
235
|
weights = np.ones(self.n_steps + 1)
|
|
165
236
|
weights[0] = 0.5
|
|
166
237
|
weights[-1] = 0.5
|
|
167
|
-
|
|
238
|
+
# Expand weights for broadcasting
|
|
239
|
+
for _ in range(len(original_shape)):
|
|
240
|
+
weights = weights[:, np.newaxis]
|
|
241
|
+
avg_gradients = np.sum(all_gradients * weights, axis=0) / self.n_steps
|
|
168
242
|
else:
|
|
169
243
|
# Standard Riemann sum: average of gradients
|
|
170
244
|
avg_gradients = np.mean(all_gradients, axis=0)
|
|
@@ -185,7 +259,10 @@ class IntegratedGradientsExplainer(BaseExplainer):
|
|
|
185
259
|
Generate Integrated Gradients explanation for an instance.
|
|
186
260
|
|
|
187
261
|
Args:
|
|
188
|
-
instance:
|
|
262
|
+
instance: Input instance. Can be:
|
|
263
|
+
- 1D array for tabular data
|
|
264
|
+
- 2D array for grayscale images
|
|
265
|
+
- 3D array for color images (C, H, W)
|
|
189
266
|
target_class: For classification, which class to explain.
|
|
190
267
|
If None, uses the predicted class.
|
|
191
268
|
baseline: Override the default baseline for this explanation.
|
|
@@ -196,66 +273,90 @@ class IntegratedGradientsExplainer(BaseExplainer):
|
|
|
196
273
|
Returns:
|
|
197
274
|
Explanation object with feature attributions.
|
|
198
275
|
"""
|
|
199
|
-
instance = np.
|
|
276
|
+
instance = np.asarray(instance).astype(np.float32)
|
|
277
|
+
original_shape = instance.shape
|
|
200
278
|
|
|
201
|
-
#
|
|
279
|
+
# Infer data type
|
|
280
|
+
data_type = self._infer_data_type(instance)
|
|
281
|
+
|
|
282
|
+
# Get baseline (preserves shape)
|
|
202
283
|
if baseline is not None:
|
|
203
|
-
bl = np.
|
|
284
|
+
bl = np.asarray(baseline).astype(np.float32).reshape(original_shape)
|
|
204
285
|
else:
|
|
205
286
|
bl = self._get_baseline(instance)
|
|
206
287
|
|
|
207
288
|
# Determine target class if not specified
|
|
208
289
|
if target_class is None and self.class_names:
|
|
209
|
-
|
|
210
|
-
|
|
290
|
+
# Add batch dim for prediction
|
|
291
|
+
pred_input = instance[np.newaxis, ...] if instance.ndim == len(original_shape) else instance
|
|
292
|
+
predictions = self.model.predict(pred_input)
|
|
293
|
+
target_class = int(np.argmax(predictions[0]))
|
|
211
294
|
|
|
212
|
-
# Compute integrated gradients
|
|
295
|
+
# Compute integrated gradients (preserves shape)
|
|
213
296
|
ig_attributions = self._compute_integrated_gradients(
|
|
214
297
|
instance, bl, target_class
|
|
215
298
|
)
|
|
216
299
|
|
|
217
|
-
# Build
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
300
|
+
# Build explanation data
|
|
301
|
+
explanation_data = {
|
|
302
|
+
"attributions_raw": ig_attributions.tolist(),
|
|
303
|
+
"baseline": bl.tolist(),
|
|
304
|
+
"n_steps": self.n_steps,
|
|
305
|
+
"method": self.method,
|
|
306
|
+
"input_shape": list(original_shape),
|
|
307
|
+
"data_type": data_type
|
|
221
308
|
}
|
|
222
309
|
|
|
310
|
+
# For tabular data, create named attributions
|
|
311
|
+
if data_type == "tabular" and self.feature_names is not None:
|
|
312
|
+
flat_ig = ig_attributions.flatten()
|
|
313
|
+
if len(flat_ig) == len(self.feature_names):
|
|
314
|
+
attributions = {
|
|
315
|
+
fname: float(flat_ig[i])
|
|
316
|
+
for i, fname in enumerate(self.feature_names)
|
|
317
|
+
}
|
|
318
|
+
explanation_data["feature_attributions"] = attributions
|
|
319
|
+
elif data_type == "image":
|
|
320
|
+
# For images, store aggregated feature importance
|
|
321
|
+
explanation_data["attribution_map"] = ig_attributions
|
|
322
|
+
# Also store channel-aggregated saliency for visualization
|
|
323
|
+
if ig_attributions.ndim == 3: # (C, H, W)
|
|
324
|
+
explanation_data["saliency_map"] = np.abs(ig_attributions).sum(axis=0)
|
|
325
|
+
else:
|
|
326
|
+
explanation_data["saliency_map"] = np.abs(ig_attributions)
|
|
327
|
+
|
|
223
328
|
# Determine class name
|
|
224
329
|
if self.class_names and target_class is not None:
|
|
225
330
|
label_name = self.class_names[target_class]
|
|
226
331
|
else:
|
|
227
332
|
label_name = f"class_{target_class}" if target_class is not None else "output"
|
|
228
333
|
|
|
229
|
-
explanation_data = {
|
|
230
|
-
"feature_attributions": attributions,
|
|
231
|
-
"attributions_raw": ig_attributions.tolist(),
|
|
232
|
-
"baseline": bl.tolist(),
|
|
233
|
-
"n_steps": self.n_steps,
|
|
234
|
-
"method": self.method
|
|
235
|
-
}
|
|
236
|
-
|
|
237
334
|
# Optionally compute convergence delta
|
|
238
335
|
if return_convergence_delta:
|
|
239
336
|
# The sum of attributions should equal F(x) - F(baseline)
|
|
240
|
-
pred_input =
|
|
241
|
-
pred_baseline =
|
|
337
|
+
pred_input = instance[np.newaxis, ...]
|
|
338
|
+
pred_baseline = bl[np.newaxis, ...]
|
|
339
|
+
|
|
340
|
+
pred_input_val = self.model.predict(pred_input)
|
|
341
|
+
pred_baseline_val = self.model.predict(pred_baseline)
|
|
242
342
|
|
|
243
|
-
if target_class is not None:
|
|
244
|
-
pred_diff =
|
|
343
|
+
if target_class is not None and pred_input_val.shape[-1] > 1:
|
|
344
|
+
pred_diff = pred_input_val[0, target_class] - pred_baseline_val[0, target_class]
|
|
245
345
|
else:
|
|
246
|
-
pred_diff =
|
|
346
|
+
pred_diff = pred_input_val[0, 0] - pred_baseline_val[0, 0]
|
|
247
347
|
|
|
248
|
-
attribution_sum = np.sum(ig_attributions)
|
|
249
|
-
convergence_delta = abs(pred_diff - attribution_sum)
|
|
348
|
+
attribution_sum = float(np.sum(ig_attributions))
|
|
349
|
+
convergence_delta = abs(float(pred_diff) - attribution_sum)
|
|
250
350
|
|
|
251
|
-
explanation_data["convergence_delta"] =
|
|
351
|
+
explanation_data["convergence_delta"] = convergence_delta
|
|
252
352
|
explanation_data["prediction_difference"] = float(pred_diff)
|
|
253
|
-
explanation_data["attribution_sum"] =
|
|
353
|
+
explanation_data["attribution_sum"] = attribution_sum
|
|
254
354
|
|
|
255
355
|
return Explanation(
|
|
256
356
|
explainer_name="IntegratedGradients",
|
|
257
357
|
target_class=label_name,
|
|
258
|
-
explanation_data=explanation_data
|
|
358
|
+
explanation_data=explanation_data,
|
|
359
|
+
feature_names=self.feature_names
|
|
259
360
|
)
|
|
260
361
|
|
|
261
362
|
def explain_batch(
|
|
@@ -266,20 +367,21 @@ class IntegratedGradientsExplainer(BaseExplainer):
|
|
|
266
367
|
"""
|
|
267
368
|
Generate explanations for multiple instances.
|
|
268
369
|
|
|
269
|
-
Note: This
|
|
270
|
-
|
|
271
|
-
the batched gradient computation in a custom implementation.
|
|
370
|
+
Note: This processes instances sequentially. For large batches,
|
|
371
|
+
consider implementing batched gradient computation.
|
|
272
372
|
|
|
273
373
|
Args:
|
|
274
|
-
X:
|
|
374
|
+
X: Array of instances. First dimension is batch.
|
|
275
375
|
target_class: Target class for all instances.
|
|
276
376
|
|
|
277
377
|
Returns:
|
|
278
378
|
List of Explanation objects.
|
|
279
379
|
"""
|
|
280
|
-
X = np.
|
|
380
|
+
X = np.asarray(X)
|
|
381
|
+
|
|
382
|
+
# Handle single instance passed as array
|
|
281
383
|
if X.ndim == 1:
|
|
282
|
-
|
|
384
|
+
return [self.explain(X, target_class=target_class)]
|
|
283
385
|
|
|
284
386
|
return [
|
|
285
387
|
self.explain(X[i], target_class=target_class)
|
|
@@ -308,12 +410,13 @@ class IntegratedGradientsExplainer(BaseExplainer):
|
|
|
308
410
|
Returns:
|
|
309
411
|
Explanation with averaged attributions.
|
|
310
412
|
"""
|
|
311
|
-
instance = np.
|
|
413
|
+
instance = np.asarray(instance).astype(np.float32)
|
|
414
|
+
original_shape = instance.shape
|
|
312
415
|
|
|
313
416
|
all_attributions = []
|
|
314
417
|
for _ in range(n_samples):
|
|
315
418
|
# Create noisy baseline
|
|
316
|
-
noise = np.random.normal(0, noise_scale,
|
|
419
|
+
noise = np.random.normal(0, noise_scale, original_shape).astype(np.float32)
|
|
317
420
|
noisy_baseline = noise # Noise around zero
|
|
318
421
|
|
|
319
422
|
ig = self._compute_integrated_gradients(
|
|
@@ -325,11 +428,26 @@ class IntegratedGradientsExplainer(BaseExplainer):
|
|
|
325
428
|
avg_attributions = np.mean(all_attributions, axis=0)
|
|
326
429
|
std_attributions = np.std(all_attributions, axis=0)
|
|
327
430
|
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
431
|
+
# Build explanation data
|
|
432
|
+
data_type = self._infer_data_type(instance)
|
|
433
|
+
explanation_data = {
|
|
434
|
+
"attributions_raw": avg_attributions.tolist(),
|
|
435
|
+
"attributions_std": std_attributions.tolist(),
|
|
436
|
+
"n_samples": n_samples,
|
|
437
|
+
"noise_scale": noise_scale,
|
|
438
|
+
"data_type": data_type
|
|
331
439
|
}
|
|
332
440
|
|
|
441
|
+
# For tabular data, create named attributions
|
|
442
|
+
if data_type == "tabular" and self.feature_names is not None:
|
|
443
|
+
flat_avg = avg_attributions.flatten()
|
|
444
|
+
if len(flat_avg) == len(self.feature_names):
|
|
445
|
+
attributions = {
|
|
446
|
+
fname: float(flat_avg[i])
|
|
447
|
+
for i, fname in enumerate(self.feature_names)
|
|
448
|
+
}
|
|
449
|
+
explanation_data["feature_attributions"] = attributions
|
|
450
|
+
|
|
333
451
|
if self.class_names and target_class is not None:
|
|
334
452
|
label_name = self.class_names[target_class]
|
|
335
453
|
else:
|
|
@@ -338,11 +456,6 @@ class IntegratedGradientsExplainer(BaseExplainer):
|
|
|
338
456
|
return Explanation(
|
|
339
457
|
explainer_name="IntegratedGradients_Smooth",
|
|
340
458
|
target_class=label_name,
|
|
341
|
-
explanation_data=
|
|
342
|
-
|
|
343
|
-
"attributions_raw": avg_attributions.tolist(),
|
|
344
|
-
"attributions_std": std_attributions.tolist(),
|
|
345
|
-
"n_samples": n_samples,
|
|
346
|
-
"noise_scale": noise_scale
|
|
347
|
-
}
|
|
459
|
+
explanation_data=explanation_data,
|
|
460
|
+
feature_names=self.feature_names
|
|
348
461
|
)
|