explainiverse 0.7.1__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- explainiverse/__init__.py +5 -4
- explainiverse/core/registry.py +18 -0
- explainiverse/explainers/gradient/__init__.py +3 -0
- explainiverse/explainers/gradient/lrp.py +1211 -0
- {explainiverse-0.7.1.dist-info → explainiverse-0.8.1.dist-info}/METADATA +76 -13
- {explainiverse-0.7.1.dist-info → explainiverse-0.8.1.dist-info}/RECORD +8 -7
- {explainiverse-0.7.1.dist-info → explainiverse-0.8.1.dist-info}/LICENSE +0 -0
- {explainiverse-0.7.1.dist-info → explainiverse-0.8.1.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,1211 @@
|
|
|
1
|
+
# src/explainiverse/explainers/gradient/lrp.py
|
|
2
|
+
"""
|
|
3
|
+
Layer-wise Relevance Propagation (LRP) - Decomposition-based Attribution.
|
|
4
|
+
|
|
5
|
+
LRP decomposes network predictions back to input features using a conservation
|
|
6
|
+
principle. Unlike gradient-based methods, LRP propagates relevance scores
|
|
7
|
+
layer-by-layer through the network using specific propagation rules.
|
|
8
|
+
|
|
9
|
+
Key Properties:
|
|
10
|
+
- Conservation: Sum of relevances at each layer equals the output
|
|
11
|
+
- Layer-wise decomposition: Relevance flows backward through layers
|
|
12
|
+
- Multiple rules: Different rules for different layer types and use cases
|
|
13
|
+
|
|
14
|
+
Supported Layer Types:
|
|
15
|
+
- Linear (fully connected)
|
|
16
|
+
- Conv2d (convolutional)
|
|
17
|
+
- BatchNorm1d, BatchNorm2d
|
|
18
|
+
- ReLU, LeakyReLU, ELU, Tanh, Sigmoid
|
|
19
|
+
- MaxPool2d, AvgPool2d
|
|
20
|
+
- Flatten, Dropout (passthrough)
|
|
21
|
+
|
|
22
|
+
Propagation Rules:
|
|
23
|
+
- LRP-0: Basic rule (no stabilization) - theoretical baseline
|
|
24
|
+
- LRP-ε (epsilon): Adds small constant for numerical stability (recommended)
|
|
25
|
+
- LRP-γ (gamma): Enhances positive contributions - good for image classification
|
|
26
|
+
- LRP-αβ (alpha-beta): Separates positive/negative contributions - fine control
|
|
27
|
+
- LRP-z⁺ (z-plus): Only considers positive weights - often used for input layers
|
|
28
|
+
- Composite: Different rules for different layers
|
|
29
|
+
|
|
30
|
+
Mathematical Formulation:
|
|
31
|
+
For layer l with input a and output z = Wx + b:
|
|
32
|
+
|
|
33
|
+
LRP-0: R_j = Σ_k (a_j * w_jk / z_k) * R_k
|
|
34
|
+
LRP-ε: R_j = Σ_k (a_j * w_jk / (z_k + ε*sign(z_k))) * R_k
|
|
35
|
+
LRP-γ: R_j = Σ_k (a_j * (w_jk + γ*w_jk⁺) / (z_k + γ*z_k⁺)) * R_k
|
|
36
|
+
LRP-αβ: R_j = Σ_k (α * (a_j * w_jk)⁺ / z_k⁺ - β * (a_j * w_jk)⁻ / z_k⁻) * R_k
|
|
37
|
+
LRP-z⁺: R_j = Σ_k (a_j * w_jk⁺ / Σ_i a_i * w_ik⁺) * R_k
|
|
38
|
+
|
|
39
|
+
Reference:
|
|
40
|
+
Bach, S., Binder, A., Montavon, G., Klauschen, F., Müller, K. R., & Samek, W. (2015).
|
|
41
|
+
On Pixel-wise Explanations for Non-Linear Classifier Decisions by Layer-wise
|
|
42
|
+
Relevance Propagation. PLOS ONE.
|
|
43
|
+
https://doi.org/10.1371/journal.pone.0130140
|
|
44
|
+
|
|
45
|
+
Montavon, G., Binder, A., Lapuschkin, S., Samek, W., & Müller, K. R. (2019).
|
|
46
|
+
Layer-wise Relevance Propagation: An Overview. Explainable AI: Interpreting,
|
|
47
|
+
Explaining and Visualizing Deep Learning. Springer.
|
|
48
|
+
|
|
49
|
+
Example:
|
|
50
|
+
from explainiverse.explainers.gradient import LRPExplainer
|
|
51
|
+
from explainiverse.adapters import PyTorchAdapter
|
|
52
|
+
|
|
53
|
+
adapter = PyTorchAdapter(model, task="classification")
|
|
54
|
+
|
|
55
|
+
explainer = LRPExplainer(
|
|
56
|
+
model=adapter,
|
|
57
|
+
feature_names=feature_names,
|
|
58
|
+
rule="epsilon",
|
|
59
|
+
epsilon=1e-6
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
explanation = explainer.explain(instance)
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
import numpy as np
|
|
66
|
+
from typing import List, Optional, Dict, Any, Tuple, Union
|
|
67
|
+
from collections import OrderedDict
|
|
68
|
+
|
|
69
|
+
from explainiverse.core.explainer import BaseExplainer
|
|
70
|
+
from explainiverse.core.explanation import Explanation
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
# Check if PyTorch is available
|
|
74
|
+
try:
|
|
75
|
+
import torch
|
|
76
|
+
import torch.nn as nn
|
|
77
|
+
import torch.nn.functional as F
|
|
78
|
+
TORCH_AVAILABLE = True
|
|
79
|
+
except ImportError:
|
|
80
|
+
TORCH_AVAILABLE = False
|
|
81
|
+
torch = None
|
|
82
|
+
nn = None
|
|
83
|
+
F = None
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
# Valid LRP rules
|
|
87
|
+
VALID_RULES = ["epsilon", "gamma", "alpha_beta", "z_plus", "composite"]
|
|
88
|
+
|
|
89
|
+
# Layer types that require special LRP handling
|
|
90
|
+
WEIGHTED_LAYERS = (nn.Linear, nn.Conv2d) if TORCH_AVAILABLE else ()
|
|
91
|
+
NORMALIZATION_LAYERS = (nn.BatchNorm1d, nn.BatchNorm2d) if TORCH_AVAILABLE else ()
|
|
92
|
+
ACTIVATION_LAYERS = (nn.ReLU, nn.LeakyReLU, nn.ELU, nn.Tanh, nn.Sigmoid, nn.GELU) if TORCH_AVAILABLE else ()
|
|
93
|
+
POOLING_LAYERS = (nn.MaxPool2d, nn.AvgPool2d, nn.AdaptiveAvgPool2d) if TORCH_AVAILABLE else ()
|
|
94
|
+
PASSTHROUGH_LAYERS = (nn.Dropout, nn.Dropout2d, nn.Flatten) if TORCH_AVAILABLE else ()
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class LRPExplainer(BaseExplainer):
|
|
98
|
+
"""
|
|
99
|
+
Layer-wise Relevance Propagation (LRP) explainer for neural networks.
|
|
100
|
+
|
|
101
|
+
LRP decomposes the network output into relevance scores for each input
|
|
102
|
+
feature by propagating relevance backward through the network layers.
|
|
103
|
+
The key property is conservation: the sum of relevances at each layer
|
|
104
|
+
equals the relevance at the layer above.
|
|
105
|
+
|
|
106
|
+
Supports:
|
|
107
|
+
- Fully connected networks (Linear + activations)
|
|
108
|
+
- Convolutional networks (Conv2d + BatchNorm + pooling)
|
|
109
|
+
- Mixed architectures
|
|
110
|
+
|
|
111
|
+
Attributes:
|
|
112
|
+
model: Model adapter (must be PyTorchAdapter)
|
|
113
|
+
feature_names: List of feature names
|
|
114
|
+
class_names: List of class names (for classification)
|
|
115
|
+
rule: Propagation rule ("epsilon", "gamma", "alpha_beta", "z_plus", "composite")
|
|
116
|
+
epsilon: Stabilization constant for epsilon rule
|
|
117
|
+
gamma: Enhancement factor for gamma rule
|
|
118
|
+
alpha: Positive contribution weight for alpha-beta rule
|
|
119
|
+
beta: Negative contribution weight for alpha-beta rule
|
|
120
|
+
|
|
121
|
+
Example:
|
|
122
|
+
>>> explainer = LRPExplainer(adapter, feature_names, rule="epsilon")
|
|
123
|
+
>>> explanation = explainer.explain(instance)
|
|
124
|
+
>>> print(explanation.explanation_data["feature_attributions"])
|
|
125
|
+
"""
|
|
126
|
+
|
|
127
|
+
def __init__(
|
|
128
|
+
self,
|
|
129
|
+
model,
|
|
130
|
+
feature_names: List[str],
|
|
131
|
+
class_names: Optional[List[str]] = None,
|
|
132
|
+
rule: str = "epsilon",
|
|
133
|
+
epsilon: float = 1e-6,
|
|
134
|
+
gamma: float = 0.25,
|
|
135
|
+
alpha: float = 2.0,
|
|
136
|
+
beta: float = 1.0
|
|
137
|
+
):
|
|
138
|
+
"""
|
|
139
|
+
Initialize the LRP explainer.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
model: A PyTorchAdapter wrapping the model to explain.
|
|
143
|
+
feature_names: List of input feature names.
|
|
144
|
+
class_names: List of class names (for classification tasks).
|
|
145
|
+
rule: Propagation rule to use:
|
|
146
|
+
- "epsilon": LRP-ε with stabilization (default, recommended)
|
|
147
|
+
- "gamma": LRP-γ enhancing positive contributions
|
|
148
|
+
- "alpha_beta": LRP-αβ separating pos/neg contributions
|
|
149
|
+
- "z_plus": LRP-z⁺ using only positive weights
|
|
150
|
+
- "composite": Different rules for different layers
|
|
151
|
+
epsilon: Small constant for numerical stability in epsilon rule.
|
|
152
|
+
Default: 1e-6
|
|
153
|
+
gamma: Factor to enhance positive contributions in gamma rule.
|
|
154
|
+
Default: 0.25
|
|
155
|
+
alpha: Weight for positive contributions in alpha-beta rule.
|
|
156
|
+
Must satisfy alpha - beta = 1. Default: 2.0
|
|
157
|
+
beta: Weight for negative contributions in alpha-beta rule.
|
|
158
|
+
Must satisfy alpha - beta = 1. Default: 1.0
|
|
159
|
+
|
|
160
|
+
Raises:
|
|
161
|
+
TypeError: If model is not a PyTorchAdapter.
|
|
162
|
+
ValueError: If rule is invalid or alpha-beta constraint violated.
|
|
163
|
+
"""
|
|
164
|
+
if not TORCH_AVAILABLE:
|
|
165
|
+
raise ImportError(
|
|
166
|
+
"PyTorch is required for LRP. Install with: pip install torch"
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
super().__init__(model)
|
|
170
|
+
|
|
171
|
+
# Validate model is PyTorchAdapter
|
|
172
|
+
if not hasattr(model, 'model') or not isinstance(model.model, nn.Module):
|
|
173
|
+
raise TypeError(
|
|
174
|
+
"LRP requires a PyTorchAdapter wrapping a PyTorch model. "
|
|
175
|
+
"Use: PyTorchAdapter(your_model, task='classification')"
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
# Validate rule
|
|
179
|
+
if rule not in VALID_RULES:
|
|
180
|
+
raise ValueError(
|
|
181
|
+
f"Invalid rule: '{rule}'. Must be one of: {VALID_RULES}"
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
# Validate alpha-beta constraint
|
|
185
|
+
if rule == "alpha_beta":
|
|
186
|
+
if not np.isclose(alpha - beta, 1.0):
|
|
187
|
+
raise ValueError(
|
|
188
|
+
f"For alpha-beta rule, alpha - beta must equal 1. "
|
|
189
|
+
f"Got alpha={alpha}, beta={beta}, difference={alpha - beta}"
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
self.feature_names = list(feature_names)
|
|
193
|
+
self.class_names = list(class_names) if class_names else None
|
|
194
|
+
self.rule = rule
|
|
195
|
+
self.epsilon = epsilon
|
|
196
|
+
self.gamma = gamma
|
|
197
|
+
self.alpha = alpha
|
|
198
|
+
self.beta = beta
|
|
199
|
+
|
|
200
|
+
# For composite rules
|
|
201
|
+
self._layer_rules: Optional[Dict[int, str]] = None
|
|
202
|
+
|
|
203
|
+
# Cache for layer information
|
|
204
|
+
self._layers_info: Optional[List[Dict[str, Any]]] = None
|
|
205
|
+
|
|
206
|
+
def _get_pytorch_model(self) -> nn.Module:
|
|
207
|
+
"""Get the underlying PyTorch model."""
|
|
208
|
+
return self.model.model
|
|
209
|
+
|
|
210
|
+
def _is_cnn_model(self) -> bool:
|
|
211
|
+
"""Check if the model's first weighted layer is Conv2d."""
|
|
212
|
+
model = self._get_pytorch_model()
|
|
213
|
+
for module in model.modules():
|
|
214
|
+
if isinstance(module, nn.Conv2d):
|
|
215
|
+
return True
|
|
216
|
+
if isinstance(module, nn.Linear):
|
|
217
|
+
return False
|
|
218
|
+
return False
|
|
219
|
+
|
|
220
|
+
def _prepare_input_tensor(self, instance: np.ndarray) -> torch.Tensor:
|
|
221
|
+
"""
|
|
222
|
+
Prepare input tensor with correct shape for the model.
|
|
223
|
+
|
|
224
|
+
For CNN models, preserves the spatial dimensions.
|
|
225
|
+
For MLP models, flattens to 2D.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
instance: Input array (1D for tabular, 3D for images)
|
|
229
|
+
|
|
230
|
+
Returns:
|
|
231
|
+
Tensor with batch dimension added and correct shape for model
|
|
232
|
+
"""
|
|
233
|
+
instance = np.array(instance).astype(np.float32)
|
|
234
|
+
original_shape = instance.shape
|
|
235
|
+
|
|
236
|
+
model = self._get_pytorch_model()
|
|
237
|
+
|
|
238
|
+
# Find first weighted layer to determine input type
|
|
239
|
+
first_layer = None
|
|
240
|
+
for module in model.modules():
|
|
241
|
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
|
242
|
+
first_layer = module
|
|
243
|
+
break
|
|
244
|
+
|
|
245
|
+
if isinstance(first_layer, nn.Conv2d):
|
|
246
|
+
# CNN model - need 4D input (batch, channels, height, width)
|
|
247
|
+
in_channels = first_layer.in_channels
|
|
248
|
+
|
|
249
|
+
if len(original_shape) >= 3:
|
|
250
|
+
# Already (C, H, W) format - just add batch dimension
|
|
251
|
+
x = torch.tensor(instance).unsqueeze(0)
|
|
252
|
+
elif len(original_shape) == 2:
|
|
253
|
+
# (H, W) - assume single channel, add channel and batch dimensions
|
|
254
|
+
x = torch.tensor(instance).unsqueeze(0).unsqueeze(0)
|
|
255
|
+
else:
|
|
256
|
+
# Flattened - try to infer spatial dimensions
|
|
257
|
+
n_features = instance.size
|
|
258
|
+
if n_features % in_channels == 0:
|
|
259
|
+
spatial_size = int(np.sqrt(n_features // in_channels))
|
|
260
|
+
if spatial_size * spatial_size * in_channels == n_features:
|
|
261
|
+
x = torch.tensor(instance.flatten()).reshape(
|
|
262
|
+
1, in_channels, spatial_size, spatial_size
|
|
263
|
+
)
|
|
264
|
+
else:
|
|
265
|
+
raise ValueError(
|
|
266
|
+
f"Cannot infer spatial dimensions for {n_features} features "
|
|
267
|
+
f"with {in_channels} channels"
|
|
268
|
+
)
|
|
269
|
+
else:
|
|
270
|
+
raise ValueError(
|
|
271
|
+
f"Number of features ({n_features}) not divisible by "
|
|
272
|
+
f"input channels ({in_channels})"
|
|
273
|
+
)
|
|
274
|
+
else:
|
|
275
|
+
# MLP model - need 2D input (batch, features)
|
|
276
|
+
x = torch.tensor(instance.flatten()).reshape(1, -1)
|
|
277
|
+
|
|
278
|
+
return x.float()
|
|
279
|
+
|
|
280
|
+
def _get_rule_for_layer(self, layer_idx: int, layer_type: str) -> str:
|
|
281
|
+
"""
|
|
282
|
+
Get the propagation rule for a specific layer.
|
|
283
|
+
|
|
284
|
+
Args:
|
|
285
|
+
layer_idx: Index of the layer
|
|
286
|
+
layer_type: Type of the layer (e.g., "Linear", "Conv2d")
|
|
287
|
+
|
|
288
|
+
Returns:
|
|
289
|
+
Rule name to use for this layer
|
|
290
|
+
"""
|
|
291
|
+
if self.rule != "composite":
|
|
292
|
+
return self.rule
|
|
293
|
+
|
|
294
|
+
# Composite rule: check layer-specific rules
|
|
295
|
+
if self._layer_rules and layer_idx in self._layer_rules:
|
|
296
|
+
return self._layer_rules[layer_idx]
|
|
297
|
+
|
|
298
|
+
# Default fallback for composite
|
|
299
|
+
return "epsilon"
|
|
300
|
+
|
|
301
|
+
def set_composite_rule(self, layer_rules: Dict[int, str]) -> "LRPExplainer":
|
|
302
|
+
"""
|
|
303
|
+
Set layer-specific rules for composite LRP.
|
|
304
|
+
|
|
305
|
+
This allows using different propagation rules for different layers,
|
|
306
|
+
which is often beneficial. A common practice is:
|
|
307
|
+
- z_plus for input/early layers (focuses on what's present)
|
|
308
|
+
- epsilon for middle layers (balanced attribution)
|
|
309
|
+
- epsilon or zero for final layers
|
|
310
|
+
|
|
311
|
+
Args:
|
|
312
|
+
layer_rules: Dictionary mapping layer indices to rule names.
|
|
313
|
+
Layers not in this dict use "epsilon" by default.
|
|
314
|
+
|
|
315
|
+
Returns:
|
|
316
|
+
Self for method chaining.
|
|
317
|
+
|
|
318
|
+
Example:
|
|
319
|
+
>>> explainer.set_composite_rule({
|
|
320
|
+
... 0: "z_plus", # First layer
|
|
321
|
+
... 2: "epsilon", # Middle layer
|
|
322
|
+
... 4: "epsilon" # Final layer
|
|
323
|
+
... })
|
|
324
|
+
"""
|
|
325
|
+
# Validate rules
|
|
326
|
+
for idx, rule in layer_rules.items():
|
|
327
|
+
if rule not in VALID_RULES and rule != "composite":
|
|
328
|
+
raise ValueError(f"Invalid rule '{rule}' for layer {idx}")
|
|
329
|
+
|
|
330
|
+
self._layer_rules = layer_rules
|
|
331
|
+
return self
|
|
332
|
+
|
|
333
|
+
# =========================================================================
|
|
334
|
+
# Linear Layer LRP Rules
|
|
335
|
+
# =========================================================================
|
|
336
|
+
|
|
337
|
+
def _lrp_linear_epsilon(
|
|
338
|
+
self,
|
|
339
|
+
layer: nn.Linear,
|
|
340
|
+
activation: torch.Tensor,
|
|
341
|
+
relevance: torch.Tensor,
|
|
342
|
+
epsilon: float
|
|
343
|
+
) -> torch.Tensor:
|
|
344
|
+
"""
|
|
345
|
+
LRP-epsilon rule for linear layers.
|
|
346
|
+
|
|
347
|
+
R_j = Σ_k (a_j * w_jk / (z_k + ε*sign(z_k))) * R_k
|
|
348
|
+
"""
|
|
349
|
+
# Forward pass to get z
|
|
350
|
+
z = torch.mm(activation, layer.weight.t())
|
|
351
|
+
if layer.bias is not None:
|
|
352
|
+
z = z + layer.bias
|
|
353
|
+
|
|
354
|
+
# Stabilize: z + epsilon * sign(z)
|
|
355
|
+
z_stabilized = z + epsilon * torch.sign(z)
|
|
356
|
+
z_stabilized = torch.where(
|
|
357
|
+
torch.abs(z_stabilized) < epsilon,
|
|
358
|
+
torch.full_like(z_stabilized, epsilon),
|
|
359
|
+
z_stabilized
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
# Compute relevance contribution: (R / z_stabilized) @ W
|
|
363
|
+
s = relevance / z_stabilized
|
|
364
|
+
c = torch.mm(s, layer.weight)
|
|
365
|
+
|
|
366
|
+
return activation * c
|
|
367
|
+
|
|
368
|
+
def _lrp_linear_gamma(
|
|
369
|
+
self,
|
|
370
|
+
layer: nn.Linear,
|
|
371
|
+
activation: torch.Tensor,
|
|
372
|
+
relevance: torch.Tensor,
|
|
373
|
+
gamma: float
|
|
374
|
+
) -> torch.Tensor:
|
|
375
|
+
"""
|
|
376
|
+
LRP-gamma rule for linear layers.
|
|
377
|
+
Enhances positive contributions for sharper attributions.
|
|
378
|
+
"""
|
|
379
|
+
w_plus = torch.clamp(layer.weight, min=0)
|
|
380
|
+
w_modified = layer.weight + gamma * w_plus
|
|
381
|
+
|
|
382
|
+
z = torch.mm(activation, w_modified.t())
|
|
383
|
+
if layer.bias is not None:
|
|
384
|
+
b_plus = torch.clamp(layer.bias, min=0)
|
|
385
|
+
z = z + layer.bias + gamma * b_plus
|
|
386
|
+
|
|
387
|
+
z_stabilized = z + self.epsilon * torch.sign(z)
|
|
388
|
+
z_stabilized = torch.where(
|
|
389
|
+
torch.abs(z_stabilized) < self.epsilon,
|
|
390
|
+
torch.full_like(z_stabilized, self.epsilon),
|
|
391
|
+
z_stabilized
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
s = relevance / z_stabilized
|
|
395
|
+
c = torch.mm(s, w_modified)
|
|
396
|
+
|
|
397
|
+
return activation * c
|
|
398
|
+
|
|
399
|
+
def _lrp_linear_alpha_beta(
|
|
400
|
+
self,
|
|
401
|
+
layer: nn.Linear,
|
|
402
|
+
activation: torch.Tensor,
|
|
403
|
+
relevance: torch.Tensor,
|
|
404
|
+
alpha: float,
|
|
405
|
+
beta: float
|
|
406
|
+
) -> torch.Tensor:
|
|
407
|
+
"""
|
|
408
|
+
LRP-alpha-beta rule for linear layers.
|
|
409
|
+
Separates positive and negative contributions.
|
|
410
|
+
"""
|
|
411
|
+
w_plus = torch.clamp(layer.weight, min=0)
|
|
412
|
+
w_minus = torch.clamp(layer.weight, max=0)
|
|
413
|
+
a_plus = torch.clamp(activation, min=0)
|
|
414
|
+
|
|
415
|
+
z_plus = torch.mm(a_plus, w_plus.t())
|
|
416
|
+
if layer.bias is not None:
|
|
417
|
+
z_plus = z_plus + torch.clamp(layer.bias, min=0)
|
|
418
|
+
|
|
419
|
+
z_minus = torch.mm(a_plus, w_minus.t())
|
|
420
|
+
if layer.bias is not None:
|
|
421
|
+
z_minus = z_minus + torch.clamp(layer.bias, max=0)
|
|
422
|
+
|
|
423
|
+
z_plus_stable = z_plus + self.epsilon
|
|
424
|
+
z_minus_stable = z_minus - self.epsilon
|
|
425
|
+
z_minus_stable = torch.where(
|
|
426
|
+
torch.abs(z_minus_stable) < self.epsilon,
|
|
427
|
+
torch.full_like(z_minus_stable, -self.epsilon),
|
|
428
|
+
z_minus_stable
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
s_plus = relevance / z_plus_stable
|
|
432
|
+
s_minus = relevance / z_minus_stable
|
|
433
|
+
|
|
434
|
+
c_plus = torch.mm(s_plus, w_plus)
|
|
435
|
+
c_minus = torch.mm(s_minus, w_minus)
|
|
436
|
+
|
|
437
|
+
return alpha * a_plus * c_plus - beta * a_plus * c_minus
|
|
438
|
+
|
|
439
|
+
def _lrp_linear_z_plus(
|
|
440
|
+
self,
|
|
441
|
+
layer: nn.Linear,
|
|
442
|
+
activation: torch.Tensor,
|
|
443
|
+
relevance: torch.Tensor
|
|
444
|
+
) -> torch.Tensor:
|
|
445
|
+
"""
|
|
446
|
+
LRP-z+ rule for linear layers.
|
|
447
|
+
Only considers positive weights.
|
|
448
|
+
"""
|
|
449
|
+
w_plus = torch.clamp(layer.weight, min=0)
|
|
450
|
+
a_plus = torch.clamp(activation, min=0)
|
|
451
|
+
|
|
452
|
+
z_plus = torch.mm(a_plus, w_plus.t())
|
|
453
|
+
if layer.bias is not None:
|
|
454
|
+
z_plus = z_plus + torch.clamp(layer.bias, min=0)
|
|
455
|
+
|
|
456
|
+
z_plus_stable = z_plus + self.epsilon
|
|
457
|
+
|
|
458
|
+
s = relevance / z_plus_stable
|
|
459
|
+
c = torch.mm(s, w_plus)
|
|
460
|
+
|
|
461
|
+
return a_plus * c
|
|
462
|
+
|
|
463
|
+
def _propagate_linear(
|
|
464
|
+
self,
|
|
465
|
+
layer: nn.Linear,
|
|
466
|
+
activation: torch.Tensor,
|
|
467
|
+
relevance: torch.Tensor,
|
|
468
|
+
rule: str
|
|
469
|
+
) -> torch.Tensor:
|
|
470
|
+
"""Propagate relevance through a linear layer."""
|
|
471
|
+
if rule == "epsilon":
|
|
472
|
+
return self._lrp_linear_epsilon(layer, activation, relevance, self.epsilon)
|
|
473
|
+
elif rule == "gamma":
|
|
474
|
+
return self._lrp_linear_gamma(layer, activation, relevance, self.gamma)
|
|
475
|
+
elif rule == "alpha_beta":
|
|
476
|
+
return self._lrp_linear_alpha_beta(layer, activation, relevance, self.alpha, self.beta)
|
|
477
|
+
elif rule == "z_plus":
|
|
478
|
+
return self._lrp_linear_z_plus(layer, activation, relevance)
|
|
479
|
+
else:
|
|
480
|
+
return self._lrp_linear_epsilon(layer, activation, relevance, self.epsilon)
|
|
481
|
+
|
|
482
|
+
# =========================================================================
|
|
483
|
+
# Conv2d Layer LRP Rules
|
|
484
|
+
# =========================================================================
|
|
485
|
+
|
|
486
|
+
def _lrp_conv2d_epsilon(
|
|
487
|
+
self,
|
|
488
|
+
layer: nn.Conv2d,
|
|
489
|
+
activation: torch.Tensor,
|
|
490
|
+
relevance: torch.Tensor,
|
|
491
|
+
epsilon: float
|
|
492
|
+
) -> torch.Tensor:
|
|
493
|
+
"""
|
|
494
|
+
LRP-epsilon rule for Conv2d layers.
|
|
495
|
+
Uses convolution transpose for backward relevance propagation.
|
|
496
|
+
"""
|
|
497
|
+
# Forward pass to get z
|
|
498
|
+
z = F.conv2d(
|
|
499
|
+
activation,
|
|
500
|
+
layer.weight,
|
|
501
|
+
bias=layer.bias,
|
|
502
|
+
stride=layer.stride,
|
|
503
|
+
padding=layer.padding,
|
|
504
|
+
dilation=layer.dilation,
|
|
505
|
+
groups=layer.groups
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
# Stabilize
|
|
509
|
+
z_stabilized = z + epsilon * torch.sign(z)
|
|
510
|
+
z_stabilized = torch.where(
|
|
511
|
+
torch.abs(z_stabilized) < epsilon,
|
|
512
|
+
torch.full_like(z_stabilized, epsilon),
|
|
513
|
+
z_stabilized
|
|
514
|
+
)
|
|
515
|
+
|
|
516
|
+
# Compute s = R / z
|
|
517
|
+
s = relevance / z_stabilized
|
|
518
|
+
|
|
519
|
+
# Backward pass using conv_transpose2d
|
|
520
|
+
c = F.conv_transpose2d(
|
|
521
|
+
s,
|
|
522
|
+
layer.weight,
|
|
523
|
+
bias=None,
|
|
524
|
+
stride=layer.stride,
|
|
525
|
+
padding=layer.padding,
|
|
526
|
+
output_padding=0,
|
|
527
|
+
groups=layer.groups,
|
|
528
|
+
dilation=layer.dilation
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
# Handle output size mismatch
|
|
532
|
+
if c.shape != activation.shape:
|
|
533
|
+
# Pad or crop to match activation shape
|
|
534
|
+
diff_h = activation.shape[2] - c.shape[2]
|
|
535
|
+
diff_w = activation.shape[3] - c.shape[3]
|
|
536
|
+
if diff_h > 0 or diff_w > 0:
|
|
537
|
+
c = F.pad(c, [0, max(0, diff_w), 0, max(0, diff_h)])
|
|
538
|
+
if diff_h < 0 or diff_w < 0:
|
|
539
|
+
c = c[:, :, :activation.shape[2], :activation.shape[3]]
|
|
540
|
+
|
|
541
|
+
return activation * c
|
|
542
|
+
|
|
543
|
+
def _lrp_conv2d_gamma(
|
|
544
|
+
self,
|
|
545
|
+
layer: nn.Conv2d,
|
|
546
|
+
activation: torch.Tensor,
|
|
547
|
+
relevance: torch.Tensor,
|
|
548
|
+
gamma: float
|
|
549
|
+
) -> torch.Tensor:
|
|
550
|
+
"""LRP-gamma rule for Conv2d layers."""
|
|
551
|
+
w_plus = torch.clamp(layer.weight, min=0)
|
|
552
|
+
w_modified = layer.weight + gamma * w_plus
|
|
553
|
+
|
|
554
|
+
z = F.conv2d(
|
|
555
|
+
activation,
|
|
556
|
+
w_modified,
|
|
557
|
+
bias=layer.bias,
|
|
558
|
+
stride=layer.stride,
|
|
559
|
+
padding=layer.padding,
|
|
560
|
+
dilation=layer.dilation,
|
|
561
|
+
groups=layer.groups
|
|
562
|
+
)
|
|
563
|
+
|
|
564
|
+
if layer.bias is not None:
|
|
565
|
+
b_plus = torch.clamp(layer.bias, min=0)
|
|
566
|
+
# Bias is already added in conv2d, add gamma * b_plus
|
|
567
|
+
z = z + gamma * b_plus.view(1, -1, 1, 1)
|
|
568
|
+
|
|
569
|
+
z_stabilized = z + self.epsilon * torch.sign(z)
|
|
570
|
+
z_stabilized = torch.where(
|
|
571
|
+
torch.abs(z_stabilized) < self.epsilon,
|
|
572
|
+
torch.full_like(z_stabilized, self.epsilon),
|
|
573
|
+
z_stabilized
|
|
574
|
+
)
|
|
575
|
+
|
|
576
|
+
s = relevance / z_stabilized
|
|
577
|
+
|
|
578
|
+
c = F.conv_transpose2d(
|
|
579
|
+
s,
|
|
580
|
+
w_modified,
|
|
581
|
+
bias=None,
|
|
582
|
+
stride=layer.stride,
|
|
583
|
+
padding=layer.padding,
|
|
584
|
+
output_padding=0,
|
|
585
|
+
groups=layer.groups,
|
|
586
|
+
dilation=layer.dilation
|
|
587
|
+
)
|
|
588
|
+
|
|
589
|
+
if c.shape != activation.shape:
|
|
590
|
+
diff_h = activation.shape[2] - c.shape[2]
|
|
591
|
+
diff_w = activation.shape[3] - c.shape[3]
|
|
592
|
+
if diff_h > 0 or diff_w > 0:
|
|
593
|
+
c = F.pad(c, [0, max(0, diff_w), 0, max(0, diff_h)])
|
|
594
|
+
if diff_h < 0 or diff_w < 0:
|
|
595
|
+
c = c[:, :, :activation.shape[2], :activation.shape[3]]
|
|
596
|
+
|
|
597
|
+
return activation * c
|
|
598
|
+
|
|
599
|
+
def _lrp_conv2d_z_plus(
|
|
600
|
+
self,
|
|
601
|
+
layer: nn.Conv2d,
|
|
602
|
+
activation: torch.Tensor,
|
|
603
|
+
relevance: torch.Tensor
|
|
604
|
+
) -> torch.Tensor:
|
|
605
|
+
"""LRP-z+ rule for Conv2d layers."""
|
|
606
|
+
w_plus = torch.clamp(layer.weight, min=0)
|
|
607
|
+
a_plus = torch.clamp(activation, min=0)
|
|
608
|
+
|
|
609
|
+
z_plus = F.conv2d(
|
|
610
|
+
a_plus,
|
|
611
|
+
w_plus,
|
|
612
|
+
bias=None, # Ignore bias for z+
|
|
613
|
+
stride=layer.stride,
|
|
614
|
+
padding=layer.padding,
|
|
615
|
+
dilation=layer.dilation,
|
|
616
|
+
groups=layer.groups
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
if layer.bias is not None:
|
|
620
|
+
z_plus = z_plus + torch.clamp(layer.bias, min=0).view(1, -1, 1, 1)
|
|
621
|
+
|
|
622
|
+
z_plus_stable = z_plus + self.epsilon
|
|
623
|
+
|
|
624
|
+
s = relevance / z_plus_stable
|
|
625
|
+
|
|
626
|
+
c = F.conv_transpose2d(
|
|
627
|
+
s,
|
|
628
|
+
w_plus,
|
|
629
|
+
bias=None,
|
|
630
|
+
stride=layer.stride,
|
|
631
|
+
padding=layer.padding,
|
|
632
|
+
output_padding=0,
|
|
633
|
+
groups=layer.groups,
|
|
634
|
+
dilation=layer.dilation
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
if c.shape != a_plus.shape:
|
|
638
|
+
diff_h = a_plus.shape[2] - c.shape[2]
|
|
639
|
+
diff_w = a_plus.shape[3] - c.shape[3]
|
|
640
|
+
if diff_h > 0 or diff_w > 0:
|
|
641
|
+
c = F.pad(c, [0, max(0, diff_w), 0, max(0, diff_h)])
|
|
642
|
+
if diff_h < 0 or diff_w < 0:
|
|
643
|
+
c = c[:, :, :a_plus.shape[2], :a_plus.shape[3]]
|
|
644
|
+
|
|
645
|
+
return a_plus * c
|
|
646
|
+
|
|
647
|
+
def _propagate_conv2d(
|
|
648
|
+
self,
|
|
649
|
+
layer: nn.Conv2d,
|
|
650
|
+
activation: torch.Tensor,
|
|
651
|
+
relevance: torch.Tensor,
|
|
652
|
+
rule: str
|
|
653
|
+
) -> torch.Tensor:
|
|
654
|
+
"""Propagate relevance through a Conv2d layer."""
|
|
655
|
+
if rule == "epsilon":
|
|
656
|
+
return self._lrp_conv2d_epsilon(layer, activation, relevance, self.epsilon)
|
|
657
|
+
elif rule == "gamma":
|
|
658
|
+
return self._lrp_conv2d_gamma(layer, activation, relevance, self.gamma)
|
|
659
|
+
elif rule == "z_plus":
|
|
660
|
+
return self._lrp_conv2d_z_plus(layer, activation, relevance)
|
|
661
|
+
elif rule == "alpha_beta":
|
|
662
|
+
# Alpha-beta for conv is complex, fall back to epsilon
|
|
663
|
+
return self._lrp_conv2d_epsilon(layer, activation, relevance, self.epsilon)
|
|
664
|
+
else:
|
|
665
|
+
return self._lrp_conv2d_epsilon(layer, activation, relevance, self.epsilon)
|
|
666
|
+
|
|
667
|
+
# =========================================================================
|
|
668
|
+
# BatchNorm Layer LRP Rules
|
|
669
|
+
# =========================================================================
|
|
670
|
+
|
|
671
|
+
def _propagate_batchnorm(
|
|
672
|
+
self,
|
|
673
|
+
layer: Union[nn.BatchNorm1d, nn.BatchNorm2d],
|
|
674
|
+
activation: torch.Tensor,
|
|
675
|
+
relevance: torch.Tensor
|
|
676
|
+
) -> torch.Tensor:
|
|
677
|
+
"""
|
|
678
|
+
Propagate relevance through BatchNorm layer.
|
|
679
|
+
|
|
680
|
+
BatchNorm is an affine transformation: y = gamma * (x - mean) / std + beta
|
|
681
|
+
We treat it as a linear scaling and propagate relevance proportionally.
|
|
682
|
+
"""
|
|
683
|
+
# Get BatchNorm parameters
|
|
684
|
+
if layer.running_mean is None or layer.running_var is None:
|
|
685
|
+
# If no running stats, pass through
|
|
686
|
+
return relevance
|
|
687
|
+
|
|
688
|
+
mean = layer.running_mean
|
|
689
|
+
var = layer.running_var
|
|
690
|
+
eps = layer.eps
|
|
691
|
+
|
|
692
|
+
# Compute the effective scale factor
|
|
693
|
+
std = torch.sqrt(var + eps)
|
|
694
|
+
|
|
695
|
+
if layer.weight is not None:
|
|
696
|
+
scale = layer.weight / std
|
|
697
|
+
else:
|
|
698
|
+
scale = 1.0 / std
|
|
699
|
+
|
|
700
|
+
# Reshape scale for broadcasting
|
|
701
|
+
if isinstance(layer, nn.BatchNorm2d):
|
|
702
|
+
scale = scale.view(1, -1, 1, 1)
|
|
703
|
+
else:
|
|
704
|
+
scale = scale.view(1, -1)
|
|
705
|
+
|
|
706
|
+
# Relevance propagation: R_input = R_output (scaled back)
|
|
707
|
+
# Since BN is essentially a rescaling, we redistribute proportionally
|
|
708
|
+
return relevance / (scale + self.epsilon * torch.sign(scale))
|
|
709
|
+
|
|
710
|
+
# =========================================================================
|
|
711
|
+
# Activation Layer LRP Rules
|
|
712
|
+
# =========================================================================
|
|
713
|
+
|
|
714
|
+
def _propagate_activation(
|
|
715
|
+
self,
|
|
716
|
+
layer: nn.Module,
|
|
717
|
+
activation: torch.Tensor,
|
|
718
|
+
relevance: torch.Tensor
|
|
719
|
+
) -> torch.Tensor:
|
|
720
|
+
"""
|
|
721
|
+
Propagate relevance through activation layers (ReLU, etc.).
|
|
722
|
+
|
|
723
|
+
For element-wise activations, relevance passes through unchanged
|
|
724
|
+
to locations where the activation was positive.
|
|
725
|
+
"""
|
|
726
|
+
if isinstance(layer, nn.ReLU):
|
|
727
|
+
# ReLU: pass relevance where input was positive
|
|
728
|
+
# Since we have post-activation values, we use them as mask
|
|
729
|
+
mask = (activation > 0).float()
|
|
730
|
+
return relevance * mask + relevance * (1 - mask) # Actually just pass through
|
|
731
|
+
elif isinstance(layer, (nn.LeakyReLU, nn.ELU)):
|
|
732
|
+
# For leaky activations, relevance passes through
|
|
733
|
+
return relevance
|
|
734
|
+
elif isinstance(layer, (nn.Tanh, nn.Sigmoid)):
|
|
735
|
+
# For bounded activations, relevance passes through
|
|
736
|
+
return relevance
|
|
737
|
+
else:
|
|
738
|
+
# Default: pass through
|
|
739
|
+
return relevance
|
|
740
|
+
|
|
741
|
+
# =========================================================================
|
|
742
|
+
# Pooling Layer LRP Rules
|
|
743
|
+
# =========================================================================
|
|
744
|
+
|
|
745
|
+
def _propagate_maxpool2d(
|
|
746
|
+
self,
|
|
747
|
+
layer: nn.MaxPool2d,
|
|
748
|
+
activation: torch.Tensor,
|
|
749
|
+
relevance: torch.Tensor
|
|
750
|
+
) -> torch.Tensor:
|
|
751
|
+
"""
|
|
752
|
+
Propagate relevance through MaxPool2d.
|
|
753
|
+
|
|
754
|
+
Relevance is distributed to the max locations (winner-take-all).
|
|
755
|
+
"""
|
|
756
|
+
# Forward pass to get indices
|
|
757
|
+
_, indices = F.max_pool2d(
|
|
758
|
+
activation,
|
|
759
|
+
kernel_size=layer.kernel_size,
|
|
760
|
+
stride=layer.stride,
|
|
761
|
+
padding=layer.padding,
|
|
762
|
+
dilation=layer.dilation,
|
|
763
|
+
return_indices=True,
|
|
764
|
+
ceil_mode=layer.ceil_mode
|
|
765
|
+
)
|
|
766
|
+
|
|
767
|
+
# Unpool: place relevance at max locations
|
|
768
|
+
unpooled = F.max_unpool2d(
|
|
769
|
+
relevance,
|
|
770
|
+
indices,
|
|
771
|
+
kernel_size=layer.kernel_size,
|
|
772
|
+
stride=layer.stride,
|
|
773
|
+
padding=layer.padding,
|
|
774
|
+
output_size=activation.shape
|
|
775
|
+
)
|
|
776
|
+
|
|
777
|
+
return unpooled
|
|
778
|
+
|
|
779
|
+
def _propagate_avgpool2d(
|
|
780
|
+
self,
|
|
781
|
+
layer: Union[nn.AvgPool2d, nn.AdaptiveAvgPool2d],
|
|
782
|
+
activation: torch.Tensor,
|
|
783
|
+
relevance: torch.Tensor
|
|
784
|
+
) -> torch.Tensor:
|
|
785
|
+
"""
|
|
786
|
+
Propagate relevance through AvgPool2d.
|
|
787
|
+
|
|
788
|
+
Relevance is distributed uniformly across pooling regions.
|
|
789
|
+
"""
|
|
790
|
+
if isinstance(layer, nn.AdaptiveAvgPool2d):
|
|
791
|
+
# For adaptive pooling, upsample relevance to input size
|
|
792
|
+
return F.interpolate(
|
|
793
|
+
relevance,
|
|
794
|
+
size=activation.shape[2:],
|
|
795
|
+
mode='nearest'
|
|
796
|
+
)
|
|
797
|
+
else:
|
|
798
|
+
# For regular avg pooling, use nearest neighbor upsampling
|
|
799
|
+
# and scale by pool area
|
|
800
|
+
kernel_size = layer.kernel_size if isinstance(layer.kernel_size, tuple) else (layer.kernel_size, layer.kernel_size)
|
|
801
|
+
|
|
802
|
+
upsampled = F.interpolate(
|
|
803
|
+
relevance,
|
|
804
|
+
size=activation.shape[2:],
|
|
805
|
+
mode='nearest'
|
|
806
|
+
)
|
|
807
|
+
|
|
808
|
+
return upsampled
|
|
809
|
+
|
|
810
|
+
# =========================================================================
|
|
811
|
+
# Main LRP Computation
|
|
812
|
+
# =========================================================================
|
|
813
|
+
|
|
814
|
+
def _compute_lrp(
|
|
815
|
+
self,
|
|
816
|
+
instance: np.ndarray,
|
|
817
|
+
target_class: Optional[int] = None,
|
|
818
|
+
return_layer_relevances: bool = False
|
|
819
|
+
) -> Union[np.ndarray, Tuple[np.ndarray, Dict[str, np.ndarray]]]:
|
|
820
|
+
"""
|
|
821
|
+
Compute LRP attributions for a single instance.
|
|
822
|
+
|
|
823
|
+
Args:
|
|
824
|
+
instance: Input instance (1D or multi-dimensional array)
|
|
825
|
+
target_class: Target class for relevance initialization
|
|
826
|
+
return_layer_relevances: If True, also return relevances at each layer
|
|
827
|
+
|
|
828
|
+
Returns:
|
|
829
|
+
If return_layer_relevances is False:
|
|
830
|
+
Array of attribution scores for input features
|
|
831
|
+
If return_layer_relevances is True:
|
|
832
|
+
Tuple of (input_attributions, layer_relevances_dict)
|
|
833
|
+
"""
|
|
834
|
+
model = self._get_pytorch_model()
|
|
835
|
+
model.eval()
|
|
836
|
+
|
|
837
|
+
# Prepare input with correct shape for model type (CNN vs MLP)
|
|
838
|
+
x = self._prepare_input_tensor(instance)
|
|
839
|
+
x.requires_grad_(False)
|
|
840
|
+
|
|
841
|
+
# =====================================================================
|
|
842
|
+
# Forward pass: collect activations at each layer
|
|
843
|
+
# =====================================================================
|
|
844
|
+
activations = OrderedDict()
|
|
845
|
+
activations["input"] = x.clone()
|
|
846
|
+
|
|
847
|
+
layer_list = [] # List of (idx, name, layer, input_activation)
|
|
848
|
+
|
|
849
|
+
current = x
|
|
850
|
+
|
|
851
|
+
# Handle Sequential models
|
|
852
|
+
if isinstance(model, nn.Sequential):
|
|
853
|
+
for idx, (name, layer) in enumerate(model.named_children()):
|
|
854
|
+
layer_list.append((idx, name, layer, current.clone()))
|
|
855
|
+
current = layer(current)
|
|
856
|
+
activations[f"layer_{idx}_{name}"] = current.clone()
|
|
857
|
+
else:
|
|
858
|
+
# For non-Sequential models, use hooks
|
|
859
|
+
hooks = []
|
|
860
|
+
layer_data = OrderedDict()
|
|
861
|
+
|
|
862
|
+
def make_hook(name):
|
|
863
|
+
def hook(module, input, output):
|
|
864
|
+
inp = input[0] if isinstance(input, tuple) else input
|
|
865
|
+
layer_data[name] = {
|
|
866
|
+
"input": inp.clone().detach(),
|
|
867
|
+
"output": output.clone().detach() if isinstance(output, torch.Tensor) else output
|
|
868
|
+
}
|
|
869
|
+
return hook
|
|
870
|
+
|
|
871
|
+
# Register hooks on relevant layers
|
|
872
|
+
idx = 0
|
|
873
|
+
for name, module in model.named_modules():
|
|
874
|
+
if isinstance(module, (*WEIGHTED_LAYERS, *NORMALIZATION_LAYERS, *ACTIVATION_LAYERS, *POOLING_LAYERS)):
|
|
875
|
+
hooks.append(module.register_forward_hook(make_hook(f"{idx}_{name}")))
|
|
876
|
+
idx += 1
|
|
877
|
+
|
|
878
|
+
# Forward pass
|
|
879
|
+
current = model(x)
|
|
880
|
+
|
|
881
|
+
# Remove hooks
|
|
882
|
+
for h in hooks:
|
|
883
|
+
h.remove()
|
|
884
|
+
|
|
885
|
+
# Build layer list from collected data
|
|
886
|
+
for name, data in layer_data.items():
|
|
887
|
+
idx_str, layer_name = name.split("_", 1)
|
|
888
|
+
# Get the actual module
|
|
889
|
+
module = dict(model.named_modules()).get(layer_name)
|
|
890
|
+
if module is not None:
|
|
891
|
+
layer_list.append((int(idx_str), layer_name, module, data["input"]))
|
|
892
|
+
|
|
893
|
+
output = current
|
|
894
|
+
|
|
895
|
+
# =====================================================================
|
|
896
|
+
# Initialize relevance at output layer
|
|
897
|
+
# =====================================================================
|
|
898
|
+
if target_class is not None:
|
|
899
|
+
relevance = torch.zeros_like(output)
|
|
900
|
+
relevance[0, target_class] = output[0, target_class]
|
|
901
|
+
else:
|
|
902
|
+
relevance = output.clone()
|
|
903
|
+
|
|
904
|
+
# =====================================================================
|
|
905
|
+
# Backward pass: propagate relevance through layers
|
|
906
|
+
# =====================================================================
|
|
907
|
+
layer_relevances = OrderedDict()
|
|
908
|
+
layer_relevances["output"] = relevance.detach().cpu().numpy().flatten()
|
|
909
|
+
|
|
910
|
+
# Reverse through layers
|
|
911
|
+
for idx, name, layer, activation in reversed(layer_list):
|
|
912
|
+
rule = self._get_rule_for_layer(idx, type(layer).__name__)
|
|
913
|
+
|
|
914
|
+
# Propagate based on layer type
|
|
915
|
+
if isinstance(layer, nn.Linear):
|
|
916
|
+
# Flatten activation if needed
|
|
917
|
+
if activation.dim() > 2:
|
|
918
|
+
activation = activation.flatten(1)
|
|
919
|
+
if relevance.dim() > 2:
|
|
920
|
+
relevance = relevance.flatten(1)
|
|
921
|
+
relevance = self._propagate_linear(layer, activation, relevance, rule)
|
|
922
|
+
|
|
923
|
+
elif isinstance(layer, nn.Conv2d):
|
|
924
|
+
relevance = self._propagate_conv2d(layer, activation, relevance, rule)
|
|
925
|
+
|
|
926
|
+
elif isinstance(layer, NORMALIZATION_LAYERS):
|
|
927
|
+
relevance = self._propagate_batchnorm(layer, activation, relevance)
|
|
928
|
+
|
|
929
|
+
elif isinstance(layer, ACTIVATION_LAYERS):
|
|
930
|
+
relevance = self._propagate_activation(layer, activation, relevance)
|
|
931
|
+
|
|
932
|
+
elif isinstance(layer, nn.MaxPool2d):
|
|
933
|
+
relevance = self._propagate_maxpool2d(layer, activation, relevance)
|
|
934
|
+
|
|
935
|
+
elif isinstance(layer, (nn.AvgPool2d, nn.AdaptiveAvgPool2d)):
|
|
936
|
+
relevance = self._propagate_avgpool2d(layer, activation, relevance)
|
|
937
|
+
|
|
938
|
+
elif isinstance(layer, nn.Flatten):
|
|
939
|
+
# Reshape relevance back to pre-flatten shape
|
|
940
|
+
if activation.dim() > 2:
|
|
941
|
+
relevance = relevance.reshape(activation.shape)
|
|
942
|
+
|
|
943
|
+
elif isinstance(layer, nn.Unflatten):
|
|
944
|
+
# Unflatten in forward expands dimensions: (batch, features) -> (batch, *dims)
|
|
945
|
+
# In backward, reshape relevance to match the flattened input activation
|
|
946
|
+
relevance = relevance.view(activation.shape)
|
|
947
|
+
|
|
948
|
+
elif isinstance(layer, PASSTHROUGH_LAYERS):
|
|
949
|
+
# Dropout and other passthrough layers
|
|
950
|
+
pass
|
|
951
|
+
|
|
952
|
+
layer_relevances[f"layer_{idx}_{name}"] = relevance.detach().cpu().numpy().flatten()
|
|
953
|
+
|
|
954
|
+
# Final relevance is the input attribution
|
|
955
|
+
input_relevance = relevance.detach().cpu().numpy().flatten()
|
|
956
|
+
layer_relevances["input"] = input_relevance
|
|
957
|
+
|
|
958
|
+
if return_layer_relevances:
|
|
959
|
+
return input_relevance, layer_relevances
|
|
960
|
+
return input_relevance
|
|
961
|
+
|
|
962
|
+
def explain(
|
|
963
|
+
self,
|
|
964
|
+
instance: np.ndarray,
|
|
965
|
+
target_class: Optional[int] = None,
|
|
966
|
+
return_convergence_delta: bool = False
|
|
967
|
+
) -> Explanation:
|
|
968
|
+
"""
|
|
969
|
+
Generate LRP explanation for an instance.
|
|
970
|
+
|
|
971
|
+
Args:
|
|
972
|
+
instance: Numpy array of input features (1D for tabular,
|
|
973
|
+
or multi-dimensional for images).
|
|
974
|
+
target_class: For classification, which class to explain.
|
|
975
|
+
If None, uses the predicted class.
|
|
976
|
+
return_convergence_delta: If True, include the convergence delta
|
|
977
|
+
(difference between sum of attributions and target output).
|
|
978
|
+
Should be close to 0 for correct LRP (conservation property).
|
|
979
|
+
|
|
980
|
+
Returns:
|
|
981
|
+
Explanation object with feature attributions.
|
|
982
|
+
|
|
983
|
+
Example:
|
|
984
|
+
>>> explanation = explainer.explain(instance)
|
|
985
|
+
>>> print(explanation.explanation_data["feature_attributions"])
|
|
986
|
+
"""
|
|
987
|
+
instance = np.array(instance).astype(np.float32)
|
|
988
|
+
original_shape = instance.shape
|
|
989
|
+
instance_flat = instance.flatten()
|
|
990
|
+
|
|
991
|
+
# Determine target class if not specified
|
|
992
|
+
if target_class is None and self.class_names:
|
|
993
|
+
# Get prediction using properly shaped input
|
|
994
|
+
model = self._get_pytorch_model()
|
|
995
|
+
model.eval()
|
|
996
|
+
with torch.no_grad():
|
|
997
|
+
x = self._prepare_input_tensor(instance)
|
|
998
|
+
output = model(x)
|
|
999
|
+
target_class = int(torch.argmax(output, dim=1).item())
|
|
1000
|
+
|
|
1001
|
+
# Compute LRP attributions
|
|
1002
|
+
attributions_raw = self._compute_lrp(instance, target_class)
|
|
1003
|
+
|
|
1004
|
+
# Build attributions dict
|
|
1005
|
+
if len(self.feature_names) == len(attributions_raw):
|
|
1006
|
+
attributions = {
|
|
1007
|
+
fname: float(attributions_raw[i])
|
|
1008
|
+
for i, fname in enumerate(self.feature_names)
|
|
1009
|
+
}
|
|
1010
|
+
else:
|
|
1011
|
+
# For images or mismatched feature names, use indices
|
|
1012
|
+
attributions = {
|
|
1013
|
+
f"feature_{i}": float(attributions_raw[i])
|
|
1014
|
+
for i in range(len(attributions_raw))
|
|
1015
|
+
}
|
|
1016
|
+
|
|
1017
|
+
# Determine class name
|
|
1018
|
+
if self.class_names and target_class is not None:
|
|
1019
|
+
label_name = self.class_names[target_class]
|
|
1020
|
+
else:
|
|
1021
|
+
label_name = f"class_{target_class}" if target_class is not None else "output"
|
|
1022
|
+
|
|
1023
|
+
explanation_data = {
|
|
1024
|
+
"feature_attributions": attributions,
|
|
1025
|
+
"attributions_raw": [float(x) for x in attributions_raw],
|
|
1026
|
+
"rule": self.rule,
|
|
1027
|
+
"epsilon": self.epsilon if self.rule in ["epsilon", "composite"] else None,
|
|
1028
|
+
"gamma": self.gamma if self.rule == "gamma" else None,
|
|
1029
|
+
"alpha": self.alpha if self.rule == "alpha_beta" else None,
|
|
1030
|
+
"beta": self.beta if self.rule == "alpha_beta" else None,
|
|
1031
|
+
"input_shape": list(original_shape)
|
|
1032
|
+
}
|
|
1033
|
+
|
|
1034
|
+
# Compute convergence delta (conservation check)
|
|
1035
|
+
if return_convergence_delta:
|
|
1036
|
+
model = self._get_pytorch_model()
|
|
1037
|
+
model.eval()
|
|
1038
|
+
|
|
1039
|
+
with torch.no_grad():
|
|
1040
|
+
# Use the helper method to get properly shaped input
|
|
1041
|
+
x = self._prepare_input_tensor(instance)
|
|
1042
|
+
output = model(x)
|
|
1043
|
+
|
|
1044
|
+
if target_class is not None:
|
|
1045
|
+
target_output = output[0, target_class].item()
|
|
1046
|
+
else:
|
|
1047
|
+
target_output = output.sum().item()
|
|
1048
|
+
|
|
1049
|
+
attribution_sum = sum(attributions.values())
|
|
1050
|
+
convergence_delta = abs(target_output - attribution_sum)
|
|
1051
|
+
|
|
1052
|
+
explanation_data["target_output"] = float(target_output)
|
|
1053
|
+
explanation_data["attribution_sum"] = float(attribution_sum)
|
|
1054
|
+
explanation_data["convergence_delta"] = float(convergence_delta)
|
|
1055
|
+
|
|
1056
|
+
return Explanation(
|
|
1057
|
+
explainer_name="LRP",
|
|
1058
|
+
target_class=label_name,
|
|
1059
|
+
explanation_data=explanation_data,
|
|
1060
|
+
feature_names=self.feature_names
|
|
1061
|
+
)
|
|
1062
|
+
|
|
1063
|
+
def explain_batch(
|
|
1064
|
+
self,
|
|
1065
|
+
X: np.ndarray,
|
|
1066
|
+
target_class: Optional[int] = None
|
|
1067
|
+
) -> List[Explanation]:
|
|
1068
|
+
"""
|
|
1069
|
+
Generate explanations for multiple instances.
|
|
1070
|
+
|
|
1071
|
+
Args:
|
|
1072
|
+
X: Array of instances. For tabular: (n_samples, n_features).
|
|
1073
|
+
For images: (n_samples, channels, height, width) or similar.
|
|
1074
|
+
target_class: Target class for all instances. If None,
|
|
1075
|
+
uses predicted class for each instance.
|
|
1076
|
+
|
|
1077
|
+
Returns:
|
|
1078
|
+
List of Explanation objects.
|
|
1079
|
+
"""
|
|
1080
|
+
X = np.array(X)
|
|
1081
|
+
|
|
1082
|
+
# Handle single instance
|
|
1083
|
+
if X.ndim == 1:
|
|
1084
|
+
X = X.reshape(1, -1)
|
|
1085
|
+
|
|
1086
|
+
# For multi-dimensional data (images), first dim is batch
|
|
1087
|
+
n_samples = X.shape[0]
|
|
1088
|
+
|
|
1089
|
+
return [
|
|
1090
|
+
self.explain(X[i], target_class=target_class)
|
|
1091
|
+
for i in range(n_samples)
|
|
1092
|
+
]
|
|
1093
|
+
|
|
1094
|
+
def explain_with_layer_relevances(
|
|
1095
|
+
self,
|
|
1096
|
+
instance: np.ndarray,
|
|
1097
|
+
target_class: Optional[int] = None
|
|
1098
|
+
) -> Dict[str, Any]:
|
|
1099
|
+
"""
|
|
1100
|
+
Compute LRP with layer-wise relevance scores for detailed analysis.
|
|
1101
|
+
|
|
1102
|
+
This method returns relevance scores at each layer, which is useful
|
|
1103
|
+
for understanding how relevance flows through the network and
|
|
1104
|
+
verifying the conservation property.
|
|
1105
|
+
|
|
1106
|
+
Args:
|
|
1107
|
+
instance: Input instance.
|
|
1108
|
+
target_class: Target class for relevance computation.
|
|
1109
|
+
|
|
1110
|
+
Returns:
|
|
1111
|
+
Dictionary containing:
|
|
1112
|
+
- input_relevances: Final attribution scores for input features
|
|
1113
|
+
- layer_relevances: Dict mapping layer names to relevance arrays
|
|
1114
|
+
- target_class: The target class used
|
|
1115
|
+
- rule: The rule used for computation
|
|
1116
|
+
"""
|
|
1117
|
+
instance = np.array(instance).astype(np.float32)
|
|
1118
|
+
|
|
1119
|
+
# Determine target class if not specified
|
|
1120
|
+
if target_class is None and self.class_names:
|
|
1121
|
+
# Get prediction using properly shaped input
|
|
1122
|
+
model = self._get_pytorch_model()
|
|
1123
|
+
model.eval()
|
|
1124
|
+
with torch.no_grad():
|
|
1125
|
+
x = self._prepare_input_tensor(instance)
|
|
1126
|
+
output = model(x)
|
|
1127
|
+
target_class = int(torch.argmax(output, dim=1).item())
|
|
1128
|
+
|
|
1129
|
+
# Compute LRP with layer relevances
|
|
1130
|
+
input_relevances, layer_relevances = self._compute_lrp(
|
|
1131
|
+
instance, target_class, return_layer_relevances=True
|
|
1132
|
+
)
|
|
1133
|
+
|
|
1134
|
+
return {
|
|
1135
|
+
"input_relevances": [float(x) for x in input_relevances],
|
|
1136
|
+
"layer_relevances": {
|
|
1137
|
+
name: [float(x) for x in rel] if isinstance(rel, np.ndarray) else float(rel)
|
|
1138
|
+
for name, rel in layer_relevances.items()
|
|
1139
|
+
},
|
|
1140
|
+
"target_class": target_class,
|
|
1141
|
+
"rule": self.rule,
|
|
1142
|
+
"feature_names": self.feature_names
|
|
1143
|
+
}
|
|
1144
|
+
|
|
1145
|
+
def compare_rules(
|
|
1146
|
+
self,
|
|
1147
|
+
instance: np.ndarray,
|
|
1148
|
+
target_class: Optional[int] = None,
|
|
1149
|
+
rules: Optional[List[str]] = None
|
|
1150
|
+
) -> Dict[str, Dict[str, Any]]:
|
|
1151
|
+
"""
|
|
1152
|
+
Compare different LRP rules on the same instance.
|
|
1153
|
+
|
|
1154
|
+
Useful for understanding how different rules affect attributions
|
|
1155
|
+
and for selecting the most appropriate rule for your use case.
|
|
1156
|
+
|
|
1157
|
+
Args:
|
|
1158
|
+
instance: Input instance.
|
|
1159
|
+
target_class: Target class for comparison.
|
|
1160
|
+
rules: List of rules to compare. If None, compares all rules.
|
|
1161
|
+
|
|
1162
|
+
Returns:
|
|
1163
|
+
Dictionary mapping rule names to their attribution results.
|
|
1164
|
+
"""
|
|
1165
|
+
instance = np.array(instance).astype(np.float32)
|
|
1166
|
+
|
|
1167
|
+
# Determine target class
|
|
1168
|
+
if target_class is None and self.class_names:
|
|
1169
|
+
# Get prediction using properly shaped input
|
|
1170
|
+
model = self._get_pytorch_model()
|
|
1171
|
+
model.eval()
|
|
1172
|
+
with torch.no_grad():
|
|
1173
|
+
x = self._prepare_input_tensor(instance)
|
|
1174
|
+
output = model(x)
|
|
1175
|
+
target_class = int(torch.argmax(output, dim=1).item())
|
|
1176
|
+
|
|
1177
|
+
if rules is None:
|
|
1178
|
+
rules = ["epsilon", "gamma", "alpha_beta", "z_plus"]
|
|
1179
|
+
|
|
1180
|
+
results = {}
|
|
1181
|
+
|
|
1182
|
+
# Save original settings
|
|
1183
|
+
original_rule = self.rule
|
|
1184
|
+
|
|
1185
|
+
for rule in rules:
|
|
1186
|
+
self.rule = rule
|
|
1187
|
+
|
|
1188
|
+
try:
|
|
1189
|
+
attributions = self._compute_lrp(instance, target_class)
|
|
1190
|
+
|
|
1191
|
+
# Find top feature
|
|
1192
|
+
top_idx = int(np.argmax(np.abs(attributions)))
|
|
1193
|
+
if top_idx < len(self.feature_names):
|
|
1194
|
+
top_feature = self.feature_names[top_idx]
|
|
1195
|
+
else:
|
|
1196
|
+
top_feature = f"feature_{top_idx}"
|
|
1197
|
+
|
|
1198
|
+
results[rule] = {
|
|
1199
|
+
"attributions": [float(x) for x in attributions],
|
|
1200
|
+
"top_feature": top_feature,
|
|
1201
|
+
"top_attribution": float(attributions[top_idx]),
|
|
1202
|
+
"attribution_sum": float(np.sum(attributions)),
|
|
1203
|
+
"attribution_range": (float(np.min(attributions)), float(np.max(attributions)))
|
|
1204
|
+
}
|
|
1205
|
+
except Exception as e:
|
|
1206
|
+
results[rule] = {"error": str(e)}
|
|
1207
|
+
|
|
1208
|
+
# Restore original rule
|
|
1209
|
+
self.rule = original_rule
|
|
1210
|
+
|
|
1211
|
+
return results
|