birder 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/adversarial/__init__.py +13 -0
- birder/adversarial/base.py +101 -0
- birder/adversarial/deepfool.py +173 -0
- birder/adversarial/fgsm.py +51 -18
- birder/adversarial/pgd.py +79 -28
- birder/adversarial/simba.py +172 -0
- birder/common/training_cli.py +11 -3
- birder/common/training_utils.py +18 -1
- birder/inference/data_parallel.py +1 -2
- birder/introspection/__init__.py +10 -6
- birder/introspection/attention_rollout.py +122 -54
- birder/introspection/base.py +73 -29
- birder/introspection/gradcam.py +71 -100
- birder/introspection/guided_backprop.py +146 -72
- birder/introspection/transformer_attribution.py +182 -0
- birder/net/detection/deformable_detr.py +14 -12
- birder/net/detection/detr.py +7 -3
- birder/net/detection/rt_detr_v1.py +3 -3
- birder/net/detection/yolo_v3.py +6 -11
- birder/net/detection/yolo_v4.py +7 -18
- birder/net/detection/yolo_v4_tiny.py +3 -3
- birder/net/fastvit.py +1 -1
- birder/net/mim/mae_vit.py +7 -8
- birder/net/pit.py +1 -1
- birder/net/resnet_v1.py +94 -34
- birder/net/ssl/data2vec.py +1 -1
- birder/net/ssl/data2vec2.py +4 -2
- birder/results/gui.py +15 -2
- birder/scripts/predict_detection.py +33 -1
- birder/scripts/train.py +24 -17
- birder/scripts/train_barlow_twins.py +10 -7
- birder/scripts/train_byol.py +10 -7
- birder/scripts/train_capi.py +12 -9
- birder/scripts/train_data2vec.py +10 -7
- birder/scripts/train_data2vec2.py +10 -7
- birder/scripts/train_detection.py +42 -18
- birder/scripts/train_dino_v1.py +10 -7
- birder/scripts/train_dino_v2.py +10 -7
- birder/scripts/train_dino_v2_dist.py +17 -7
- birder/scripts/train_franca.py +10 -7
- birder/scripts/train_i_jepa.py +17 -13
- birder/scripts/train_ibot.py +10 -7
- birder/scripts/train_kd.py +24 -18
- birder/scripts/train_mim.py +11 -10
- birder/scripts/train_mmcr.py +10 -7
- birder/scripts/train_rotnet.py +10 -7
- birder/scripts/train_simclr.py +10 -7
- birder/scripts/train_vicreg.py +10 -7
- birder/tools/__main__.py +6 -2
- birder/tools/adversarial.py +147 -96
- birder/tools/auto_anchors.py +361 -0
- birder/tools/ensemble_model.py +1 -1
- birder/tools/introspection.py +58 -31
- birder/version.py +1 -1
- {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/METADATA +2 -1
- {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/RECORD +60 -55
- {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/WHEEL +0 -0
- {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/entry_points.txt +0 -0
- {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/top_level.txt +0 -0
birder/introspection/gradcam.py
CHANGED
|
@@ -1,5 +1,9 @@
|
|
|
1
1
|
"""
|
|
2
|
-
|
|
2
|
+
Gradient-weighted Class Activation Mapping (Grad-CAM), adapted from
|
|
3
|
+
https://github.com/jacobgil/pytorch-grad-cam
|
|
4
|
+
|
|
5
|
+
Paper "Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization",
|
|
6
|
+
https://arxiv.org/abs/1610.02391
|
|
3
7
|
"""
|
|
4
8
|
|
|
5
9
|
# Reference license: MIT
|
|
@@ -16,71 +20,51 @@ from torch import nn
|
|
|
16
20
|
from torch.utils.hooks import RemovableHandle
|
|
17
21
|
|
|
18
22
|
from birder.introspection.base import InterpretabilityResult
|
|
19
|
-
from birder.introspection.base import
|
|
23
|
+
from birder.introspection.base import predict_class
|
|
24
|
+
from birder.introspection.base import preprocess_image
|
|
25
|
+
from birder.introspection.base import scale_cam_image
|
|
20
26
|
from birder.introspection.base import show_mask_on_image
|
|
27
|
+
from birder.introspection.base import validate_target_class
|
|
21
28
|
|
|
22
29
|
|
|
23
|
-
def
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
img = img - np.min(img)
|
|
29
|
-
img = img / (1e-7 + np.max(img))
|
|
30
|
-
if target_size is not None:
|
|
31
|
-
img = np.array(Image.fromarray(img).resize(target_size))
|
|
32
|
-
|
|
33
|
-
result.append(img)
|
|
34
|
-
|
|
35
|
-
return np.array(result, dtype=np.float32)
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
class ClassifierOutputTarget:
|
|
39
|
-
def __init__(self, category: int) -> None:
|
|
40
|
-
self.category = category
|
|
30
|
+
def compute_cam(activations: npt.NDArray[np.float32], gradients: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
|
|
31
|
+
weights: npt.NDArray[np.float32] = np.mean(gradients, axis=(2, 3))
|
|
32
|
+
weighted_activations = weights[:, :, None, None] * activations
|
|
33
|
+
cam: npt.NDArray[np.float32] = weighted_activations.sum(axis=1)
|
|
34
|
+
cam = np.maximum(cam, 0)
|
|
41
35
|
|
|
42
|
-
|
|
43
|
-
if len(model_output.shape) == 1:
|
|
44
|
-
return model_output[self.category]
|
|
36
|
+
return cam
|
|
45
37
|
|
|
46
|
-
return model_output[:, self.category]
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
class ActivationsAndGradients:
|
|
50
|
-
"""
|
|
51
|
-
Class for extracting activations and
|
|
52
|
-
registering gradients from targeted intermediate layers
|
|
53
|
-
"""
|
|
54
38
|
|
|
39
|
+
class ActivationCapture:
|
|
55
40
|
def __init__(
|
|
56
41
|
self,
|
|
57
42
|
model: nn.Module,
|
|
58
43
|
target_layer: nn.Module,
|
|
59
|
-
reshape_transform: Optional[Callable[[torch.Tensor], torch.Tensor]],
|
|
44
|
+
reshape_transform: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
|
|
60
45
|
) -> None:
|
|
61
46
|
self.model = model
|
|
62
|
-
self.
|
|
63
|
-
self.activations: torch.Tensor
|
|
47
|
+
self.target_layer = target_layer
|
|
64
48
|
self.reshape_transform = reshape_transform
|
|
49
|
+
|
|
50
|
+
self.activations: Optional[torch.Tensor] = None
|
|
51
|
+
self.gradients: Optional[torch.Tensor] = None
|
|
65
52
|
self.handles: list[RemovableHandle] = []
|
|
66
53
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
self.handles.append(target_layer.register_forward_hook(self.save_gradient))
|
|
54
|
+
# Register hooks
|
|
55
|
+
self.handles.append(target_layer.register_forward_hook(self._save_activation))
|
|
56
|
+
self.handles.append(target_layer.register_forward_hook(self._save_gradient))
|
|
71
57
|
|
|
72
|
-
def
|
|
58
|
+
def _save_activation(self, _module: nn.Module, _input: torch.Tensor, output: torch.Tensor) -> None:
|
|
73
59
|
if self.reshape_transform is not None:
|
|
74
60
|
output = self.reshape_transform(output)
|
|
75
61
|
|
|
76
62
|
self.activations = output.cpu().detach()
|
|
77
63
|
|
|
78
|
-
def
|
|
64
|
+
def _save_gradient(self, _module: nn.Module, _input: torch.Tensor, output: torch.Tensor) -> None:
|
|
79
65
|
if hasattr(output, "requires_grad") is False or output.requires_grad is False:
|
|
80
|
-
# You can only register hooks on tensor requires grad.
|
|
81
66
|
return
|
|
82
67
|
|
|
83
|
-
# Gradients are computed in reverse order
|
|
84
68
|
def _store_grad(grad: torch.Tensor) -> None:
|
|
85
69
|
if self.reshape_transform is not None:
|
|
86
70
|
grad = self.reshape_transform(grad)
|
|
@@ -100,77 +84,64 @@ class ActivationsAndGradients:
|
|
|
100
84
|
class GradCAM:
|
|
101
85
|
def __init__(
|
|
102
86
|
self,
|
|
103
|
-
|
|
87
|
+
net: nn.Module,
|
|
88
|
+
device: torch.device,
|
|
89
|
+
transform: Callable[..., torch.Tensor],
|
|
104
90
|
target_layer: nn.Module,
|
|
105
91
|
reshape_transform: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
|
|
106
92
|
) -> None:
|
|
107
|
-
self.
|
|
93
|
+
self.net = net.eval()
|
|
94
|
+
self.device = device
|
|
95
|
+
self.transform = transform
|
|
108
96
|
self.target_layer = target_layer
|
|
109
|
-
self.activations_and_grads = ActivationsAndGradients(self.model, target_layer, reshape_transform)
|
|
110
|
-
|
|
111
|
-
def get_cam_image(
|
|
112
|
-
self, activations: npt.NDArray[np.float32], grads: npt.NDArray[np.float32]
|
|
113
|
-
) -> npt.NDArray[np.float32]:
|
|
114
|
-
weights: npt.NDArray[np.float32] = np.mean(grads, axis=(2, 3))
|
|
115
|
-
weighted_activations = weights[:, :, None, None] * activations
|
|
116
|
-
cam: npt.NDArray[np.float32] = weighted_activations.sum(axis=1)
|
|
117
|
-
|
|
118
|
-
return cam
|
|
119
|
-
|
|
120
|
-
def compute_layer_cam(self, input_tensor: torch.Tensor) -> npt.NDArray[np.float32]:
|
|
121
|
-
target_size = (input_tensor.size(-1), input_tensor.size(-2))
|
|
122
97
|
|
|
123
|
-
|
|
124
|
-
layer_grads = self.activations_and_grads.gradients.numpy()
|
|
98
|
+
self.activation_capture = ActivationCapture(net, target_layer, reshape_transform)
|
|
125
99
|
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
return scaled[:, None, :]
|
|
100
|
+
def __call__(self, image: str | Path | Image.Image, target_class: Optional[int] = None) -> InterpretabilityResult:
|
|
101
|
+
(input_tensor, rgb_img) = preprocess_image(image, self.transform, self.device)
|
|
102
|
+
input_tensor.requires_grad_(True)
|
|
130
103
|
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
) -> npt.NDArray[np.float32]:
|
|
134
|
-
output = self.activations_and_grads(input_tensor)
|
|
135
|
-
if target is None:
|
|
136
|
-
category = np.argmax(output.cpu().data.numpy(), axis=-1)
|
|
137
|
-
target = ClassifierOutputTarget(category)
|
|
104
|
+
# Forward pass
|
|
105
|
+
logits = self.activation_capture(input_tensor)
|
|
138
106
|
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
cam_per_layer = np.mean(cam_per_layer, axis=1)
|
|
145
|
-
cam_per_layer = _scale_cam_image(cam_per_layer)
|
|
146
|
-
|
|
147
|
-
return cam_per_layer
|
|
107
|
+
# Determine target class
|
|
108
|
+
if target_class is None:
|
|
109
|
+
target_class = predict_class(logits)
|
|
110
|
+
else:
|
|
111
|
+
validate_target_class(target_class, logits.shape[-1])
|
|
148
112
|
|
|
149
|
-
|
|
150
|
-
self.
|
|
113
|
+
# Backward pass
|
|
114
|
+
self.net.zero_grad()
|
|
115
|
+
loss = logits[0, target_class]
|
|
116
|
+
loss.backward(retain_graph=False)
|
|
151
117
|
|
|
118
|
+
# Get captured activations and gradients
|
|
119
|
+
if self.activation_capture.activations is None:
|
|
120
|
+
raise RuntimeError("No activations captured")
|
|
152
121
|
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
self,
|
|
156
|
-
model: nn.Module,
|
|
157
|
-
device: torch.device,
|
|
158
|
-
transform: Callable[..., torch.Tensor],
|
|
159
|
-
target_layer: nn.Module,
|
|
160
|
-
reshape_transform: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
|
|
161
|
-
) -> None:
|
|
162
|
-
super().__init__(model, device, transform)
|
|
163
|
-
self.grad_cam = GradCAM(model, target_layer, reshape_transform=reshape_transform)
|
|
122
|
+
if self.activation_capture.gradients is None:
|
|
123
|
+
raise RuntimeError("No gradients captured")
|
|
164
124
|
|
|
165
|
-
|
|
166
|
-
|
|
125
|
+
activations = self.activation_capture.activations.numpy()
|
|
126
|
+
gradients = self.activation_capture.gradients.numpy()
|
|
167
127
|
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
128
|
+
# Compute CAM
|
|
129
|
+
cam = compute_cam(activations, gradients)
|
|
130
|
+
target_size = (input_tensor.size(-1), input_tensor.size(-2))
|
|
131
|
+
cam_scaled = scale_cam_image(cam, target_size)
|
|
132
|
+
grayscale_cam = cam_scaled[0]
|
|
172
133
|
|
|
173
|
-
|
|
134
|
+
# Create visualization
|
|
174
135
|
visualization = show_mask_on_image(rgb_img, grayscale_cam)
|
|
175
136
|
|
|
176
|
-
return InterpretabilityResult(
|
|
137
|
+
return InterpretabilityResult(
|
|
138
|
+
original_image=rgb_img,
|
|
139
|
+
visualization=visualization,
|
|
140
|
+
raw_output=grayscale_cam,
|
|
141
|
+
logits=logits.detach(),
|
|
142
|
+
predicted_class=target_class,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
def __del__(self) -> None:
|
|
146
|
+
if hasattr(self, "activation_capture") is True:
|
|
147
|
+
self.activation_capture.release()
|
|
@@ -1,47 +1,29 @@
|
|
|
1
1
|
"""
|
|
2
|
-
|
|
2
|
+
Guided Backpropagation, adapted from
|
|
3
|
+
https://github.com/jacobgil/pytorch-grad-cam
|
|
3
4
|
|
|
4
5
|
Paper "Striving for Simplicity: The All Convolutional Net", https://arxiv.org/abs/1412.6806
|
|
5
6
|
"""
|
|
6
7
|
|
|
7
8
|
# Reference license: MIT
|
|
8
9
|
|
|
10
|
+
import math
|
|
11
|
+
from collections.abc import Callable
|
|
9
12
|
from pathlib import Path
|
|
10
13
|
from typing import Any
|
|
11
14
|
from typing import Optional
|
|
12
15
|
|
|
13
|
-
import numpy as np
|
|
14
|
-
import numpy.typing as npt
|
|
15
16
|
import torch
|
|
17
|
+
import torch.nn.functional as F
|
|
16
18
|
from PIL import Image
|
|
17
19
|
from torch import nn
|
|
18
20
|
from torch.autograd import Function
|
|
19
21
|
|
|
20
22
|
from birder.introspection.base import InterpretabilityResult
|
|
21
|
-
from birder.introspection.base import
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
"""
|
|
26
|
-
See https://github.com/jacobgil/keras-grad-cam/blob/master/grad-cam.py#L65
|
|
27
|
-
"""
|
|
28
|
-
|
|
29
|
-
img = img - np.mean(img)
|
|
30
|
-
img = img / (np.std(img) + 1e-5)
|
|
31
|
-
img = img * 0.1
|
|
32
|
-
img = img + 0.5
|
|
33
|
-
img = np.clip(img, 0, 1)
|
|
34
|
-
|
|
35
|
-
return np.array(img * 255).astype(np.uint8)
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
# pylint: disable=protected-access
|
|
39
|
-
def _replace_all_layer_type_recursive(model: nn.Module, old_layer_type: nn.Module, new_layer: nn.Module) -> None:
|
|
40
|
-
for name, layer in model._modules.items():
|
|
41
|
-
if isinstance(layer, old_layer_type):
|
|
42
|
-
model._modules[name] = new_layer
|
|
43
|
-
|
|
44
|
-
_replace_all_layer_type_recursive(layer, old_layer_type, new_layer)
|
|
23
|
+
from birder.introspection.base import deprocess_image
|
|
24
|
+
from birder.introspection.base import predict_class
|
|
25
|
+
from birder.introspection.base import preprocess_image
|
|
26
|
+
from birder.introspection.base import validate_target_class
|
|
45
27
|
|
|
46
28
|
|
|
47
29
|
# pylint: disable=abstract-method,arguments-differ
|
|
@@ -57,7 +39,6 @@ class GuidedBackpropReLU(Function):
|
|
|
57
39
|
@staticmethod
|
|
58
40
|
def backward(ctx: Any, grad_output: torch.Tensor) -> torch.Tensor:
|
|
59
41
|
(input_img, _output) = ctx.saved_tensors
|
|
60
|
-
grad_input = None
|
|
61
42
|
|
|
62
43
|
positive_mask_1 = (input_img > 0).type_as(grad_output)
|
|
63
44
|
positive_mask_2 = (grad_output > 0).type_as(grad_output)
|
|
@@ -71,7 +52,7 @@ class GuidedBackpropReLU(Function):
|
|
|
71
52
|
|
|
72
53
|
|
|
73
54
|
# pylint: disable=abstract-method,arguments-differ
|
|
74
|
-
class
|
|
55
|
+
class GuidedBackpropSiLU(Function):
|
|
75
56
|
@staticmethod
|
|
76
57
|
def forward(ctx: Any, input_img: torch.Tensor) -> torch.Tensor:
|
|
77
58
|
result = input_img * torch.sigmoid(input_img)
|
|
@@ -90,66 +71,159 @@ class GuidedBackpropSwish(Function):
|
|
|
90
71
|
return grad_input
|
|
91
72
|
|
|
92
73
|
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
74
|
+
# pylint: disable=abstract-method,arguments-differ
|
|
75
|
+
class GuidedBackpropGELU(Function):
|
|
76
|
+
@staticmethod
|
|
77
|
+
def forward(ctx: Any, input_img: torch.Tensor) -> torch.Tensor:
|
|
78
|
+
result = F.gelu(input_img, approximate="none") # pylint:disable=not-callable
|
|
79
|
+
ctx.save_for_backward(input_img)
|
|
80
|
+
return result
|
|
96
81
|
|
|
82
|
+
@staticmethod
|
|
83
|
+
def backward(ctx: Any, grad_output: torch.Tensor) -> torch.Tensor:
|
|
84
|
+
x = ctx.saved_tensors[0]
|
|
97
85
|
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
return GuidedBackpropSwish.apply(input_img)
|
|
86
|
+
sqrt_2 = math.sqrt(2.0)
|
|
87
|
+
sqrt_2pi = math.sqrt(2.0 * math.pi)
|
|
101
88
|
|
|
89
|
+
cdf = 0.5 * (1.0 + torch.erf(x / sqrt_2))
|
|
90
|
+
pdf = torch.exp(-0.5 * x * x) / sqrt_2pi
|
|
102
91
|
|
|
103
|
-
|
|
104
|
-
def forward(self, input_img: torch.Tensor) -> Any:
|
|
105
|
-
return GuidedBackpropSwish.apply(input_img)
|
|
92
|
+
d_gelu = cdf + x * pdf
|
|
106
93
|
|
|
94
|
+
positive_mask_1 = (x > 0).type_as(grad_output)
|
|
95
|
+
positive_mask_2 = (grad_output > 0).type_as(grad_output)
|
|
107
96
|
|
|
108
|
-
|
|
109
|
-
def forward(self, input_img: torch.Tensor) -> Any:
|
|
110
|
-
return GuidedBackpropSwish.apply(input_img)
|
|
97
|
+
grad_input = grad_output * d_gelu * positive_mask_1 * positive_mask_2
|
|
111
98
|
|
|
99
|
+
return grad_input
|
|
112
100
|
|
|
113
|
-
class GuidedBackpropModel:
|
|
114
|
-
def __init__(self, model: nn.Module) -> None:
|
|
115
|
-
self.model = model
|
|
116
|
-
self.model.eval()
|
|
117
101
|
|
|
118
|
-
|
|
119
|
-
|
|
102
|
+
# pylint: disable=abstract-method,arguments-differ
|
|
103
|
+
class GuidedBackpropHardswish(Function):
|
|
104
|
+
@staticmethod
|
|
105
|
+
def forward(ctx: Any, input_img: torch.Tensor) -> torch.Tensor:
|
|
106
|
+
result = F.hardswish(input_img)
|
|
107
|
+
ctx.save_for_backward(input_img)
|
|
108
|
+
return result
|
|
109
|
+
|
|
110
|
+
@staticmethod
|
|
111
|
+
def backward(ctx: Any, grad_output: torch.Tensor) -> torch.Tensor:
|
|
112
|
+
x = ctx.saved_tensors[0]
|
|
113
|
+
|
|
114
|
+
grad = torch.zeros_like(x)
|
|
120
115
|
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
_replace_all_layer_type_recursive(self.model, nn.GELU, GuidedBackpropGeLUAsModule())
|
|
124
|
-
_replace_all_layer_type_recursive(self.model, nn.SiLU, GuidedBackpropSwishAsModule())
|
|
125
|
-
_replace_all_layer_type_recursive(self.model, nn.Hardswish, GuidedBackpropHardswishAsModule())
|
|
116
|
+
mask_mid = (x > -3) & (x < 3)
|
|
117
|
+
grad[mask_mid] = (2.0 * x[mask_mid] + 3.0) / 6.0
|
|
126
118
|
|
|
127
|
-
|
|
128
|
-
|
|
119
|
+
mask_high = x >= 3
|
|
120
|
+
grad[mask_high] = 1.0
|
|
129
121
|
|
|
130
|
-
|
|
131
|
-
|
|
122
|
+
positive_mask_1 = (x > 0).type_as(grad_output)
|
|
123
|
+
positive_mask_2 = (grad_output > 0).type_as(grad_output)
|
|
132
124
|
|
|
133
|
-
|
|
134
|
-
|
|
125
|
+
grad_input = grad_output * grad * positive_mask_1 * positive_mask_2
|
|
126
|
+
|
|
127
|
+
return grad_input
|
|
135
128
|
|
|
136
|
-
output_grad = input_img.grad.cpu().data.numpy()
|
|
137
|
-
output_grad = output_grad[0, :, :, :]
|
|
138
|
-
output_grad = output_grad.transpose((1, 2, 0))
|
|
139
129
|
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
_replace_all_layer_type_recursive(self.model, GuidedBackpropReLUAsModule, nn.ReLU())
|
|
130
|
+
class GuidedReLU(nn.Module):
|
|
131
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
132
|
+
return GuidedBackpropReLU.apply(x)
|
|
144
133
|
|
|
145
|
-
return output_grad # type: ignore[no-any-return]
|
|
146
134
|
|
|
135
|
+
class GuidedSiLU(nn.Module):
|
|
136
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
137
|
+
return GuidedBackpropSiLU.apply(x)
|
|
147
138
|
|
|
148
|
-
class GuidedBackpropInterpreter(Interpreter):
|
|
149
|
-
def interpret(self, image: str | Path | Image.Image, target_class: Optional[int] = None) -> InterpretabilityResult:
|
|
150
|
-
(input_tensor, rgb_img) = self._preprocess_image(image)
|
|
151
139
|
|
|
152
|
-
|
|
153
|
-
|
|
140
|
+
class GuidedGELU(nn.Module):
|
|
141
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
142
|
+
return GuidedBackpropGELU.apply(x)
|
|
154
143
|
|
|
155
|
-
|
|
144
|
+
|
|
145
|
+
class GuidedHardswish(nn.Module):
|
|
146
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
147
|
+
return GuidedBackpropHardswish.apply(x)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
# Activation replacement mapping
|
|
151
|
+
ACTIVATION_REPLACEMENTS: dict[type, type] = {
|
|
152
|
+
nn.ReLU: GuidedReLU,
|
|
153
|
+
nn.SiLU: GuidedSiLU,
|
|
154
|
+
nn.GELU: GuidedGELU,
|
|
155
|
+
nn.Hardswish: GuidedHardswish,
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def replace_activations_recursive(model: nn.Module, replacements: dict[type, type]) -> None:
|
|
160
|
+
"""
|
|
161
|
+
NOTE: This ONLY works for activations defined as nn.Module objects (e.g., self.act = nn.ReLU()).
|
|
162
|
+
It will NOT affect functional calls inside forward methods, such as F.relu(x) or F.gelu(x).
|
|
163
|
+
"""
|
|
164
|
+
|
|
165
|
+
for name, module in list(model._modules.items()): # pylint: disable=protected-access
|
|
166
|
+
for old_type, new_type in replacements.items():
|
|
167
|
+
if isinstance(module, old_type):
|
|
168
|
+
model._modules[name] = new_type() # pylint: disable=protected-access
|
|
169
|
+
break
|
|
170
|
+
else:
|
|
171
|
+
# Recurse into submodules
|
|
172
|
+
replace_activations_recursive(module, replacements)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def restore_activations_recursive(model: nn.Module, guided_types: dict[type, type]) -> None:
|
|
176
|
+
reverse_mapping = {v: k for k, v in guided_types.items()}
|
|
177
|
+
for name, module in list(model._modules.items()): # pylint: disable=protected-access
|
|
178
|
+
for guided_type, original_type in reverse_mapping.items():
|
|
179
|
+
if isinstance(module, guided_type):
|
|
180
|
+
model._modules[name] = original_type() # pylint: disable=protected-access
|
|
181
|
+
break
|
|
182
|
+
else:
|
|
183
|
+
restore_activations_recursive(module, guided_types)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
class GuidedBackprop:
|
|
187
|
+
def __init__(self, net: nn.Module, device: torch.device, transform: Callable[..., torch.Tensor]) -> None:
|
|
188
|
+
self.net = net.eval()
|
|
189
|
+
self.device = device
|
|
190
|
+
self.transform = transform
|
|
191
|
+
|
|
192
|
+
def __call__(self, image: str | Path | Image.Image, target_class: Optional[int] = None) -> InterpretabilityResult:
|
|
193
|
+
(input_tensor, rgb_img) = preprocess_image(image, self.transform, self.device)
|
|
194
|
+
|
|
195
|
+
# Get prediction
|
|
196
|
+
with torch.inference_mode():
|
|
197
|
+
logits = self.net(input_tensor)
|
|
198
|
+
|
|
199
|
+
if target_class is None:
|
|
200
|
+
target_class = predict_class(logits)
|
|
201
|
+
else:
|
|
202
|
+
validate_target_class(target_class, logits.shape[-1])
|
|
203
|
+
|
|
204
|
+
# Replace activations with guided versions
|
|
205
|
+
replace_activations_recursive(self.net, ACTIVATION_REPLACEMENTS)
|
|
206
|
+
|
|
207
|
+
try:
|
|
208
|
+
input_tensor = input_tensor.detach().requires_grad_(True)
|
|
209
|
+
output = self.net(input_tensor)
|
|
210
|
+
|
|
211
|
+
loss = output[0, target_class]
|
|
212
|
+
loss.backward(retain_graph=False)
|
|
213
|
+
|
|
214
|
+
gradients = input_tensor.grad.cpu().numpy()
|
|
215
|
+
gradients = gradients[0, :, :, :] # Remove batch dim
|
|
216
|
+
gradients = gradients.transpose((1, 2, 0)) # CHW -> HWC
|
|
217
|
+
|
|
218
|
+
finally:
|
|
219
|
+
restore_activations_recursive(self.net, ACTIVATION_REPLACEMENTS)
|
|
220
|
+
|
|
221
|
+
visualization = deprocess_image(gradients * rgb_img)
|
|
222
|
+
|
|
223
|
+
return InterpretabilityResult(
|
|
224
|
+
original_image=rgb_img,
|
|
225
|
+
visualization=visualization,
|
|
226
|
+
raw_output=gradients,
|
|
227
|
+
logits=logits.detach(),
|
|
228
|
+
predicted_class=target_class,
|
|
229
|
+
)
|
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Transformer Attribution (Gradient-weighted Attention Rollout), adapted from
|
|
3
|
+
https://github.com/hila-chefer/Transformer-Explainability
|
|
4
|
+
|
|
5
|
+
Paper "Transformer Interpretability Beyond Attention Visualization", https://arxiv.org/abs/2012.09838
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
# Reference license: MIT
|
|
9
|
+
|
|
10
|
+
from collections.abc import Callable
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Optional
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
import torch
|
|
16
|
+
from PIL import Image
|
|
17
|
+
from torch import nn
|
|
18
|
+
|
|
19
|
+
from birder.introspection.base import InterpretabilityResult
|
|
20
|
+
from birder.introspection.base import predict_class
|
|
21
|
+
from birder.introspection.base import preprocess_image
|
|
22
|
+
from birder.introspection.base import show_mask_on_image
|
|
23
|
+
from birder.introspection.base import validate_target_class
|
|
24
|
+
from birder.net.vit import Encoder
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def compute_attribution_rollout(
|
|
28
|
+
attributions: list[tuple[torch.Tensor, torch.Tensor]], num_special_tokens: int, patch_grid_shape: tuple[int, int]
|
|
29
|
+
) -> torch.Tensor:
|
|
30
|
+
"""
|
|
31
|
+
NOTE: Uses gradient norm per token instead of element-wise grad * attention multiplication.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
# Assume batch size = 1
|
|
35
|
+
num_tokens = attributions[0][0].size(-1)
|
|
36
|
+
device = attributions[0][0].device
|
|
37
|
+
|
|
38
|
+
result = torch.eye(num_tokens, device=device)
|
|
39
|
+
with torch.no_grad():
|
|
40
|
+
for attn_weights, output_grad in attributions:
|
|
41
|
+
# Compute token importance from output gradient norm across embedding dimension
|
|
42
|
+
token_importance = output_grad.norm(dim=-1, keepdim=True)
|
|
43
|
+
token_importance = token_importance.transpose(-1, -2)
|
|
44
|
+
|
|
45
|
+
# Weight attention patterns by token importance
|
|
46
|
+
weighted_attn = attn_weights * token_importance.unsqueeze(1)
|
|
47
|
+
|
|
48
|
+
# Fuse attention heads and apply non-negativity constraint
|
|
49
|
+
relevance = weighted_attn.mean(dim=1).clamp(min=0)
|
|
50
|
+
|
|
51
|
+
# Add residual connection and normalize
|
|
52
|
+
eye = torch.eye(num_tokens, device=device)
|
|
53
|
+
normalized = (relevance + eye) / 2.0
|
|
54
|
+
normalized = normalized / normalized.sum(dim=-1, keepdim=True)
|
|
55
|
+
|
|
56
|
+
# Accumulate attention across layers
|
|
57
|
+
result = torch.matmul(normalized, result)
|
|
58
|
+
|
|
59
|
+
rollout = result[0]
|
|
60
|
+
|
|
61
|
+
if 0 < num_special_tokens:
|
|
62
|
+
source_to_patches = rollout[:num_special_tokens, num_special_tokens:]
|
|
63
|
+
mask = source_to_patches.mean(dim=0)
|
|
64
|
+
else:
|
|
65
|
+
mask = rollout.mean(dim=0)
|
|
66
|
+
|
|
67
|
+
mask = mask / (mask.max() + 1e-8)
|
|
68
|
+
|
|
69
|
+
(grid_h, grid_w) = patch_grid_shape
|
|
70
|
+
mask = mask.reshape(grid_h, grid_w)
|
|
71
|
+
|
|
72
|
+
return mask
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class AttributionGatherer:
|
|
76
|
+
def __init__(self, net: nn.Module, attention_layer_name: str) -> None:
|
|
77
|
+
assert hasattr(net, "encoder") is True and isinstance(net.encoder, Encoder)
|
|
78
|
+
|
|
79
|
+
net.encoder.set_need_attn()
|
|
80
|
+
|
|
81
|
+
self.net = net
|
|
82
|
+
self.handles: list[torch.utils.hooks.RemovableHandle] = []
|
|
83
|
+
self._gradients: list[torch.Tensor] = []
|
|
84
|
+
self._attention_weights: list[torch.Tensor] = []
|
|
85
|
+
|
|
86
|
+
for name, module in self.net.named_modules():
|
|
87
|
+
if name.endswith(attention_layer_name) is True:
|
|
88
|
+
handle = module.register_forward_hook(self._capture_forward)
|
|
89
|
+
self.handles.append(handle)
|
|
90
|
+
|
|
91
|
+
def _capture_forward(
|
|
92
|
+
self, _module: nn.Module, _inputs: tuple[torch.Tensor, ...], output: tuple[torch.Tensor, ...] | torch.Tensor
|
|
93
|
+
) -> None:
|
|
94
|
+
output_tensor = output[0]
|
|
95
|
+
attn_weights = output[1]
|
|
96
|
+
|
|
97
|
+
self._attention_weights.append(attn_weights.detach())
|
|
98
|
+
if output_tensor.requires_grad:
|
|
99
|
+
|
|
100
|
+
def _store_grad(grad: torch.Tensor) -> None:
|
|
101
|
+
self._gradients.append(grad.detach())
|
|
102
|
+
|
|
103
|
+
output_tensor.register_hook(_store_grad)
|
|
104
|
+
|
|
105
|
+
def get_captured_data(self) -> list[tuple[torch.Tensor, torch.Tensor]]:
|
|
106
|
+
if len(self._attention_weights) != len(self._gradients):
|
|
107
|
+
raise RuntimeError(
|
|
108
|
+
f"Mismatch between attention weights ({len(self._attention_weights)}) "
|
|
109
|
+
f"and gradients ({len(self._gradients)}). Ensure backward() was called."
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
if len(self._attention_weights) == 0:
|
|
113
|
+
raise RuntimeError("No attention data captured. Ensure the model has attention layers.")
|
|
114
|
+
|
|
115
|
+
# Pair attention weights with output gradients (gradients reversed to match forward order)
|
|
116
|
+
results = [(attn.cpu(), grad.cpu()) for attn, grad in zip(self._attention_weights, reversed(self._gradients))]
|
|
117
|
+
|
|
118
|
+
# Clear storage for next forward pass
|
|
119
|
+
self._gradients = []
|
|
120
|
+
self._attention_weights = []
|
|
121
|
+
|
|
122
|
+
return results
|
|
123
|
+
|
|
124
|
+
def release(self) -> None:
|
|
125
|
+
for handle in self.handles:
|
|
126
|
+
handle.remove()
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class TransformerAttribution:
|
|
130
|
+
def __init__(
|
|
131
|
+
self,
|
|
132
|
+
net: nn.Module,
|
|
133
|
+
device: torch.device,
|
|
134
|
+
transform: Callable[..., torch.Tensor],
|
|
135
|
+
attention_layer_name: str = "self_attention",
|
|
136
|
+
) -> None:
|
|
137
|
+
self.net = net.eval()
|
|
138
|
+
self.device = device
|
|
139
|
+
self.transform = transform
|
|
140
|
+
self.gatherer = AttributionGatherer(net, attention_layer_name)
|
|
141
|
+
|
|
142
|
+
def __call__(self, image: str | Path | Image.Image, target_class: Optional[int] = None) -> InterpretabilityResult:
|
|
143
|
+
(input_tensor, rgb_img) = preprocess_image(image, self.transform, self.device)
|
|
144
|
+
input_tensor.requires_grad_(True)
|
|
145
|
+
|
|
146
|
+
self.net.zero_grad()
|
|
147
|
+
logits = self.net(input_tensor)
|
|
148
|
+
|
|
149
|
+
if target_class is None:
|
|
150
|
+
target_class = predict_class(logits)
|
|
151
|
+
else:
|
|
152
|
+
validate_target_class(target_class, logits.shape[-1])
|
|
153
|
+
|
|
154
|
+
score = logits[0, target_class]
|
|
155
|
+
score.backward()
|
|
156
|
+
|
|
157
|
+
attribution_data = self.gatherer.get_captured_data()
|
|
158
|
+
|
|
159
|
+
(_, _, H, W) = input_tensor.shape
|
|
160
|
+
patch_grid_shape = (H // self.net.stem_stride, W // self.net.stem_stride)
|
|
161
|
+
|
|
162
|
+
attribution_map = compute_attribution_rollout(
|
|
163
|
+
attribution_data, num_special_tokens=self.net.num_special_tokens, patch_grid_shape=patch_grid_shape
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
attribution_img = Image.fromarray(attribution_map.numpy())
|
|
167
|
+
attribution_img = attribution_img.resize((rgb_img.shape[1], rgb_img.shape[0]))
|
|
168
|
+
attribution_arr = np.array(attribution_img)
|
|
169
|
+
|
|
170
|
+
visualization = show_mask_on_image(rgb_img, attribution_arr, image_weight=0.4)
|
|
171
|
+
|
|
172
|
+
return InterpretabilityResult(
|
|
173
|
+
original_image=rgb_img,
|
|
174
|
+
visualization=visualization,
|
|
175
|
+
raw_output=attribution_arr,
|
|
176
|
+
logits=logits.detach(),
|
|
177
|
+
predicted_class=target_class,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
def __del__(self) -> None:
|
|
181
|
+
if hasattr(self, "gatherer") is True:
|
|
182
|
+
self.gatherer.release()
|