explainiverse 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1206 @@
1
+ # src/explainiverse/explainers/gradient/lrp.py
2
+ """
3
+ Layer-wise Relevance Propagation (LRP) - Decomposition-based Attribution.
4
+
5
+ LRP decomposes network predictions back to input features using a conservation
6
+ principle. Unlike gradient-based methods, LRP propagates relevance scores
7
+ layer-by-layer through the network using specific propagation rules.
8
+
9
+ Key Properties:
10
+ - Conservation: Sum of relevances at each layer equals the output
11
+ - Layer-wise decomposition: Relevance flows backward through layers
12
+ - Multiple rules: Different rules for different layer types and use cases
13
+
14
+ Supported Layer Types:
15
+ - Linear (fully connected)
16
+ - Conv2d (convolutional)
17
+ - BatchNorm1d, BatchNorm2d
18
+ - ReLU, LeakyReLU, ELU, Tanh, Sigmoid
19
+ - MaxPool2d, AvgPool2d
20
+ - Flatten, Dropout (passthrough)
21
+
22
+ Propagation Rules:
23
+ - LRP-0: Basic rule (no stabilization) - theoretical baseline
24
+ - LRP-ε (epsilon): Adds small constant for numerical stability (recommended)
25
+ - LRP-γ (gamma): Enhances positive contributions - good for image classification
26
+ - LRP-αβ (alpha-beta): Separates positive/negative contributions - fine control
27
+ - LRP-z⁺ (z-plus): Only considers positive weights - often used for input layers
28
+ - Composite: Different rules for different layers
29
+
30
+ Mathematical Formulation:
31
+ For layer l with input a and output z = Wx + b:
32
+
33
+ LRP-0: R_j = Σ_k (a_j * w_jk / z_k) * R_k
34
+ LRP-ε: R_j = Σ_k (a_j * w_jk / (z_k + ε*sign(z_k))) * R_k
35
+ LRP-γ: R_j = Σ_k (a_j * (w_jk + γ*w_jk⁺) / (z_k + γ*z_k⁺)) * R_k
36
+ LRP-αβ: R_j = Σ_k (α * (a_j * w_jk)⁺ / z_k⁺ - β * (a_j * w_jk)⁻ / z_k⁻) * R_k
37
+ LRP-z⁺: R_j = Σ_k (a_j * w_jk⁺ / Σ_i a_i * w_ik⁺) * R_k
38
+
39
+ Reference:
40
+ Bach, S., Binder, A., Montavon, G., Klauschen, F., Müller, K. R., & Samek, W. (2015).
41
+ On Pixel-wise Explanations for Non-Linear Classifier Decisions by Layer-wise
42
+ Relevance Propagation. PLOS ONE.
43
+ https://doi.org/10.1371/journal.pone.0130140
44
+
45
+ Montavon, G., Binder, A., Lapuschkin, S., Samek, W., & Müller, K. R. (2019).
46
+ Layer-wise Relevance Propagation: An Overview. Explainable AI: Interpreting,
47
+ Explaining and Visualizing Deep Learning. Springer.
48
+
49
+ Example:
50
+ from explainiverse.explainers.gradient import LRPExplainer
51
+ from explainiverse.adapters import PyTorchAdapter
52
+
53
+ adapter = PyTorchAdapter(model, task="classification")
54
+
55
+ explainer = LRPExplainer(
56
+ model=adapter,
57
+ feature_names=feature_names,
58
+ rule="epsilon",
59
+ epsilon=1e-6
60
+ )
61
+
62
+ explanation = explainer.explain(instance)
63
+ """
64
+
65
+ import numpy as np
66
+ from typing import List, Optional, Dict, Any, Tuple, Union
67
+ from collections import OrderedDict
68
+
69
+ from explainiverse.core.explainer import BaseExplainer
70
+ from explainiverse.core.explanation import Explanation
71
+
72
+
73
+ # Check if PyTorch is available
74
+ try:
75
+ import torch
76
+ import torch.nn as nn
77
+ import torch.nn.functional as F
78
+ TORCH_AVAILABLE = True
79
+ except ImportError:
80
+ TORCH_AVAILABLE = False
81
+ torch = None
82
+ nn = None
83
+ F = None
84
+
85
+
86
+ # Valid LRP rules
87
+ VALID_RULES = ["epsilon", "gamma", "alpha_beta", "z_plus", "composite"]
88
+
89
+ # Layer types that require special LRP handling
90
+ WEIGHTED_LAYERS = (nn.Linear, nn.Conv2d) if TORCH_AVAILABLE else ()
91
+ NORMALIZATION_LAYERS = (nn.BatchNorm1d, nn.BatchNorm2d) if TORCH_AVAILABLE else ()
92
+ ACTIVATION_LAYERS = (nn.ReLU, nn.LeakyReLU, nn.ELU, nn.Tanh, nn.Sigmoid, nn.GELU) if TORCH_AVAILABLE else ()
93
+ POOLING_LAYERS = (nn.MaxPool2d, nn.AvgPool2d, nn.AdaptiveAvgPool2d) if TORCH_AVAILABLE else ()
94
+ PASSTHROUGH_LAYERS = (nn.Dropout, nn.Dropout2d, nn.Flatten) if TORCH_AVAILABLE else ()
95
+
96
+
97
+ class LRPExplainer(BaseExplainer):
98
+ """
99
+ Layer-wise Relevance Propagation (LRP) explainer for neural networks.
100
+
101
+ LRP decomposes the network output into relevance scores for each input
102
+ feature by propagating relevance backward through the network layers.
103
+ The key property is conservation: the sum of relevances at each layer
104
+ equals the relevance at the layer above.
105
+
106
+ Supports:
107
+ - Fully connected networks (Linear + activations)
108
+ - Convolutional networks (Conv2d + BatchNorm + pooling)
109
+ - Mixed architectures
110
+
111
+ Attributes:
112
+ model: Model adapter (must be PyTorchAdapter)
113
+ feature_names: List of feature names
114
+ class_names: List of class names (for classification)
115
+ rule: Propagation rule ("epsilon", "gamma", "alpha_beta", "z_plus", "composite")
116
+ epsilon: Stabilization constant for epsilon rule
117
+ gamma: Enhancement factor for gamma rule
118
+ alpha: Positive contribution weight for alpha-beta rule
119
+ beta: Negative contribution weight for alpha-beta rule
120
+
121
+ Example:
122
+ >>> explainer = LRPExplainer(adapter, feature_names, rule="epsilon")
123
+ >>> explanation = explainer.explain(instance)
124
+ >>> print(explanation.explanation_data["feature_attributions"])
125
+ """
126
+
127
+ def __init__(
128
+ self,
129
+ model,
130
+ feature_names: List[str],
131
+ class_names: Optional[List[str]] = None,
132
+ rule: str = "epsilon",
133
+ epsilon: float = 1e-6,
134
+ gamma: float = 0.25,
135
+ alpha: float = 2.0,
136
+ beta: float = 1.0
137
+ ):
138
+ """
139
+ Initialize the LRP explainer.
140
+
141
+ Args:
142
+ model: A PyTorchAdapter wrapping the model to explain.
143
+ feature_names: List of input feature names.
144
+ class_names: List of class names (for classification tasks).
145
+ rule: Propagation rule to use:
146
+ - "epsilon": LRP-ε with stabilization (default, recommended)
147
+ - "gamma": LRP-γ enhancing positive contributions
148
+ - "alpha_beta": LRP-αβ separating pos/neg contributions
149
+ - "z_plus": LRP-z⁺ using only positive weights
150
+ - "composite": Different rules for different layers
151
+ epsilon: Small constant for numerical stability in epsilon rule.
152
+ Default: 1e-6
153
+ gamma: Factor to enhance positive contributions in gamma rule.
154
+ Default: 0.25
155
+ alpha: Weight for positive contributions in alpha-beta rule.
156
+ Must satisfy alpha - beta = 1. Default: 2.0
157
+ beta: Weight for negative contributions in alpha-beta rule.
158
+ Must satisfy alpha - beta = 1. Default: 1.0
159
+
160
+ Raises:
161
+ TypeError: If model is not a PyTorchAdapter.
162
+ ValueError: If rule is invalid or alpha-beta constraint violated.
163
+ """
164
+ if not TORCH_AVAILABLE:
165
+ raise ImportError(
166
+ "PyTorch is required for LRP. Install with: pip install torch"
167
+ )
168
+
169
+ super().__init__(model)
170
+
171
+ # Validate model is PyTorchAdapter
172
+ if not hasattr(model, 'model') or not isinstance(model.model, nn.Module):
173
+ raise TypeError(
174
+ "LRP requires a PyTorchAdapter wrapping a PyTorch model. "
175
+ "Use: PyTorchAdapter(your_model, task='classification')"
176
+ )
177
+
178
+ # Validate rule
179
+ if rule not in VALID_RULES:
180
+ raise ValueError(
181
+ f"Invalid rule: '{rule}'. Must be one of: {VALID_RULES}"
182
+ )
183
+
184
+ # Validate alpha-beta constraint
185
+ if rule == "alpha_beta":
186
+ if not np.isclose(alpha - beta, 1.0):
187
+ raise ValueError(
188
+ f"For alpha-beta rule, alpha - beta must equal 1. "
189
+ f"Got alpha={alpha}, beta={beta}, difference={alpha - beta}"
190
+ )
191
+
192
+ self.feature_names = list(feature_names)
193
+ self.class_names = list(class_names) if class_names else None
194
+ self.rule = rule
195
+ self.epsilon = epsilon
196
+ self.gamma = gamma
197
+ self.alpha = alpha
198
+ self.beta = beta
199
+
200
+ # For composite rules
201
+ self._layer_rules: Optional[Dict[int, str]] = None
202
+
203
+ # Cache for layer information
204
+ self._layers_info: Optional[List[Dict[str, Any]]] = None
205
+
206
+ def _get_pytorch_model(self) -> nn.Module:
207
+ """Get the underlying PyTorch model."""
208
+ return self.model.model
209
+
210
+ def _is_cnn_model(self) -> bool:
211
+ """Check if the model's first weighted layer is Conv2d."""
212
+ model = self._get_pytorch_model()
213
+ for module in model.modules():
214
+ if isinstance(module, nn.Conv2d):
215
+ return True
216
+ if isinstance(module, nn.Linear):
217
+ return False
218
+ return False
219
+
220
+ def _prepare_input_tensor(self, instance: np.ndarray) -> torch.Tensor:
221
+ """
222
+ Prepare input tensor with correct shape for the model.
223
+
224
+ For CNN models, preserves the spatial dimensions.
225
+ For MLP models, flattens to 2D.
226
+
227
+ Args:
228
+ instance: Input array (1D for tabular, 3D for images)
229
+
230
+ Returns:
231
+ Tensor with batch dimension added and correct shape for model
232
+ """
233
+ instance = np.array(instance).astype(np.float32)
234
+ original_shape = instance.shape
235
+
236
+ model = self._get_pytorch_model()
237
+
238
+ # Find first weighted layer to determine input type
239
+ first_layer = None
240
+ for module in model.modules():
241
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
242
+ first_layer = module
243
+ break
244
+
245
+ if isinstance(first_layer, nn.Conv2d):
246
+ # CNN model - need 4D input (batch, channels, height, width)
247
+ in_channels = first_layer.in_channels
248
+
249
+ if len(original_shape) >= 3:
250
+ # Already (C, H, W) format - just add batch dimension
251
+ x = torch.tensor(instance).unsqueeze(0)
252
+ elif len(original_shape) == 2:
253
+ # (H, W) - assume single channel, add channel and batch dimensions
254
+ x = torch.tensor(instance).unsqueeze(0).unsqueeze(0)
255
+ else:
256
+ # Flattened - try to infer spatial dimensions
257
+ n_features = instance.size
258
+ if n_features % in_channels == 0:
259
+ spatial_size = int(np.sqrt(n_features // in_channels))
260
+ if spatial_size * spatial_size * in_channels == n_features:
261
+ x = torch.tensor(instance.flatten()).reshape(
262
+ 1, in_channels, spatial_size, spatial_size
263
+ )
264
+ else:
265
+ raise ValueError(
266
+ f"Cannot infer spatial dimensions for {n_features} features "
267
+ f"with {in_channels} channels"
268
+ )
269
+ else:
270
+ raise ValueError(
271
+ f"Number of features ({n_features}) not divisible by "
272
+ f"input channels ({in_channels})"
273
+ )
274
+ else:
275
+ # MLP model - need 2D input (batch, features)
276
+ x = torch.tensor(instance.flatten()).reshape(1, -1)
277
+
278
+ return x.float()
279
+
280
+ def _get_rule_for_layer(self, layer_idx: int, layer_type: str) -> str:
281
+ """
282
+ Get the propagation rule for a specific layer.
283
+
284
+ Args:
285
+ layer_idx: Index of the layer
286
+ layer_type: Type of the layer (e.g., "Linear", "Conv2d")
287
+
288
+ Returns:
289
+ Rule name to use for this layer
290
+ """
291
+ if self.rule != "composite":
292
+ return self.rule
293
+
294
+ # Composite rule: check layer-specific rules
295
+ if self._layer_rules and layer_idx in self._layer_rules:
296
+ return self._layer_rules[layer_idx]
297
+
298
+ # Default fallback for composite
299
+ return "epsilon"
300
+
301
+ def set_composite_rule(self, layer_rules: Dict[int, str]) -> "LRPExplainer":
302
+ """
303
+ Set layer-specific rules for composite LRP.
304
+
305
+ This allows using different propagation rules for different layers,
306
+ which is often beneficial. A common practice is:
307
+ - z_plus for input/early layers (focuses on what's present)
308
+ - epsilon for middle layers (balanced attribution)
309
+ - epsilon or zero for final layers
310
+
311
+ Args:
312
+ layer_rules: Dictionary mapping layer indices to rule names.
313
+ Layers not in this dict use "epsilon" by default.
314
+
315
+ Returns:
316
+ Self for method chaining.
317
+
318
+ Example:
319
+ >>> explainer.set_composite_rule({
320
+ ... 0: "z_plus", # First layer
321
+ ... 2: "epsilon", # Middle layer
322
+ ... 4: "epsilon" # Final layer
323
+ ... })
324
+ """
325
+ # Validate rules
326
+ for idx, rule in layer_rules.items():
327
+ if rule not in VALID_RULES and rule != "composite":
328
+ raise ValueError(f"Invalid rule '{rule}' for layer {idx}")
329
+
330
+ self._layer_rules = layer_rules
331
+ return self
332
+
333
+ # =========================================================================
334
+ # Linear Layer LRP Rules
335
+ # =========================================================================
336
+
337
+ def _lrp_linear_epsilon(
338
+ self,
339
+ layer: nn.Linear,
340
+ activation: torch.Tensor,
341
+ relevance: torch.Tensor,
342
+ epsilon: float
343
+ ) -> torch.Tensor:
344
+ """
345
+ LRP-epsilon rule for linear layers.
346
+
347
+ R_j = Σ_k (a_j * w_jk / (z_k + ε*sign(z_k))) * R_k
348
+ """
349
+ # Forward pass to get z
350
+ z = torch.mm(activation, layer.weight.t())
351
+ if layer.bias is not None:
352
+ z = z + layer.bias
353
+
354
+ # Stabilize: z + epsilon * sign(z)
355
+ z_stabilized = z + epsilon * torch.sign(z)
356
+ z_stabilized = torch.where(
357
+ torch.abs(z_stabilized) < epsilon,
358
+ torch.full_like(z_stabilized, epsilon),
359
+ z_stabilized
360
+ )
361
+
362
+ # Compute relevance contribution: (R / z_stabilized) @ W
363
+ s = relevance / z_stabilized
364
+ c = torch.mm(s, layer.weight)
365
+
366
+ return activation * c
367
+
368
+ def _lrp_linear_gamma(
369
+ self,
370
+ layer: nn.Linear,
371
+ activation: torch.Tensor,
372
+ relevance: torch.Tensor,
373
+ gamma: float
374
+ ) -> torch.Tensor:
375
+ """
376
+ LRP-gamma rule for linear layers.
377
+ Enhances positive contributions for sharper attributions.
378
+ """
379
+ w_plus = torch.clamp(layer.weight, min=0)
380
+ w_modified = layer.weight + gamma * w_plus
381
+
382
+ z = torch.mm(activation, w_modified.t())
383
+ if layer.bias is not None:
384
+ b_plus = torch.clamp(layer.bias, min=0)
385
+ z = z + layer.bias + gamma * b_plus
386
+
387
+ z_stabilized = z + self.epsilon * torch.sign(z)
388
+ z_stabilized = torch.where(
389
+ torch.abs(z_stabilized) < self.epsilon,
390
+ torch.full_like(z_stabilized, self.epsilon),
391
+ z_stabilized
392
+ )
393
+
394
+ s = relevance / z_stabilized
395
+ c = torch.mm(s, w_modified)
396
+
397
+ return activation * c
398
+
399
+ def _lrp_linear_alpha_beta(
400
+ self,
401
+ layer: nn.Linear,
402
+ activation: torch.Tensor,
403
+ relevance: torch.Tensor,
404
+ alpha: float,
405
+ beta: float
406
+ ) -> torch.Tensor:
407
+ """
408
+ LRP-alpha-beta rule for linear layers.
409
+ Separates positive and negative contributions.
410
+ """
411
+ w_plus = torch.clamp(layer.weight, min=0)
412
+ w_minus = torch.clamp(layer.weight, max=0)
413
+ a_plus = torch.clamp(activation, min=0)
414
+
415
+ z_plus = torch.mm(a_plus, w_plus.t())
416
+ if layer.bias is not None:
417
+ z_plus = z_plus + torch.clamp(layer.bias, min=0)
418
+
419
+ z_minus = torch.mm(a_plus, w_minus.t())
420
+ if layer.bias is not None:
421
+ z_minus = z_minus + torch.clamp(layer.bias, max=0)
422
+
423
+ z_plus_stable = z_plus + self.epsilon
424
+ z_minus_stable = z_minus - self.epsilon
425
+ z_minus_stable = torch.where(
426
+ torch.abs(z_minus_stable) < self.epsilon,
427
+ torch.full_like(z_minus_stable, -self.epsilon),
428
+ z_minus_stable
429
+ )
430
+
431
+ s_plus = relevance / z_plus_stable
432
+ s_minus = relevance / z_minus_stable
433
+
434
+ c_plus = torch.mm(s_plus, w_plus)
435
+ c_minus = torch.mm(s_minus, w_minus)
436
+
437
+ return alpha * a_plus * c_plus - beta * a_plus * c_minus
438
+
439
+ def _lrp_linear_z_plus(
440
+ self,
441
+ layer: nn.Linear,
442
+ activation: torch.Tensor,
443
+ relevance: torch.Tensor
444
+ ) -> torch.Tensor:
445
+ """
446
+ LRP-z+ rule for linear layers.
447
+ Only considers positive weights.
448
+ """
449
+ w_plus = torch.clamp(layer.weight, min=0)
450
+ a_plus = torch.clamp(activation, min=0)
451
+
452
+ z_plus = torch.mm(a_plus, w_plus.t())
453
+ if layer.bias is not None:
454
+ z_plus = z_plus + torch.clamp(layer.bias, min=0)
455
+
456
+ z_plus_stable = z_plus + self.epsilon
457
+
458
+ s = relevance / z_plus_stable
459
+ c = torch.mm(s, w_plus)
460
+
461
+ return a_plus * c
462
+
463
+ def _propagate_linear(
464
+ self,
465
+ layer: nn.Linear,
466
+ activation: torch.Tensor,
467
+ relevance: torch.Tensor,
468
+ rule: str
469
+ ) -> torch.Tensor:
470
+ """Propagate relevance through a linear layer."""
471
+ if rule == "epsilon":
472
+ return self._lrp_linear_epsilon(layer, activation, relevance, self.epsilon)
473
+ elif rule == "gamma":
474
+ return self._lrp_linear_gamma(layer, activation, relevance, self.gamma)
475
+ elif rule == "alpha_beta":
476
+ return self._lrp_linear_alpha_beta(layer, activation, relevance, self.alpha, self.beta)
477
+ elif rule == "z_plus":
478
+ return self._lrp_linear_z_plus(layer, activation, relevance)
479
+ else:
480
+ return self._lrp_linear_epsilon(layer, activation, relevance, self.epsilon)
481
+
482
+ # =========================================================================
483
+ # Conv2d Layer LRP Rules
484
+ # =========================================================================
485
+
486
+ def _lrp_conv2d_epsilon(
487
+ self,
488
+ layer: nn.Conv2d,
489
+ activation: torch.Tensor,
490
+ relevance: torch.Tensor,
491
+ epsilon: float
492
+ ) -> torch.Tensor:
493
+ """
494
+ LRP-epsilon rule for Conv2d layers.
495
+ Uses convolution transpose for backward relevance propagation.
496
+ """
497
+ # Forward pass to get z
498
+ z = F.conv2d(
499
+ activation,
500
+ layer.weight,
501
+ bias=layer.bias,
502
+ stride=layer.stride,
503
+ padding=layer.padding,
504
+ dilation=layer.dilation,
505
+ groups=layer.groups
506
+ )
507
+
508
+ # Stabilize
509
+ z_stabilized = z + epsilon * torch.sign(z)
510
+ z_stabilized = torch.where(
511
+ torch.abs(z_stabilized) < epsilon,
512
+ torch.full_like(z_stabilized, epsilon),
513
+ z_stabilized
514
+ )
515
+
516
+ # Compute s = R / z
517
+ s = relevance / z_stabilized
518
+
519
+ # Backward pass using conv_transpose2d
520
+ c = F.conv_transpose2d(
521
+ s,
522
+ layer.weight,
523
+ bias=None,
524
+ stride=layer.stride,
525
+ padding=layer.padding,
526
+ output_padding=0,
527
+ groups=layer.groups,
528
+ dilation=layer.dilation
529
+ )
530
+
531
+ # Handle output size mismatch
532
+ if c.shape != activation.shape:
533
+ # Pad or crop to match activation shape
534
+ diff_h = activation.shape[2] - c.shape[2]
535
+ diff_w = activation.shape[3] - c.shape[3]
536
+ if diff_h > 0 or diff_w > 0:
537
+ c = F.pad(c, [0, max(0, diff_w), 0, max(0, diff_h)])
538
+ if diff_h < 0 or diff_w < 0:
539
+ c = c[:, :, :activation.shape[2], :activation.shape[3]]
540
+
541
+ return activation * c
542
+
543
+ def _lrp_conv2d_gamma(
544
+ self,
545
+ layer: nn.Conv2d,
546
+ activation: torch.Tensor,
547
+ relevance: torch.Tensor,
548
+ gamma: float
549
+ ) -> torch.Tensor:
550
+ """LRP-gamma rule for Conv2d layers."""
551
+ w_plus = torch.clamp(layer.weight, min=0)
552
+ w_modified = layer.weight + gamma * w_plus
553
+
554
+ z = F.conv2d(
555
+ activation,
556
+ w_modified,
557
+ bias=layer.bias,
558
+ stride=layer.stride,
559
+ padding=layer.padding,
560
+ dilation=layer.dilation,
561
+ groups=layer.groups
562
+ )
563
+
564
+ if layer.bias is not None:
565
+ b_plus = torch.clamp(layer.bias, min=0)
566
+ # Bias is already added in conv2d, add gamma * b_plus
567
+ z = z + gamma * b_plus.view(1, -1, 1, 1)
568
+
569
+ z_stabilized = z + self.epsilon * torch.sign(z)
570
+ z_stabilized = torch.where(
571
+ torch.abs(z_stabilized) < self.epsilon,
572
+ torch.full_like(z_stabilized, self.epsilon),
573
+ z_stabilized
574
+ )
575
+
576
+ s = relevance / z_stabilized
577
+
578
+ c = F.conv_transpose2d(
579
+ s,
580
+ w_modified,
581
+ bias=None,
582
+ stride=layer.stride,
583
+ padding=layer.padding,
584
+ output_padding=0,
585
+ groups=layer.groups,
586
+ dilation=layer.dilation
587
+ )
588
+
589
+ if c.shape != activation.shape:
590
+ diff_h = activation.shape[2] - c.shape[2]
591
+ diff_w = activation.shape[3] - c.shape[3]
592
+ if diff_h > 0 or diff_w > 0:
593
+ c = F.pad(c, [0, max(0, diff_w), 0, max(0, diff_h)])
594
+ if diff_h < 0 or diff_w < 0:
595
+ c = c[:, :, :activation.shape[2], :activation.shape[3]]
596
+
597
+ return activation * c
598
+
599
+ def _lrp_conv2d_z_plus(
600
+ self,
601
+ layer: nn.Conv2d,
602
+ activation: torch.Tensor,
603
+ relevance: torch.Tensor
604
+ ) -> torch.Tensor:
605
+ """LRP-z+ rule for Conv2d layers."""
606
+ w_plus = torch.clamp(layer.weight, min=0)
607
+ a_plus = torch.clamp(activation, min=0)
608
+
609
+ z_plus = F.conv2d(
610
+ a_plus,
611
+ w_plus,
612
+ bias=None, # Ignore bias for z+
613
+ stride=layer.stride,
614
+ padding=layer.padding,
615
+ dilation=layer.dilation,
616
+ groups=layer.groups
617
+ )
618
+
619
+ if layer.bias is not None:
620
+ z_plus = z_plus + torch.clamp(layer.bias, min=0).view(1, -1, 1, 1)
621
+
622
+ z_plus_stable = z_plus + self.epsilon
623
+
624
+ s = relevance / z_plus_stable
625
+
626
+ c = F.conv_transpose2d(
627
+ s,
628
+ w_plus,
629
+ bias=None,
630
+ stride=layer.stride,
631
+ padding=layer.padding,
632
+ output_padding=0,
633
+ groups=layer.groups,
634
+ dilation=layer.dilation
635
+ )
636
+
637
+ if c.shape != a_plus.shape:
638
+ diff_h = a_plus.shape[2] - c.shape[2]
639
+ diff_w = a_plus.shape[3] - c.shape[3]
640
+ if diff_h > 0 or diff_w > 0:
641
+ c = F.pad(c, [0, max(0, diff_w), 0, max(0, diff_h)])
642
+ if diff_h < 0 or diff_w < 0:
643
+ c = c[:, :, :a_plus.shape[2], :a_plus.shape[3]]
644
+
645
+ return a_plus * c
646
+
647
+ def _propagate_conv2d(
648
+ self,
649
+ layer: nn.Conv2d,
650
+ activation: torch.Tensor,
651
+ relevance: torch.Tensor,
652
+ rule: str
653
+ ) -> torch.Tensor:
654
+ """Propagate relevance through a Conv2d layer."""
655
+ if rule == "epsilon":
656
+ return self._lrp_conv2d_epsilon(layer, activation, relevance, self.epsilon)
657
+ elif rule == "gamma":
658
+ return self._lrp_conv2d_gamma(layer, activation, relevance, self.gamma)
659
+ elif rule == "z_plus":
660
+ return self._lrp_conv2d_z_plus(layer, activation, relevance)
661
+ elif rule == "alpha_beta":
662
+ # Alpha-beta for conv is complex, fall back to epsilon
663
+ return self._lrp_conv2d_epsilon(layer, activation, relevance, self.epsilon)
664
+ else:
665
+ return self._lrp_conv2d_epsilon(layer, activation, relevance, self.epsilon)
666
+
667
+ # =========================================================================
668
+ # BatchNorm Layer LRP Rules
669
+ # =========================================================================
670
+
671
+ def _propagate_batchnorm(
672
+ self,
673
+ layer: Union[nn.BatchNorm1d, nn.BatchNorm2d],
674
+ activation: torch.Tensor,
675
+ relevance: torch.Tensor
676
+ ) -> torch.Tensor:
677
+ """
678
+ Propagate relevance through BatchNorm layer.
679
+
680
+ BatchNorm is an affine transformation: y = gamma * (x - mean) / std + beta
681
+ We treat it as a linear scaling and propagate relevance proportionally.
682
+ """
683
+ # Get BatchNorm parameters
684
+ if layer.running_mean is None or layer.running_var is None:
685
+ # If no running stats, pass through
686
+ return relevance
687
+
688
+ mean = layer.running_mean
689
+ var = layer.running_var
690
+ eps = layer.eps
691
+
692
+ # Compute the effective scale factor
693
+ std = torch.sqrt(var + eps)
694
+
695
+ if layer.weight is not None:
696
+ scale = layer.weight / std
697
+ else:
698
+ scale = 1.0 / std
699
+
700
+ # Reshape scale for broadcasting
701
+ if isinstance(layer, nn.BatchNorm2d):
702
+ scale = scale.view(1, -1, 1, 1)
703
+ else:
704
+ scale = scale.view(1, -1)
705
+
706
+ # Relevance propagation: R_input = R_output (scaled back)
707
+ # Since BN is essentially a rescaling, we redistribute proportionally
708
+ return relevance / (scale + self.epsilon * torch.sign(scale))
709
+
710
+ # =========================================================================
711
+ # Activation Layer LRP Rules
712
+ # =========================================================================
713
+
714
+ def _propagate_activation(
715
+ self,
716
+ layer: nn.Module,
717
+ activation: torch.Tensor,
718
+ relevance: torch.Tensor
719
+ ) -> torch.Tensor:
720
+ """
721
+ Propagate relevance through activation layers (ReLU, etc.).
722
+
723
+ For element-wise activations, relevance passes through unchanged
724
+ to locations where the activation was positive.
725
+ """
726
+ if isinstance(layer, nn.ReLU):
727
+ # ReLU: pass relevance where input was positive
728
+ # Since we have post-activation values, we use them as mask
729
+ mask = (activation > 0).float()
730
+ return relevance * mask + relevance * (1 - mask) # Actually just pass through
731
+ elif isinstance(layer, (nn.LeakyReLU, nn.ELU)):
732
+ # For leaky activations, relevance passes through
733
+ return relevance
734
+ elif isinstance(layer, (nn.Tanh, nn.Sigmoid)):
735
+ # For bounded activations, relevance passes through
736
+ return relevance
737
+ else:
738
+ # Default: pass through
739
+ return relevance
740
+
741
+ # =========================================================================
742
+ # Pooling Layer LRP Rules
743
+ # =========================================================================
744
+
745
+ def _propagate_maxpool2d(
746
+ self,
747
+ layer: nn.MaxPool2d,
748
+ activation: torch.Tensor,
749
+ relevance: torch.Tensor
750
+ ) -> torch.Tensor:
751
+ """
752
+ Propagate relevance through MaxPool2d.
753
+
754
+ Relevance is distributed to the max locations (winner-take-all).
755
+ """
756
+ # Forward pass to get indices
757
+ _, indices = F.max_pool2d(
758
+ activation,
759
+ kernel_size=layer.kernel_size,
760
+ stride=layer.stride,
761
+ padding=layer.padding,
762
+ dilation=layer.dilation,
763
+ return_indices=True,
764
+ ceil_mode=layer.ceil_mode
765
+ )
766
+
767
+ # Unpool: place relevance at max locations
768
+ unpooled = F.max_unpool2d(
769
+ relevance,
770
+ indices,
771
+ kernel_size=layer.kernel_size,
772
+ stride=layer.stride,
773
+ padding=layer.padding,
774
+ output_size=activation.shape
775
+ )
776
+
777
+ return unpooled
778
+
779
+ def _propagate_avgpool2d(
780
+ self,
781
+ layer: Union[nn.AvgPool2d, nn.AdaptiveAvgPool2d],
782
+ activation: torch.Tensor,
783
+ relevance: torch.Tensor
784
+ ) -> torch.Tensor:
785
+ """
786
+ Propagate relevance through AvgPool2d.
787
+
788
+ Relevance is distributed uniformly across pooling regions.
789
+ """
790
+ if isinstance(layer, nn.AdaptiveAvgPool2d):
791
+ # For adaptive pooling, upsample relevance to input size
792
+ return F.interpolate(
793
+ relevance,
794
+ size=activation.shape[2:],
795
+ mode='nearest'
796
+ )
797
+ else:
798
+ # For regular avg pooling, use nearest neighbor upsampling
799
+ # and scale by pool area
800
+ kernel_size = layer.kernel_size if isinstance(layer.kernel_size, tuple) else (layer.kernel_size, layer.kernel_size)
801
+
802
+ upsampled = F.interpolate(
803
+ relevance,
804
+ size=activation.shape[2:],
805
+ mode='nearest'
806
+ )
807
+
808
+ return upsampled
809
+
810
+ # =========================================================================
811
+ # Main LRP Computation
812
+ # =========================================================================
813
+
814
+ def _compute_lrp(
815
+ self,
816
+ instance: np.ndarray,
817
+ target_class: Optional[int] = None,
818
+ return_layer_relevances: bool = False
819
+ ) -> Union[np.ndarray, Tuple[np.ndarray, Dict[str, np.ndarray]]]:
820
+ """
821
+ Compute LRP attributions for a single instance.
822
+
823
+ Args:
824
+ instance: Input instance (1D or multi-dimensional array)
825
+ target_class: Target class for relevance initialization
826
+ return_layer_relevances: If True, also return relevances at each layer
827
+
828
+ Returns:
829
+ If return_layer_relevances is False:
830
+ Array of attribution scores for input features
831
+ If return_layer_relevances is True:
832
+ Tuple of (input_attributions, layer_relevances_dict)
833
+ """
834
+ model = self._get_pytorch_model()
835
+ model.eval()
836
+
837
+ # Prepare input with correct shape for model type (CNN vs MLP)
838
+ x = self._prepare_input_tensor(instance)
839
+ x.requires_grad_(False)
840
+
841
+ # =====================================================================
842
+ # Forward pass: collect activations at each layer
843
+ # =====================================================================
844
+ activations = OrderedDict()
845
+ activations["input"] = x.clone()
846
+
847
+ layer_list = [] # List of (idx, name, layer, input_activation)
848
+
849
+ current = x
850
+
851
+ # Handle Sequential models
852
+ if isinstance(model, nn.Sequential):
853
+ for idx, (name, layer) in enumerate(model.named_children()):
854
+ layer_list.append((idx, name, layer, current.clone()))
855
+ current = layer(current)
856
+ activations[f"layer_{idx}_{name}"] = current.clone()
857
+ else:
858
+ # For non-Sequential models, use hooks
859
+ hooks = []
860
+ layer_data = OrderedDict()
861
+
862
+ def make_hook(name):
863
+ def hook(module, input, output):
864
+ inp = input[0] if isinstance(input, tuple) else input
865
+ layer_data[name] = {
866
+ "input": inp.clone().detach(),
867
+ "output": output.clone().detach() if isinstance(output, torch.Tensor) else output
868
+ }
869
+ return hook
870
+
871
+ # Register hooks on relevant layers
872
+ idx = 0
873
+ for name, module in model.named_modules():
874
+ if isinstance(module, (*WEIGHTED_LAYERS, *NORMALIZATION_LAYERS, *ACTIVATION_LAYERS, *POOLING_LAYERS)):
875
+ hooks.append(module.register_forward_hook(make_hook(f"{idx}_{name}")))
876
+ idx += 1
877
+
878
+ # Forward pass
879
+ current = model(x)
880
+
881
+ # Remove hooks
882
+ for h in hooks:
883
+ h.remove()
884
+
885
+ # Build layer list from collected data
886
+ for name, data in layer_data.items():
887
+ idx_str, layer_name = name.split("_", 1)
888
+ # Get the actual module
889
+ module = dict(model.named_modules()).get(layer_name)
890
+ if module is not None:
891
+ layer_list.append((int(idx_str), layer_name, module, data["input"]))
892
+
893
+ output = current
894
+
895
+ # =====================================================================
896
+ # Initialize relevance at output layer
897
+ # =====================================================================
898
+ if target_class is not None:
899
+ relevance = torch.zeros_like(output)
900
+ relevance[0, target_class] = output[0, target_class]
901
+ else:
902
+ relevance = output.clone()
903
+
904
+ # =====================================================================
905
+ # Backward pass: propagate relevance through layers
906
+ # =====================================================================
907
+ layer_relevances = OrderedDict()
908
+ layer_relevances["output"] = relevance.detach().cpu().numpy().flatten()
909
+
910
+ # Reverse through layers
911
+ for idx, name, layer, activation in reversed(layer_list):
912
+ rule = self._get_rule_for_layer(idx, type(layer).__name__)
913
+
914
+ # Propagate based on layer type
915
+ if isinstance(layer, nn.Linear):
916
+ # Flatten activation if needed
917
+ if activation.dim() > 2:
918
+ activation = activation.flatten(1)
919
+ if relevance.dim() > 2:
920
+ relevance = relevance.flatten(1)
921
+ relevance = self._propagate_linear(layer, activation, relevance, rule)
922
+
923
+ elif isinstance(layer, nn.Conv2d):
924
+ relevance = self._propagate_conv2d(layer, activation, relevance, rule)
925
+
926
+ elif isinstance(layer, NORMALIZATION_LAYERS):
927
+ relevance = self._propagate_batchnorm(layer, activation, relevance)
928
+
929
+ elif isinstance(layer, ACTIVATION_LAYERS):
930
+ relevance = self._propagate_activation(layer, activation, relevance)
931
+
932
+ elif isinstance(layer, nn.MaxPool2d):
933
+ relevance = self._propagate_maxpool2d(layer, activation, relevance)
934
+
935
+ elif isinstance(layer, (nn.AvgPool2d, nn.AdaptiveAvgPool2d)):
936
+ relevance = self._propagate_avgpool2d(layer, activation, relevance)
937
+
938
+ elif isinstance(layer, nn.Flatten):
939
+ # Reshape relevance back to pre-flatten shape
940
+ if activation.dim() > 2:
941
+ relevance = relevance.reshape(activation.shape)
942
+
943
+ elif isinstance(layer, PASSTHROUGH_LAYERS):
944
+ # Dropout and other passthrough layers
945
+ pass
946
+
947
+ layer_relevances[f"layer_{idx}_{name}"] = relevance.detach().cpu().numpy().flatten()
948
+
949
+ # Final relevance is the input attribution
950
+ input_relevance = relevance.detach().cpu().numpy().flatten()
951
+ layer_relevances["input"] = input_relevance
952
+
953
+ if return_layer_relevances:
954
+ return input_relevance, layer_relevances
955
+ return input_relevance
956
+
957
+ def explain(
958
+ self,
959
+ instance: np.ndarray,
960
+ target_class: Optional[int] = None,
961
+ return_convergence_delta: bool = False
962
+ ) -> Explanation:
963
+ """
964
+ Generate LRP explanation for an instance.
965
+
966
+ Args:
967
+ instance: Numpy array of input features (1D for tabular,
968
+ or multi-dimensional for images).
969
+ target_class: For classification, which class to explain.
970
+ If None, uses the predicted class.
971
+ return_convergence_delta: If True, include the convergence delta
972
+ (difference between sum of attributions and target output).
973
+ Should be close to 0 for correct LRP (conservation property).
974
+
975
+ Returns:
976
+ Explanation object with feature attributions.
977
+
978
+ Example:
979
+ >>> explanation = explainer.explain(instance)
980
+ >>> print(explanation.explanation_data["feature_attributions"])
981
+ """
982
+ instance = np.array(instance).astype(np.float32)
983
+ original_shape = instance.shape
984
+ instance_flat = instance.flatten()
985
+
986
+ # Determine target class if not specified
987
+ if target_class is None and self.class_names:
988
+ # Get prediction using properly shaped input
989
+ model = self._get_pytorch_model()
990
+ model.eval()
991
+ with torch.no_grad():
992
+ x = self._prepare_input_tensor(instance)
993
+ output = model(x)
994
+ target_class = int(torch.argmax(output, dim=1).item())
995
+
996
+ # Compute LRP attributions
997
+ attributions_raw = self._compute_lrp(instance, target_class)
998
+
999
+ # Build attributions dict
1000
+ if len(self.feature_names) == len(attributions_raw):
1001
+ attributions = {
1002
+ fname: float(attributions_raw[i])
1003
+ for i, fname in enumerate(self.feature_names)
1004
+ }
1005
+ else:
1006
+ # For images or mismatched feature names, use indices
1007
+ attributions = {
1008
+ f"feature_{i}": float(attributions_raw[i])
1009
+ for i in range(len(attributions_raw))
1010
+ }
1011
+
1012
+ # Determine class name
1013
+ if self.class_names and target_class is not None:
1014
+ label_name = self.class_names[target_class]
1015
+ else:
1016
+ label_name = f"class_{target_class}" if target_class is not None else "output"
1017
+
1018
+ explanation_data = {
1019
+ "feature_attributions": attributions,
1020
+ "attributions_raw": [float(x) for x in attributions_raw],
1021
+ "rule": self.rule,
1022
+ "epsilon": self.epsilon if self.rule in ["epsilon", "composite"] else None,
1023
+ "gamma": self.gamma if self.rule == "gamma" else None,
1024
+ "alpha": self.alpha if self.rule == "alpha_beta" else None,
1025
+ "beta": self.beta if self.rule == "alpha_beta" else None,
1026
+ "input_shape": list(original_shape)
1027
+ }
1028
+
1029
+ # Compute convergence delta (conservation check)
1030
+ if return_convergence_delta:
1031
+ model = self._get_pytorch_model()
1032
+ model.eval()
1033
+
1034
+ with torch.no_grad():
1035
+ # Use the helper method to get properly shaped input
1036
+ x = self._prepare_input_tensor(instance)
1037
+ output = model(x)
1038
+
1039
+ if target_class is not None:
1040
+ target_output = output[0, target_class].item()
1041
+ else:
1042
+ target_output = output.sum().item()
1043
+
1044
+ attribution_sum = sum(attributions.values())
1045
+ convergence_delta = abs(target_output - attribution_sum)
1046
+
1047
+ explanation_data["target_output"] = float(target_output)
1048
+ explanation_data["attribution_sum"] = float(attribution_sum)
1049
+ explanation_data["convergence_delta"] = float(convergence_delta)
1050
+
1051
+ return Explanation(
1052
+ explainer_name="LRP",
1053
+ target_class=label_name,
1054
+ explanation_data=explanation_data,
1055
+ feature_names=self.feature_names
1056
+ )
1057
+
1058
+ def explain_batch(
1059
+ self,
1060
+ X: np.ndarray,
1061
+ target_class: Optional[int] = None
1062
+ ) -> List[Explanation]:
1063
+ """
1064
+ Generate explanations for multiple instances.
1065
+
1066
+ Args:
1067
+ X: Array of instances. For tabular: (n_samples, n_features).
1068
+ For images: (n_samples, channels, height, width) or similar.
1069
+ target_class: Target class for all instances. If None,
1070
+ uses predicted class for each instance.
1071
+
1072
+ Returns:
1073
+ List of Explanation objects.
1074
+ """
1075
+ X = np.array(X)
1076
+
1077
+ # Handle single instance
1078
+ if X.ndim == 1:
1079
+ X = X.reshape(1, -1)
1080
+
1081
+ # For multi-dimensional data (images), first dim is batch
1082
+ n_samples = X.shape[0]
1083
+
1084
+ return [
1085
+ self.explain(X[i], target_class=target_class)
1086
+ for i in range(n_samples)
1087
+ ]
1088
+
1089
+ def explain_with_layer_relevances(
1090
+ self,
1091
+ instance: np.ndarray,
1092
+ target_class: Optional[int] = None
1093
+ ) -> Dict[str, Any]:
1094
+ """
1095
+ Compute LRP with layer-wise relevance scores for detailed analysis.
1096
+
1097
+ This method returns relevance scores at each layer, which is useful
1098
+ for understanding how relevance flows through the network and
1099
+ verifying the conservation property.
1100
+
1101
+ Args:
1102
+ instance: Input instance.
1103
+ target_class: Target class for relevance computation.
1104
+
1105
+ Returns:
1106
+ Dictionary containing:
1107
+ - input_relevances: Final attribution scores for input features
1108
+ - layer_relevances: Dict mapping layer names to relevance arrays
1109
+ - target_class: The target class used
1110
+ - rule: The rule used for computation
1111
+ """
1112
+ instance = np.array(instance).astype(np.float32)
1113
+
1114
+ # Determine target class if not specified
1115
+ if target_class is None and self.class_names:
1116
+ # Get prediction using properly shaped input
1117
+ model = self._get_pytorch_model()
1118
+ model.eval()
1119
+ with torch.no_grad():
1120
+ x = self._prepare_input_tensor(instance)
1121
+ output = model(x)
1122
+ target_class = int(torch.argmax(output, dim=1).item())
1123
+
1124
+ # Compute LRP with layer relevances
1125
+ input_relevances, layer_relevances = self._compute_lrp(
1126
+ instance, target_class, return_layer_relevances=True
1127
+ )
1128
+
1129
+ return {
1130
+ "input_relevances": [float(x) for x in input_relevances],
1131
+ "layer_relevances": {
1132
+ name: [float(x) for x in rel] if isinstance(rel, np.ndarray) else float(rel)
1133
+ for name, rel in layer_relevances.items()
1134
+ },
1135
+ "target_class": target_class,
1136
+ "rule": self.rule,
1137
+ "feature_names": self.feature_names
1138
+ }
1139
+
1140
+ def compare_rules(
1141
+ self,
1142
+ instance: np.ndarray,
1143
+ target_class: Optional[int] = None,
1144
+ rules: Optional[List[str]] = None
1145
+ ) -> Dict[str, Dict[str, Any]]:
1146
+ """
1147
+ Compare different LRP rules on the same instance.
1148
+
1149
+ Useful for understanding how different rules affect attributions
1150
+ and for selecting the most appropriate rule for your use case.
1151
+
1152
+ Args:
1153
+ instance: Input instance.
1154
+ target_class: Target class for comparison.
1155
+ rules: List of rules to compare. If None, compares all rules.
1156
+
1157
+ Returns:
1158
+ Dictionary mapping rule names to their attribution results.
1159
+ """
1160
+ instance = np.array(instance).astype(np.float32)
1161
+
1162
+ # Determine target class
1163
+ if target_class is None and self.class_names:
1164
+ # Get prediction using properly shaped input
1165
+ model = self._get_pytorch_model()
1166
+ model.eval()
1167
+ with torch.no_grad():
1168
+ x = self._prepare_input_tensor(instance)
1169
+ output = model(x)
1170
+ target_class = int(torch.argmax(output, dim=1).item())
1171
+
1172
+ if rules is None:
1173
+ rules = ["epsilon", "gamma", "alpha_beta", "z_plus"]
1174
+
1175
+ results = {}
1176
+
1177
+ # Save original settings
1178
+ original_rule = self.rule
1179
+
1180
+ for rule in rules:
1181
+ self.rule = rule
1182
+
1183
+ try:
1184
+ attributions = self._compute_lrp(instance, target_class)
1185
+
1186
+ # Find top feature
1187
+ top_idx = int(np.argmax(np.abs(attributions)))
1188
+ if top_idx < len(self.feature_names):
1189
+ top_feature = self.feature_names[top_idx]
1190
+ else:
1191
+ top_feature = f"feature_{top_idx}"
1192
+
1193
+ results[rule] = {
1194
+ "attributions": [float(x) for x in attributions],
1195
+ "top_feature": top_feature,
1196
+ "top_attribution": float(attributions[top_idx]),
1197
+ "attribution_sum": float(np.sum(attributions)),
1198
+ "attribution_range": (float(np.min(attributions)), float(np.max(attributions)))
1199
+ }
1200
+ except Exception as e:
1201
+ results[rule] = {"error": str(e)}
1202
+
1203
+ # Restore original rule
1204
+ self.rule = original_rule
1205
+
1206
+ return results