birder 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. birder/adversarial/__init__.py +13 -0
  2. birder/adversarial/base.py +101 -0
  3. birder/adversarial/deepfool.py +173 -0
  4. birder/adversarial/fgsm.py +51 -18
  5. birder/adversarial/pgd.py +79 -28
  6. birder/adversarial/simba.py +172 -0
  7. birder/common/training_cli.py +11 -3
  8. birder/common/training_utils.py +18 -1
  9. birder/inference/data_parallel.py +1 -2
  10. birder/introspection/__init__.py +10 -6
  11. birder/introspection/attention_rollout.py +122 -54
  12. birder/introspection/base.py +73 -29
  13. birder/introspection/gradcam.py +71 -100
  14. birder/introspection/guided_backprop.py +146 -72
  15. birder/introspection/transformer_attribution.py +182 -0
  16. birder/net/detection/deformable_detr.py +14 -12
  17. birder/net/detection/detr.py +7 -3
  18. birder/net/detection/rt_detr_v1.py +3 -3
  19. birder/net/detection/yolo_v3.py +6 -11
  20. birder/net/detection/yolo_v4.py +7 -18
  21. birder/net/detection/yolo_v4_tiny.py +3 -3
  22. birder/net/fastvit.py +1 -1
  23. birder/net/mim/mae_vit.py +7 -8
  24. birder/net/pit.py +1 -1
  25. birder/net/resnet_v1.py +94 -34
  26. birder/net/ssl/data2vec.py +1 -1
  27. birder/net/ssl/data2vec2.py +4 -2
  28. birder/results/gui.py +15 -2
  29. birder/scripts/predict_detection.py +33 -1
  30. birder/scripts/train.py +24 -17
  31. birder/scripts/train_barlow_twins.py +10 -7
  32. birder/scripts/train_byol.py +10 -7
  33. birder/scripts/train_capi.py +12 -9
  34. birder/scripts/train_data2vec.py +10 -7
  35. birder/scripts/train_data2vec2.py +10 -7
  36. birder/scripts/train_detection.py +42 -18
  37. birder/scripts/train_dino_v1.py +10 -7
  38. birder/scripts/train_dino_v2.py +10 -7
  39. birder/scripts/train_dino_v2_dist.py +17 -7
  40. birder/scripts/train_franca.py +10 -7
  41. birder/scripts/train_i_jepa.py +17 -13
  42. birder/scripts/train_ibot.py +10 -7
  43. birder/scripts/train_kd.py +24 -18
  44. birder/scripts/train_mim.py +11 -10
  45. birder/scripts/train_mmcr.py +10 -7
  46. birder/scripts/train_rotnet.py +10 -7
  47. birder/scripts/train_simclr.py +10 -7
  48. birder/scripts/train_vicreg.py +10 -7
  49. birder/tools/__main__.py +6 -2
  50. birder/tools/adversarial.py +147 -96
  51. birder/tools/auto_anchors.py +361 -0
  52. birder/tools/ensemble_model.py +1 -1
  53. birder/tools/introspection.py +58 -31
  54. birder/version.py +1 -1
  55. {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/METADATA +2 -1
  56. {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/RECORD +60 -55
  57. {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/WHEEL +0 -0
  58. {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/entry_points.txt +0 -0
  59. {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/licenses/LICENSE +0 -0
  60. {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,9 @@
1
1
  """
2
- Adapted from https://github.com/jacobgil/pytorch-grad-cam
2
+ Gradient-weighted Class Activation Mapping (Grad-CAM), adapted from
3
+ https://github.com/jacobgil/pytorch-grad-cam
4
+
5
+ Paper "Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization",
6
+ https://arxiv.org/abs/1610.02391
3
7
  """
4
8
 
5
9
  # Reference license: MIT
@@ -16,71 +20,51 @@ from torch import nn
16
20
  from torch.utils.hooks import RemovableHandle
17
21
 
18
22
  from birder.introspection.base import InterpretabilityResult
19
- from birder.introspection.base import Interpreter
23
+ from birder.introspection.base import predict_class
24
+ from birder.introspection.base import preprocess_image
25
+ from birder.introspection.base import scale_cam_image
20
26
  from birder.introspection.base import show_mask_on_image
27
+ from birder.introspection.base import validate_target_class
21
28
 
22
29
 
23
- def _scale_cam_image(
24
- cam: npt.NDArray[np.float32], target_size: Optional[tuple[int, int]] = None
25
- ) -> npt.NDArray[np.float32]:
26
- result = []
27
- for img in cam:
28
- img = img - np.min(img)
29
- img = img / (1e-7 + np.max(img))
30
- if target_size is not None:
31
- img = np.array(Image.fromarray(img).resize(target_size))
32
-
33
- result.append(img)
34
-
35
- return np.array(result, dtype=np.float32)
36
-
37
-
38
- class ClassifierOutputTarget:
39
- def __init__(self, category: int) -> None:
40
- self.category = category
30
+ def compute_cam(activations: npt.NDArray[np.float32], gradients: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
31
+ weights: npt.NDArray[np.float32] = np.mean(gradients, axis=(2, 3))
32
+ weighted_activations = weights[:, :, None, None] * activations
33
+ cam: npt.NDArray[np.float32] = weighted_activations.sum(axis=1)
34
+ cam = np.maximum(cam, 0)
41
35
 
42
- def __call__(self, model_output: torch.Tensor) -> torch.Tensor:
43
- if len(model_output.shape) == 1:
44
- return model_output[self.category]
36
+ return cam
45
37
 
46
- return model_output[:, self.category]
47
-
48
-
49
- class ActivationsAndGradients:
50
- """
51
- Class for extracting activations and
52
- registering gradients from targeted intermediate layers
53
- """
54
38
 
39
+ class ActivationCapture:
55
40
  def __init__(
56
41
  self,
57
42
  model: nn.Module,
58
43
  target_layer: nn.Module,
59
- reshape_transform: Optional[Callable[[torch.Tensor], torch.Tensor]],
44
+ reshape_transform: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
60
45
  ) -> None:
61
46
  self.model = model
62
- self.gradients: torch.Tensor
63
- self.activations: torch.Tensor
47
+ self.target_layer = target_layer
64
48
  self.reshape_transform = reshape_transform
49
+
50
+ self.activations: Optional[torch.Tensor] = None
51
+ self.gradients: Optional[torch.Tensor] = None
65
52
  self.handles: list[RemovableHandle] = []
66
53
 
67
- self.handles.append(target_layer.register_forward_hook(self.save_activation))
68
- # Because of https://github.com/pytorch/pytorch/issues/61519,
69
- # we don't use backward hook to record gradients.
70
- self.handles.append(target_layer.register_forward_hook(self.save_gradient))
54
+ # Register hooks
55
+ self.handles.append(target_layer.register_forward_hook(self._save_activation))
56
+ self.handles.append(target_layer.register_forward_hook(self._save_gradient))
71
57
 
72
- def save_activation(self, _module: nn.Module, _input: torch.Tensor, output: torch.Tensor) -> None:
58
+ def _save_activation(self, _module: nn.Module, _input: torch.Tensor, output: torch.Tensor) -> None:
73
59
  if self.reshape_transform is not None:
74
60
  output = self.reshape_transform(output)
75
61
 
76
62
  self.activations = output.cpu().detach()
77
63
 
78
- def save_gradient(self, _module: nn.Module, _input: torch.Tensor, output: torch.Tensor) -> None:
64
+ def _save_gradient(self, _module: nn.Module, _input: torch.Tensor, output: torch.Tensor) -> None:
79
65
  if hasattr(output, "requires_grad") is False or output.requires_grad is False:
80
- # You can only register hooks on tensor requires grad.
81
66
  return
82
67
 
83
- # Gradients are computed in reverse order
84
68
  def _store_grad(grad: torch.Tensor) -> None:
85
69
  if self.reshape_transform is not None:
86
70
  grad = self.reshape_transform(grad)
@@ -100,77 +84,64 @@ class ActivationsAndGradients:
100
84
  class GradCAM:
101
85
  def __init__(
102
86
  self,
103
- model: nn.Module,
87
+ net: nn.Module,
88
+ device: torch.device,
89
+ transform: Callable[..., torch.Tensor],
104
90
  target_layer: nn.Module,
105
91
  reshape_transform: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
106
92
  ) -> None:
107
- self.model = model.eval()
93
+ self.net = net.eval()
94
+ self.device = device
95
+ self.transform = transform
108
96
  self.target_layer = target_layer
109
- self.activations_and_grads = ActivationsAndGradients(self.model, target_layer, reshape_transform)
110
-
111
- def get_cam_image(
112
- self, activations: npt.NDArray[np.float32], grads: npt.NDArray[np.float32]
113
- ) -> npt.NDArray[np.float32]:
114
- weights: npt.NDArray[np.float32] = np.mean(grads, axis=(2, 3))
115
- weighted_activations = weights[:, :, None, None] * activations
116
- cam: npt.NDArray[np.float32] = weighted_activations.sum(axis=1)
117
-
118
- return cam
119
-
120
- def compute_layer_cam(self, input_tensor: torch.Tensor) -> npt.NDArray[np.float32]:
121
- target_size = (input_tensor.size(-1), input_tensor.size(-2))
122
97
 
123
- layer_activations = self.activations_and_grads.activations.numpy()
124
- layer_grads = self.activations_and_grads.gradients.numpy()
98
+ self.activation_capture = ActivationCapture(net, target_layer, reshape_transform)
125
99
 
126
- cam = self.get_cam_image(layer_activations, layer_grads)
127
- cam = np.maximum(cam, 0)
128
- scaled = _scale_cam_image(cam, target_size)
129
- return scaled[:, None, :]
100
+ def __call__(self, image: str | Path | Image.Image, target_class: Optional[int] = None) -> InterpretabilityResult:
101
+ (input_tensor, rgb_img) = preprocess_image(image, self.transform, self.device)
102
+ input_tensor.requires_grad_(True)
130
103
 
131
- def __call__(
132
- self, input_tensor: torch.Tensor, target: Optional[ClassifierOutputTarget] = None
133
- ) -> npt.NDArray[np.float32]:
134
- output = self.activations_and_grads(input_tensor)
135
- if target is None:
136
- category = np.argmax(output.cpu().data.numpy(), axis=-1)
137
- target = ClassifierOutputTarget(category)
104
+ # Forward pass
105
+ logits = self.activation_capture(input_tensor)
138
106
 
139
- self.model.zero_grad()
140
- loss = target(output)
141
- loss.backward(retain_graph=True)
142
-
143
- cam_per_layer = self.compute_layer_cam(input_tensor)
144
- cam_per_layer = np.mean(cam_per_layer, axis=1)
145
- cam_per_layer = _scale_cam_image(cam_per_layer)
146
-
147
- return cam_per_layer
107
+ # Determine target class
108
+ if target_class is None:
109
+ target_class = predict_class(logits)
110
+ else:
111
+ validate_target_class(target_class, logits.shape[-1])
148
112
 
149
- def __del__(self) -> None:
150
- self.activations_and_grads.release()
113
+ # Backward pass
114
+ self.net.zero_grad()
115
+ loss = logits[0, target_class]
116
+ loss.backward(retain_graph=False)
151
117
 
118
+ # Get captured activations and gradients
119
+ if self.activation_capture.activations is None:
120
+ raise RuntimeError("No activations captured")
152
121
 
153
- class GradCamInterpreter(Interpreter):
154
- def __init__(
155
- self,
156
- model: nn.Module,
157
- device: torch.device,
158
- transform: Callable[..., torch.Tensor],
159
- target_layer: nn.Module,
160
- reshape_transform: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
161
- ) -> None:
162
- super().__init__(model, device, transform)
163
- self.grad_cam = GradCAM(model, target_layer, reshape_transform=reshape_transform)
122
+ if self.activation_capture.gradients is None:
123
+ raise RuntimeError("No gradients captured")
164
124
 
165
- def interpret(self, image: str | Path | Image.Image, target_class: Optional[int] = None) -> InterpretabilityResult:
166
- (input_tensor, rgb_img) = self._preprocess_image(image)
125
+ activations = self.activation_capture.activations.numpy()
126
+ gradients = self.activation_capture.gradients.numpy()
167
127
 
168
- if target_class is not None:
169
- target = ClassifierOutputTarget(target_class)
170
- else:
171
- target = None
128
+ # Compute CAM
129
+ cam = compute_cam(activations, gradients)
130
+ target_size = (input_tensor.size(-1), input_tensor.size(-2))
131
+ cam_scaled = scale_cam_image(cam, target_size)
132
+ grayscale_cam = cam_scaled[0]
172
133
 
173
- grayscale_cam = self.grad_cam(input_tensor, target=target)[0, :]
134
+ # Create visualization
174
135
  visualization = show_mask_on_image(rgb_img, grayscale_cam)
175
136
 
176
- return InterpretabilityResult(rgb_img, visualization, raw_output=grayscale_cam)
137
+ return InterpretabilityResult(
138
+ original_image=rgb_img,
139
+ visualization=visualization,
140
+ raw_output=grayscale_cam,
141
+ logits=logits.detach(),
142
+ predicted_class=target_class,
143
+ )
144
+
145
+ def __del__(self) -> None:
146
+ if hasattr(self, "activation_capture") is True:
147
+ self.activation_capture.release()
@@ -1,47 +1,29 @@
1
1
  """
2
- Adapted from https://github.com/jacobgil/pytorch-grad-cam
2
+ Guided Backpropagation, adapted from
3
+ https://github.com/jacobgil/pytorch-grad-cam
3
4
 
4
5
  Paper "Striving for Simplicity: The All Convolutional Net", https://arxiv.org/abs/1412.6806
5
6
  """
6
7
 
7
8
  # Reference license: MIT
8
9
 
10
+ import math
11
+ from collections.abc import Callable
9
12
  from pathlib import Path
10
13
  from typing import Any
11
14
  from typing import Optional
12
15
 
13
- import numpy as np
14
- import numpy.typing as npt
15
16
  import torch
17
+ import torch.nn.functional as F
16
18
  from PIL import Image
17
19
  from torch import nn
18
20
  from torch.autograd import Function
19
21
 
20
22
  from birder.introspection.base import InterpretabilityResult
21
- from birder.introspection.base import Interpreter
22
-
23
-
24
- def _deprocess_image(img: npt.NDArray[np.float32]) -> npt.NDArray[np.uint8]:
25
- """
26
- See https://github.com/jacobgil/keras-grad-cam/blob/master/grad-cam.py#L65
27
- """
28
-
29
- img = img - np.mean(img)
30
- img = img / (np.std(img) + 1e-5)
31
- img = img * 0.1
32
- img = img + 0.5
33
- img = np.clip(img, 0, 1)
34
-
35
- return np.array(img * 255).astype(np.uint8)
36
-
37
-
38
- # pylint: disable=protected-access
39
- def _replace_all_layer_type_recursive(model: nn.Module, old_layer_type: nn.Module, new_layer: nn.Module) -> None:
40
- for name, layer in model._modules.items():
41
- if isinstance(layer, old_layer_type):
42
- model._modules[name] = new_layer
43
-
44
- _replace_all_layer_type_recursive(layer, old_layer_type, new_layer)
23
+ from birder.introspection.base import deprocess_image
24
+ from birder.introspection.base import predict_class
25
+ from birder.introspection.base import preprocess_image
26
+ from birder.introspection.base import validate_target_class
45
27
 
46
28
 
47
29
  # pylint: disable=abstract-method,arguments-differ
@@ -57,7 +39,6 @@ class GuidedBackpropReLU(Function):
57
39
  @staticmethod
58
40
  def backward(ctx: Any, grad_output: torch.Tensor) -> torch.Tensor:
59
41
  (input_img, _output) = ctx.saved_tensors
60
- grad_input = None
61
42
 
62
43
  positive_mask_1 = (input_img > 0).type_as(grad_output)
63
44
  positive_mask_2 = (grad_output > 0).type_as(grad_output)
@@ -71,7 +52,7 @@ class GuidedBackpropReLU(Function):
71
52
 
72
53
 
73
54
  # pylint: disable=abstract-method,arguments-differ
74
- class GuidedBackpropSwish(Function):
55
+ class GuidedBackpropSiLU(Function):
75
56
  @staticmethod
76
57
  def forward(ctx: Any, input_img: torch.Tensor) -> torch.Tensor:
77
58
  result = input_img * torch.sigmoid(input_img)
@@ -90,66 +71,159 @@ class GuidedBackpropSwish(Function):
90
71
  return grad_input
91
72
 
92
73
 
93
- class GuidedBackpropReLUAsModule(nn.Module):
94
- def forward(self, input_img: torch.Tensor) -> Any:
95
- return GuidedBackpropReLU.apply(input_img)
74
+ # pylint: disable=abstract-method,arguments-differ
75
+ class GuidedBackpropGELU(Function):
76
+ @staticmethod
77
+ def forward(ctx: Any, input_img: torch.Tensor) -> torch.Tensor:
78
+ result = F.gelu(input_img, approximate="none") # pylint:disable=not-callable
79
+ ctx.save_for_backward(input_img)
80
+ return result
96
81
 
82
+ @staticmethod
83
+ def backward(ctx: Any, grad_output: torch.Tensor) -> torch.Tensor:
84
+ x = ctx.saved_tensors[0]
97
85
 
98
- class GuidedBackpropSwishAsModule(nn.Module):
99
- def forward(self, input_img: torch.Tensor) -> Any:
100
- return GuidedBackpropSwish.apply(input_img)
86
+ sqrt_2 = math.sqrt(2.0)
87
+ sqrt_2pi = math.sqrt(2.0 * math.pi)
101
88
 
89
+ cdf = 0.5 * (1.0 + torch.erf(x / sqrt_2))
90
+ pdf = torch.exp(-0.5 * x * x) / sqrt_2pi
102
91
 
103
- class GuidedBackpropGeLUAsModule(nn.Module):
104
- def forward(self, input_img: torch.Tensor) -> Any:
105
- return GuidedBackpropSwish.apply(input_img)
92
+ d_gelu = cdf + x * pdf
106
93
 
94
+ positive_mask_1 = (x > 0).type_as(grad_output)
95
+ positive_mask_2 = (grad_output > 0).type_as(grad_output)
107
96
 
108
- class GuidedBackpropHardswishAsModule(nn.Module):
109
- def forward(self, input_img: torch.Tensor) -> Any:
110
- return GuidedBackpropSwish.apply(input_img)
97
+ grad_input = grad_output * d_gelu * positive_mask_1 * positive_mask_2
111
98
 
99
+ return grad_input
112
100
 
113
- class GuidedBackpropModel:
114
- def __init__(self, model: nn.Module) -> None:
115
- self.model = model
116
- self.model.eval()
117
101
 
118
- def forward(self, input_img: torch.Tensor) -> torch.Tensor:
119
- return self.model(input_img)
102
+ # pylint: disable=abstract-method,arguments-differ
103
+ class GuidedBackpropHardswish(Function):
104
+ @staticmethod
105
+ def forward(ctx: Any, input_img: torch.Tensor) -> torch.Tensor:
106
+ result = F.hardswish(input_img)
107
+ ctx.save_for_backward(input_img)
108
+ return result
109
+
110
+ @staticmethod
111
+ def backward(ctx: Any, grad_output: torch.Tensor) -> torch.Tensor:
112
+ x = ctx.saved_tensors[0]
113
+
114
+ grad = torch.zeros_like(x)
120
115
 
121
- def __call__(self, input_img: torch.Tensor, target_category: Optional[int] = None) -> npt.NDArray[np.float32]:
122
- _replace_all_layer_type_recursive(self.model, nn.ReLU, GuidedBackpropReLUAsModule())
123
- _replace_all_layer_type_recursive(self.model, nn.GELU, GuidedBackpropGeLUAsModule())
124
- _replace_all_layer_type_recursive(self.model, nn.SiLU, GuidedBackpropSwishAsModule())
125
- _replace_all_layer_type_recursive(self.model, nn.Hardswish, GuidedBackpropHardswishAsModule())
116
+ mask_mid = (x > -3) & (x < 3)
117
+ grad[mask_mid] = (2.0 * x[mask_mid] + 3.0) / 6.0
126
118
 
127
- input_img = input_img.requires_grad_(True)
128
- output = self.forward(input_img)
119
+ mask_high = x >= 3
120
+ grad[mask_high] = 1.0
129
121
 
130
- if target_category is None:
131
- target_category = np.argmax(output.cpu().data.numpy()).item()
122
+ positive_mask_1 = (x > 0).type_as(grad_output)
123
+ positive_mask_2 = (grad_output > 0).type_as(grad_output)
132
124
 
133
- loss = output[0, target_category]
134
- loss.backward(retain_graph=True)
125
+ grad_input = grad_output * grad * positive_mask_1 * positive_mask_2
126
+
127
+ return grad_input
135
128
 
136
- output_grad = input_img.grad.cpu().data.numpy()
137
- output_grad = output_grad[0, :, :, :]
138
- output_grad = output_grad.transpose((1, 2, 0))
139
129
 
140
- _replace_all_layer_type_recursive(self.model, GuidedBackpropHardswishAsModule, nn.Hardswish())
141
- _replace_all_layer_type_recursive(self.model, GuidedBackpropSwishAsModule, nn.SiLU())
142
- _replace_all_layer_type_recursive(self.model, GuidedBackpropGeLUAsModule, nn.GELU())
143
- _replace_all_layer_type_recursive(self.model, GuidedBackpropReLUAsModule, nn.ReLU())
130
+ class GuidedReLU(nn.Module):
131
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
132
+ return GuidedBackpropReLU.apply(x)
144
133
 
145
- return output_grad # type: ignore[no-any-return]
146
134
 
135
+ class GuidedSiLU(nn.Module):
136
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
137
+ return GuidedBackpropSiLU.apply(x)
147
138
 
148
- class GuidedBackpropInterpreter(Interpreter):
149
- def interpret(self, image: str | Path | Image.Image, target_class: Optional[int] = None) -> InterpretabilityResult:
150
- (input_tensor, rgb_img) = self._preprocess_image(image)
151
139
 
152
- guided_bp = GuidedBackpropModel(self.model)
153
- bp_img = guided_bp(input_tensor, target_category=target_class)
140
+ class GuidedGELU(nn.Module):
141
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
142
+ return GuidedBackpropGELU.apply(x)
154
143
 
155
- return InterpretabilityResult(rgb_img, _deprocess_image(bp_img * rgb_img), raw_output=bp_img)
144
+
145
+ class GuidedHardswish(nn.Module):
146
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
147
+ return GuidedBackpropHardswish.apply(x)
148
+
149
+
150
+ # Activation replacement mapping
151
+ ACTIVATION_REPLACEMENTS: dict[type, type] = {
152
+ nn.ReLU: GuidedReLU,
153
+ nn.SiLU: GuidedSiLU,
154
+ nn.GELU: GuidedGELU,
155
+ nn.Hardswish: GuidedHardswish,
156
+ }
157
+
158
+
159
+ def replace_activations_recursive(model: nn.Module, replacements: dict[type, type]) -> None:
160
+ """
161
+ NOTE: This ONLY works for activations defined as nn.Module objects (e.g., self.act = nn.ReLU()).
162
+ It will NOT affect functional calls inside forward methods, such as F.relu(x) or F.gelu(x).
163
+ """
164
+
165
+ for name, module in list(model._modules.items()): # pylint: disable=protected-access
166
+ for old_type, new_type in replacements.items():
167
+ if isinstance(module, old_type):
168
+ model._modules[name] = new_type() # pylint: disable=protected-access
169
+ break
170
+ else:
171
+ # Recurse into submodules
172
+ replace_activations_recursive(module, replacements)
173
+
174
+
175
+ def restore_activations_recursive(model: nn.Module, guided_types: dict[type, type]) -> None:
176
+ reverse_mapping = {v: k for k, v in guided_types.items()}
177
+ for name, module in list(model._modules.items()): # pylint: disable=protected-access
178
+ for guided_type, original_type in reverse_mapping.items():
179
+ if isinstance(module, guided_type):
180
+ model._modules[name] = original_type() # pylint: disable=protected-access
181
+ break
182
+ else:
183
+ restore_activations_recursive(module, guided_types)
184
+
185
+
186
+ class GuidedBackprop:
187
+ def __init__(self, net: nn.Module, device: torch.device, transform: Callable[..., torch.Tensor]) -> None:
188
+ self.net = net.eval()
189
+ self.device = device
190
+ self.transform = transform
191
+
192
+ def __call__(self, image: str | Path | Image.Image, target_class: Optional[int] = None) -> InterpretabilityResult:
193
+ (input_tensor, rgb_img) = preprocess_image(image, self.transform, self.device)
194
+
195
+ # Get prediction
196
+ with torch.inference_mode():
197
+ logits = self.net(input_tensor)
198
+
199
+ if target_class is None:
200
+ target_class = predict_class(logits)
201
+ else:
202
+ validate_target_class(target_class, logits.shape[-1])
203
+
204
+ # Replace activations with guided versions
205
+ replace_activations_recursive(self.net, ACTIVATION_REPLACEMENTS)
206
+
207
+ try:
208
+ input_tensor = input_tensor.detach().requires_grad_(True)
209
+ output = self.net(input_tensor)
210
+
211
+ loss = output[0, target_class]
212
+ loss.backward(retain_graph=False)
213
+
214
+ gradients = input_tensor.grad.cpu().numpy()
215
+ gradients = gradients[0, :, :, :] # Remove batch dim
216
+ gradients = gradients.transpose((1, 2, 0)) # CHW -> HWC
217
+
218
+ finally:
219
+ restore_activations_recursive(self.net, ACTIVATION_REPLACEMENTS)
220
+
221
+ visualization = deprocess_image(gradients * rgb_img)
222
+
223
+ return InterpretabilityResult(
224
+ original_image=rgb_img,
225
+ visualization=visualization,
226
+ raw_output=gradients,
227
+ logits=logits.detach(),
228
+ predicted_class=target_class,
229
+ )
@@ -0,0 +1,182 @@
1
+ """
2
+ Transformer Attribution (Gradient-weighted Attention Rollout), adapted from
3
+ https://github.com/hila-chefer/Transformer-Explainability
4
+
5
+ Paper "Transformer Interpretability Beyond Attention Visualization", https://arxiv.org/abs/2012.09838
6
+ """
7
+
8
+ # Reference license: MIT
9
+
10
+ from collections.abc import Callable
11
+ from pathlib import Path
12
+ from typing import Optional
13
+
14
+ import numpy as np
15
+ import torch
16
+ from PIL import Image
17
+ from torch import nn
18
+
19
+ from birder.introspection.base import InterpretabilityResult
20
+ from birder.introspection.base import predict_class
21
+ from birder.introspection.base import preprocess_image
22
+ from birder.introspection.base import show_mask_on_image
23
+ from birder.introspection.base import validate_target_class
24
+ from birder.net.vit import Encoder
25
+
26
+
27
+ def compute_attribution_rollout(
28
+ attributions: list[tuple[torch.Tensor, torch.Tensor]], num_special_tokens: int, patch_grid_shape: tuple[int, int]
29
+ ) -> torch.Tensor:
30
+ """
31
+ NOTE: Uses gradient norm per token instead of element-wise grad * attention multiplication.
32
+ """
33
+
34
+ # Assume batch size = 1
35
+ num_tokens = attributions[0][0].size(-1)
36
+ device = attributions[0][0].device
37
+
38
+ result = torch.eye(num_tokens, device=device)
39
+ with torch.no_grad():
40
+ for attn_weights, output_grad in attributions:
41
+ # Compute token importance from output gradient norm across embedding dimension
42
+ token_importance = output_grad.norm(dim=-1, keepdim=True)
43
+ token_importance = token_importance.transpose(-1, -2)
44
+
45
+ # Weight attention patterns by token importance
46
+ weighted_attn = attn_weights * token_importance.unsqueeze(1)
47
+
48
+ # Fuse attention heads and apply non-negativity constraint
49
+ relevance = weighted_attn.mean(dim=1).clamp(min=0)
50
+
51
+ # Add residual connection and normalize
52
+ eye = torch.eye(num_tokens, device=device)
53
+ normalized = (relevance + eye) / 2.0
54
+ normalized = normalized / normalized.sum(dim=-1, keepdim=True)
55
+
56
+ # Accumulate attention across layers
57
+ result = torch.matmul(normalized, result)
58
+
59
+ rollout = result[0]
60
+
61
+ if 0 < num_special_tokens:
62
+ source_to_patches = rollout[:num_special_tokens, num_special_tokens:]
63
+ mask = source_to_patches.mean(dim=0)
64
+ else:
65
+ mask = rollout.mean(dim=0)
66
+
67
+ mask = mask / (mask.max() + 1e-8)
68
+
69
+ (grid_h, grid_w) = patch_grid_shape
70
+ mask = mask.reshape(grid_h, grid_w)
71
+
72
+ return mask
73
+
74
+
75
+ class AttributionGatherer:
76
+ def __init__(self, net: nn.Module, attention_layer_name: str) -> None:
77
+ assert hasattr(net, "encoder") is True and isinstance(net.encoder, Encoder)
78
+
79
+ net.encoder.set_need_attn()
80
+
81
+ self.net = net
82
+ self.handles: list[torch.utils.hooks.RemovableHandle] = []
83
+ self._gradients: list[torch.Tensor] = []
84
+ self._attention_weights: list[torch.Tensor] = []
85
+
86
+ for name, module in self.net.named_modules():
87
+ if name.endswith(attention_layer_name) is True:
88
+ handle = module.register_forward_hook(self._capture_forward)
89
+ self.handles.append(handle)
90
+
91
+ def _capture_forward(
92
+ self, _module: nn.Module, _inputs: tuple[torch.Tensor, ...], output: tuple[torch.Tensor, ...] | torch.Tensor
93
+ ) -> None:
94
+ output_tensor = output[0]
95
+ attn_weights = output[1]
96
+
97
+ self._attention_weights.append(attn_weights.detach())
98
+ if output_tensor.requires_grad:
99
+
100
+ def _store_grad(grad: torch.Tensor) -> None:
101
+ self._gradients.append(grad.detach())
102
+
103
+ output_tensor.register_hook(_store_grad)
104
+
105
+ def get_captured_data(self) -> list[tuple[torch.Tensor, torch.Tensor]]:
106
+ if len(self._attention_weights) != len(self._gradients):
107
+ raise RuntimeError(
108
+ f"Mismatch between attention weights ({len(self._attention_weights)}) "
109
+ f"and gradients ({len(self._gradients)}). Ensure backward() was called."
110
+ )
111
+
112
+ if len(self._attention_weights) == 0:
113
+ raise RuntimeError("No attention data captured. Ensure the model has attention layers.")
114
+
115
+ # Pair attention weights with output gradients (gradients reversed to match forward order)
116
+ results = [(attn.cpu(), grad.cpu()) for attn, grad in zip(self._attention_weights, reversed(self._gradients))]
117
+
118
+ # Clear storage for next forward pass
119
+ self._gradients = []
120
+ self._attention_weights = []
121
+
122
+ return results
123
+
124
+ def release(self) -> None:
125
+ for handle in self.handles:
126
+ handle.remove()
127
+
128
+
129
+ class TransformerAttribution:
130
+ def __init__(
131
+ self,
132
+ net: nn.Module,
133
+ device: torch.device,
134
+ transform: Callable[..., torch.Tensor],
135
+ attention_layer_name: str = "self_attention",
136
+ ) -> None:
137
+ self.net = net.eval()
138
+ self.device = device
139
+ self.transform = transform
140
+ self.gatherer = AttributionGatherer(net, attention_layer_name)
141
+
142
+ def __call__(self, image: str | Path | Image.Image, target_class: Optional[int] = None) -> InterpretabilityResult:
143
+ (input_tensor, rgb_img) = preprocess_image(image, self.transform, self.device)
144
+ input_tensor.requires_grad_(True)
145
+
146
+ self.net.zero_grad()
147
+ logits = self.net(input_tensor)
148
+
149
+ if target_class is None:
150
+ target_class = predict_class(logits)
151
+ else:
152
+ validate_target_class(target_class, logits.shape[-1])
153
+
154
+ score = logits[0, target_class]
155
+ score.backward()
156
+
157
+ attribution_data = self.gatherer.get_captured_data()
158
+
159
+ (_, _, H, W) = input_tensor.shape
160
+ patch_grid_shape = (H // self.net.stem_stride, W // self.net.stem_stride)
161
+
162
+ attribution_map = compute_attribution_rollout(
163
+ attribution_data, num_special_tokens=self.net.num_special_tokens, patch_grid_shape=patch_grid_shape
164
+ )
165
+
166
+ attribution_img = Image.fromarray(attribution_map.numpy())
167
+ attribution_img = attribution_img.resize((rgb_img.shape[1], rgb_img.shape[0]))
168
+ attribution_arr = np.array(attribution_img)
169
+
170
+ visualization = show_mask_on_image(rgb_img, attribution_arr, image_weight=0.4)
171
+
172
+ return InterpretabilityResult(
173
+ original_image=rgb_img,
174
+ visualization=visualization,
175
+ raw_output=attribution_arr,
176
+ logits=logits.detach(),
177
+ predicted_class=target_class,
178
+ )
179
+
180
+ def __del__(self) -> None:
181
+ if hasattr(self, "gatherer") is True:
182
+ self.gatherer.release()