birder 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/adversarial/__init__.py +13 -0
- birder/adversarial/base.py +101 -0
- birder/adversarial/deepfool.py +173 -0
- birder/adversarial/fgsm.py +51 -18
- birder/adversarial/pgd.py +79 -28
- birder/adversarial/simba.py +172 -0
- birder/common/training_cli.py +11 -3
- birder/common/training_utils.py +18 -1
- birder/inference/data_parallel.py +1 -2
- birder/introspection/__init__.py +10 -6
- birder/introspection/attention_rollout.py +122 -54
- birder/introspection/base.py +73 -29
- birder/introspection/gradcam.py +71 -100
- birder/introspection/guided_backprop.py +146 -72
- birder/introspection/transformer_attribution.py +182 -0
- birder/net/detection/deformable_detr.py +14 -12
- birder/net/detection/detr.py +7 -3
- birder/net/detection/rt_detr_v1.py +3 -3
- birder/net/detection/yolo_v3.py +6 -11
- birder/net/detection/yolo_v4.py +7 -18
- birder/net/detection/yolo_v4_tiny.py +3 -3
- birder/net/fastvit.py +1 -1
- birder/net/mim/mae_vit.py +7 -8
- birder/net/pit.py +1 -1
- birder/net/resnet_v1.py +94 -34
- birder/net/ssl/data2vec.py +1 -1
- birder/net/ssl/data2vec2.py +4 -2
- birder/results/gui.py +15 -2
- birder/scripts/predict_detection.py +33 -1
- birder/scripts/train.py +24 -17
- birder/scripts/train_barlow_twins.py +10 -7
- birder/scripts/train_byol.py +10 -7
- birder/scripts/train_capi.py +12 -9
- birder/scripts/train_data2vec.py +10 -7
- birder/scripts/train_data2vec2.py +10 -7
- birder/scripts/train_detection.py +42 -18
- birder/scripts/train_dino_v1.py +10 -7
- birder/scripts/train_dino_v2.py +10 -7
- birder/scripts/train_dino_v2_dist.py +17 -7
- birder/scripts/train_franca.py +10 -7
- birder/scripts/train_i_jepa.py +17 -13
- birder/scripts/train_ibot.py +10 -7
- birder/scripts/train_kd.py +24 -18
- birder/scripts/train_mim.py +11 -10
- birder/scripts/train_mmcr.py +10 -7
- birder/scripts/train_rotnet.py +10 -7
- birder/scripts/train_simclr.py +10 -7
- birder/scripts/train_vicreg.py +10 -7
- birder/tools/__main__.py +6 -2
- birder/tools/adversarial.py +147 -96
- birder/tools/auto_anchors.py +361 -0
- birder/tools/ensemble_model.py +1 -1
- birder/tools/introspection.py +58 -31
- birder/version.py +1 -1
- {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/METADATA +2 -1
- {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/RECORD +60 -55
- {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/WHEEL +0 -0
- {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/entry_points.txt +0 -0
- {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/top_level.txt +0 -0
birder/common/training_cli.py
CHANGED
|
@@ -110,10 +110,13 @@ def add_lr_scheduler_args(parser: argparse.ArgumentParser) -> None:
|
|
|
110
110
|
type=int,
|
|
111
111
|
default=40,
|
|
112
112
|
metavar="N",
|
|
113
|
-
help="decrease lr every
|
|
113
|
+
help="decrease lr every N epochs/steps (relative to after warmup, step scheduler only)",
|
|
114
114
|
)
|
|
115
115
|
group.add_argument(
|
|
116
|
-
"--lr-steps",
|
|
116
|
+
"--lr-steps",
|
|
117
|
+
type=int,
|
|
118
|
+
nargs="+",
|
|
119
|
+
help="absolute epoch/step milestones when to decrease lr (multistep scheduler only)",
|
|
117
120
|
)
|
|
118
121
|
group.add_argument(
|
|
119
122
|
"--lr-step-gamma",
|
|
@@ -391,7 +394,7 @@ def add_ema_args(
|
|
|
391
394
|
"--model-ema-warmup",
|
|
392
395
|
type=int,
|
|
393
396
|
metavar="N",
|
|
394
|
-
help="number of epochs before EMA is applied (defaults to warmup epochs/
|
|
397
|
+
help="number of epochs/steps before EMA is applied (defaults to warmup epochs/steps, pass 0 to disable warmup)",
|
|
395
398
|
)
|
|
396
399
|
|
|
397
400
|
|
|
@@ -656,6 +659,11 @@ def common_args_validation(args: argparse.Namespace) -> None:
|
|
|
656
659
|
f"but it is set to '{args.lr_scheduler_update}'"
|
|
657
660
|
)
|
|
658
661
|
|
|
662
|
+
# EMA
|
|
663
|
+
if hasattr(args, "model_ema_steps") is True:
|
|
664
|
+
if args.model_ema_steps < 1:
|
|
665
|
+
raise ValidationError("--model-ema-steps must be >= 1")
|
|
666
|
+
|
|
659
667
|
# Compile args, argument dependant
|
|
660
668
|
if hasattr(args, "compile_teacher") is True:
|
|
661
669
|
if args.compile is True and args.compile_teacher is True:
|
birder/common/training_utils.py
CHANGED
|
@@ -491,12 +491,29 @@ def get_scheduler(
|
|
|
491
491
|
if args.lr_scheduler == "constant":
|
|
492
492
|
main_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0, total_iters=1)
|
|
493
493
|
elif args.lr_scheduler == "step":
|
|
494
|
+
# Note: StepLR step_size is relative to when the main scheduler starts (after warmup)
|
|
495
|
+
# This means drops occur relative to the end of warmup, not at absolute epoch numbers
|
|
494
496
|
main_scheduler = torch.optim.lr_scheduler.StepLR(
|
|
495
497
|
optimizer, step_size=args.lr_step_size, gamma=args.lr_step_gamma
|
|
496
498
|
)
|
|
497
499
|
elif args.lr_scheduler == "multistep":
|
|
500
|
+
# For MultiStepLR, milestones should be absolute step numbers
|
|
501
|
+
# Adjust them to be relative to when the main scheduler starts (after warmup)
|
|
502
|
+
# This ensures drops occur at the specified absolute steps, not relative to after warmup
|
|
503
|
+
adjusted_milestones = [m - warmup_steps for m in args.lr_steps if m >= warmup_steps]
|
|
504
|
+
if len(adjusted_milestones) == 0:
|
|
505
|
+
logger.debug(
|
|
506
|
+
f"All MultiStepLR milestones {args.lr_steps} are before warmup "
|
|
507
|
+
f"(warmup ends at step {warmup_steps}). Using empty milestone list."
|
|
508
|
+
)
|
|
509
|
+
adjusted_milestones = []
|
|
510
|
+
|
|
511
|
+
logger.debug(
|
|
512
|
+
f"MultiStepLR milestones adjusted from {args.lr_steps} to {adjusted_milestones} "
|
|
513
|
+
f"(relative to main scheduler start after {warmup_steps} warmup steps)"
|
|
514
|
+
)
|
|
498
515
|
main_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
|
499
|
-
optimizer, milestones=
|
|
516
|
+
optimizer, milestones=adjusted_milestones, gamma=args.lr_step_gamma
|
|
500
517
|
)
|
|
501
518
|
elif args.lr_scheduler == "cosine":
|
|
502
519
|
main_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
|
@@ -1,8 +1,7 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Inference-optimized multi-GPU parallelization
|
|
3
3
|
|
|
4
|
-
This module provides InferenceDataParallel, an inference-specific alternative to
|
|
5
|
-
torch.nn.DataParallel.
|
|
4
|
+
This module provides InferenceDataParallel, an inference-specific alternative to torch.nn.DataParallel.
|
|
6
5
|
"""
|
|
7
6
|
|
|
8
7
|
import copy
|
birder/introspection/__init__.py
CHANGED
|
@@ -1,9 +1,13 @@
|
|
|
1
|
-
from birder.introspection.attention_rollout import
|
|
2
|
-
from birder.introspection.
|
|
3
|
-
from birder.introspection.
|
|
1
|
+
from birder.introspection.attention_rollout import AttentionRollout
|
|
2
|
+
from birder.introspection.base import InterpretabilityResult
|
|
3
|
+
from birder.introspection.gradcam import GradCAM
|
|
4
|
+
from birder.introspection.guided_backprop import GuidedBackprop
|
|
5
|
+
from birder.introspection.transformer_attribution import TransformerAttribution
|
|
4
6
|
|
|
5
7
|
__all__ = [
|
|
6
|
-
"
|
|
7
|
-
"
|
|
8
|
-
"
|
|
8
|
+
"InterpretabilityResult",
|
|
9
|
+
"AttentionRollout",
|
|
10
|
+
"GradCAM",
|
|
11
|
+
"GuidedBackprop",
|
|
12
|
+
"TransformerAttribution",
|
|
9
13
|
]
|
|
@@ -1,5 +1,8 @@
|
|
|
1
1
|
"""
|
|
2
|
-
|
|
2
|
+
Attention Rollout for Vision Transformers, adapted from
|
|
3
|
+
https://github.com/jacobgil/vit-explain/blob/main/vit_rollout.py
|
|
4
|
+
|
|
5
|
+
Paper "Quantifying Attention Flow in Transformers", https://arxiv.org/abs/2005.00928
|
|
3
6
|
"""
|
|
4
7
|
|
|
5
8
|
# Reference license: MIT
|
|
@@ -15,103 +18,168 @@ from PIL import Image
|
|
|
15
18
|
from torch import nn
|
|
16
19
|
|
|
17
20
|
from birder.introspection.base import InterpretabilityResult
|
|
18
|
-
from birder.introspection.base import
|
|
21
|
+
from birder.introspection.base import predict_class
|
|
22
|
+
from birder.introspection.base import preprocess_image
|
|
19
23
|
from birder.introspection.base import show_mask_on_image
|
|
20
24
|
from birder.net.vit import Encoder
|
|
21
25
|
|
|
22
26
|
|
|
23
|
-
|
|
27
|
+
# pylint: disable=too-many-locals
|
|
28
|
+
def compute_rollout(
|
|
24
29
|
attentions: list[torch.Tensor],
|
|
25
30
|
discard_ratio: float,
|
|
26
31
|
head_fusion: Literal["mean", "max", "min"],
|
|
27
32
|
num_special_tokens: int,
|
|
33
|
+
patch_grid_shape: tuple[int, int],
|
|
28
34
|
) -> torch.Tensor:
|
|
29
|
-
|
|
35
|
+
# Assume batch size = 1
|
|
36
|
+
num_tokens = attentions[0].size(-1)
|
|
37
|
+
device = attentions[0].device
|
|
38
|
+
|
|
39
|
+
# Start with identity (residual)
|
|
40
|
+
result = torch.eye(num_tokens, device=device)
|
|
41
|
+
|
|
30
42
|
with torch.no_grad():
|
|
31
43
|
for attention in attentions:
|
|
44
|
+
# Fuse heads: [B, H, T, T] -> [B, T, T]
|
|
32
45
|
if head_fusion == "mean":
|
|
33
|
-
attention_heads_fused = attention.mean(
|
|
46
|
+
attention_heads_fused = attention.mean(dim=1)
|
|
34
47
|
elif head_fusion == "max":
|
|
35
|
-
attention_heads_fused = attention.max(
|
|
48
|
+
attention_heads_fused = attention.max(dim=1)[0]
|
|
36
49
|
elif head_fusion == "min":
|
|
37
|
-
attention_heads_fused = attention.min(
|
|
50
|
+
attention_heads_fused = attention.min(dim=1)[0]
|
|
38
51
|
else:
|
|
39
|
-
raise ValueError("
|
|
40
|
-
|
|
41
|
-
#
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
52
|
+
raise ValueError(f"Unsupported head_fusion: {head_fusion}")
|
|
53
|
+
|
|
54
|
+
# attention_heads_fused: [1, T, T] (batch = 1)
|
|
55
|
+
if discard_ratio > 0:
|
|
56
|
+
# Work on the single batch element
|
|
57
|
+
attn = attention_heads_fused[0] # [T, T]
|
|
58
|
+
|
|
59
|
+
# Define which positions are "non-special"
|
|
60
|
+
idx = torch.arange(num_tokens, device=attn.device)
|
|
61
|
+
is_special = idx < num_special_tokens
|
|
62
|
+
non_special = ~is_special
|
|
63
|
+
|
|
64
|
+
# We are only allowed to prune NON-special <-> NON-special entries
|
|
65
|
+
allow = non_special[:, None] & non_special[None, :] # [T, T]
|
|
66
|
+
|
|
67
|
+
allowed_values = attn[allow]
|
|
68
|
+
num_allowed = allowed_values.numel()
|
|
69
|
+
if num_allowed > 0:
|
|
70
|
+
num_to_discard = int(num_allowed * discard_ratio)
|
|
71
|
+
if num_to_discard > 0:
|
|
72
|
+
# Drop the smallest allowed values
|
|
73
|
+
(_, low_idx) = torch.topk(allowed_values, num_to_discard, largest=False)
|
|
74
|
+
allowed_values[low_idx] = 0
|
|
75
|
+
attn[allow] = allowed_values
|
|
76
|
+
attention_heads_fused[0] = attn
|
|
77
|
+
|
|
78
|
+
# Add residual connection and normalize
|
|
79
|
+
eye = torch.eye(num_tokens, device=attention_heads_fused.device)
|
|
80
|
+
a = (attention_heads_fused + eye) / 2.0 # [1, T, T]
|
|
81
|
+
a = a / a.sum(dim=-1, keepdim=True)
|
|
82
|
+
|
|
83
|
+
# Accumulate attention across layers
|
|
84
|
+
result = torch.matmul(a, result) # [1, T, T]
|
|
85
|
+
|
|
86
|
+
rollout = result[0] # [T, T]
|
|
87
|
+
|
|
88
|
+
# Build final token → patch map
|
|
89
|
+
if 0 < num_special_tokens < num_tokens:
|
|
90
|
+
# Sources: all special tokens (0 .. num_special_tokens-1)
|
|
91
|
+
# Targets: all non-special tokens (num_special_tokens .. end)
|
|
92
|
+
source_to_patches = rollout[:num_special_tokens, num_special_tokens:]
|
|
93
|
+
mask = source_to_patches.mean(dim=0)
|
|
94
|
+
else:
|
|
95
|
+
# No special tokens (or all are special): fall back to averaging over all sources
|
|
96
|
+
mask = rollout.mean(dim=0) # [T]
|
|
97
|
+
|
|
98
|
+
# Normalize and reshape to 2D map using actual patch grid dimensions
|
|
99
|
+
mask = mask / (mask.max() + 1e-8)
|
|
100
|
+
(grid_h, grid_w) = patch_grid_shape
|
|
101
|
+
mask = mask.reshape(grid_h, grid_w)
|
|
59
102
|
|
|
60
103
|
return mask
|
|
61
104
|
|
|
62
105
|
|
|
63
|
-
class
|
|
64
|
-
def __init__(self, net:
|
|
106
|
+
class AttentionGatherer:
|
|
107
|
+
def __init__(self, net: nn.Module, attention_layer_name: str) -> None:
|
|
65
108
|
assert hasattr(net, "encoder") is True and isinstance(net.encoder, Encoder)
|
|
109
|
+
|
|
66
110
|
net.encoder.set_need_attn()
|
|
67
111
|
self.net = net
|
|
112
|
+
self.attentions: list[torch.Tensor] = []
|
|
113
|
+
self.handles: list[torch.utils.hooks.RemovableHandle] = []
|
|
114
|
+
|
|
115
|
+
# Register hooks on attention layers
|
|
68
116
|
for name, module in self.net.named_modules():
|
|
69
117
|
if name.endswith(attention_layer_name) is True:
|
|
70
|
-
module.register_forward_hook(self.
|
|
118
|
+
handle = module.register_forward_hook(self._capture_attention)
|
|
119
|
+
self.handles.append(handle)
|
|
71
120
|
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
def get_attention(
|
|
75
|
-
self, _module: torch.nn.Module, _inputs: tuple[torch.Tensor, ...], outputs: tuple[torch.Tensor, ...]
|
|
121
|
+
def _capture_attention(
|
|
122
|
+
self, _module: nn.Module, _inputs: tuple[torch.Tensor, ...], outputs: tuple[torch.Tensor, ...]
|
|
76
123
|
) -> None:
|
|
77
124
|
self.attentions.append(outputs[1].cpu())
|
|
78
125
|
|
|
79
|
-
def __call__(
|
|
80
|
-
self, x: torch.Tensor, discard_ratio: float, head_fusion: Literal["mean", "max", "min"]
|
|
81
|
-
) -> torch.Tensor:
|
|
126
|
+
def __call__(self, x: torch.Tensor) -> tuple[list[torch.Tensor], torch.Tensor]:
|
|
82
127
|
self.attentions = []
|
|
83
128
|
with torch.inference_mode():
|
|
84
|
-
self.net(x)
|
|
129
|
+
logits = self.net(x)
|
|
130
|
+
|
|
131
|
+
return (self.attentions, logits)
|
|
85
132
|
|
|
86
|
-
|
|
133
|
+
def release(self) -> None:
|
|
134
|
+
for handle in self.handles:
|
|
135
|
+
handle.remove()
|
|
87
136
|
|
|
88
137
|
|
|
89
|
-
class
|
|
138
|
+
class AttentionRollout:
|
|
90
139
|
def __init__(
|
|
91
140
|
self,
|
|
92
|
-
|
|
141
|
+
net: nn.Module,
|
|
93
142
|
device: torch.device,
|
|
94
143
|
transform: Callable[..., torch.Tensor],
|
|
95
|
-
attention_layer_name: str,
|
|
96
|
-
discard_ratio: float,
|
|
97
|
-
head_fusion: Literal["mean", "max", "min"],
|
|
144
|
+
attention_layer_name: str = "self_attention",
|
|
145
|
+
discard_ratio: float = 0.9,
|
|
146
|
+
head_fusion: Literal["mean", "max", "min"] = "max",
|
|
98
147
|
) -> None:
|
|
99
|
-
|
|
100
|
-
|
|
148
|
+
if not 0 <= discard_ratio <= 1:
|
|
149
|
+
raise ValueError(f"discard_ratio must be in [0, 1], got {discard_ratio}")
|
|
150
|
+
|
|
151
|
+
self.net = net.eval()
|
|
152
|
+
self.device = device
|
|
153
|
+
self.transform = transform
|
|
101
154
|
self.discard_ratio = discard_ratio
|
|
102
155
|
self.head_fusion = head_fusion
|
|
156
|
+
self.attention_gatherer = AttentionGatherer(net, attention_layer_name)
|
|
103
157
|
|
|
104
|
-
def
|
|
105
|
-
(input_tensor, rgb_img) = self.
|
|
158
|
+
def __call__(self, image: str | Path | Image.Image, target_class: Optional[int] = None) -> InterpretabilityResult:
|
|
159
|
+
(input_tensor, rgb_img) = preprocess_image(image, self.transform, self.device)
|
|
106
160
|
|
|
107
|
-
|
|
108
|
-
input_tensor, discard_ratio=self.discard_ratio, head_fusion=self.head_fusion
|
|
109
|
-
)
|
|
161
|
+
(attentions, logits) = self.attention_gatherer(input_tensor)
|
|
110
162
|
|
|
111
|
-
|
|
163
|
+
(_, _, H, W) = input_tensor.shape
|
|
164
|
+
patch_grid_shape = (H // self.net.stem_stride, W // self.net.stem_stride)
|
|
165
|
+
|
|
166
|
+
attention_map = compute_rollout(
|
|
167
|
+
attentions, self.discard_ratio, self.head_fusion, self.net.num_special_tokens, patch_grid_shape
|
|
168
|
+
)
|
|
112
169
|
attention_img = Image.fromarray(attention_map.numpy())
|
|
113
|
-
attention_img = attention_img.resize(rgb_img.shape[
|
|
170
|
+
attention_img = attention_img.resize((rgb_img.shape[1], rgb_img.shape[0]))
|
|
114
171
|
attention_arr = np.array(attention_img)
|
|
172
|
+
|
|
115
173
|
visualization = show_mask_on_image(rgb_img, attention_arr, image_weight=0.4)
|
|
116
174
|
|
|
117
|
-
return InterpretabilityResult(
|
|
175
|
+
return InterpretabilityResult(
|
|
176
|
+
original_image=rgb_img,
|
|
177
|
+
visualization=visualization,
|
|
178
|
+
raw_output=attention_arr,
|
|
179
|
+
logits=logits.detach(),
|
|
180
|
+
predicted_class=predict_class(logits),
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
def __del__(self) -> None:
|
|
184
|
+
if hasattr(self, "attention_gatherer") is True:
|
|
185
|
+
self.attention_gatherer.release()
|
birder/introspection/base.py
CHANGED
|
@@ -2,6 +2,7 @@ from collections.abc import Callable
|
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
from typing import Optional
|
|
5
|
+
from typing import Protocol
|
|
5
6
|
|
|
6
7
|
import matplotlib
|
|
7
8
|
import matplotlib.pyplot as plt
|
|
@@ -9,13 +10,56 @@ import numpy as np
|
|
|
9
10
|
import numpy.typing as npt
|
|
10
11
|
import torch
|
|
11
12
|
from PIL import Image
|
|
12
|
-
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass(frozen=True)
|
|
16
|
+
class InterpretabilityResult:
|
|
17
|
+
original_image: npt.NDArray[np.float32]
|
|
18
|
+
visualization: npt.NDArray[np.float32] | npt.NDArray[np.uint8]
|
|
19
|
+
raw_output: npt.NDArray[np.float32]
|
|
20
|
+
logits: Optional[torch.Tensor] = None
|
|
21
|
+
predicted_class: Optional[int] = None
|
|
22
|
+
|
|
23
|
+
def show(self, figsize: tuple[int, int] = (12, 8)) -> None:
|
|
24
|
+
_, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
|
|
25
|
+
ax1.imshow(self.visualization)
|
|
26
|
+
ax2.imshow(self.original_image)
|
|
27
|
+
plt.show()
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class Interpreter(Protocol):
|
|
31
|
+
def __call__(
|
|
32
|
+
self, image: str | Path | Image.Image, target_class: Optional[int] = None
|
|
33
|
+
) -> InterpretabilityResult: ...
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def load_image(image: str | Path | Image.Image) -> Image.Image:
|
|
37
|
+
if isinstance(image, (str, Path)):
|
|
38
|
+
return Image.open(image)
|
|
39
|
+
|
|
40
|
+
return image
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def preprocess_image(
|
|
44
|
+
image: str | Path | Image.Image, transform: Callable[..., torch.Tensor], device: torch.device
|
|
45
|
+
) -> tuple[torch.Tensor, npt.NDArray[np.float32]]:
|
|
46
|
+
pil_image = load_image(image)
|
|
47
|
+
input_tensor = transform(pil_image).unsqueeze(dim=0).to(device)
|
|
48
|
+
|
|
49
|
+
# Resize and normalize for visualization
|
|
50
|
+
resized = pil_image.resize((input_tensor.shape[-1], input_tensor.shape[-2]))
|
|
51
|
+
rgb_img = np.array(resized).astype(np.float32) / 255.0
|
|
52
|
+
|
|
53
|
+
return (input_tensor, rgb_img)
|
|
13
54
|
|
|
14
55
|
|
|
15
56
|
def show_mask_on_image(
|
|
16
|
-
img: npt.NDArray[np.float32],
|
|
17
|
-
|
|
18
|
-
|
|
57
|
+
img: npt.NDArray[np.float32],
|
|
58
|
+
mask: npt.NDArray[np.float32],
|
|
59
|
+
image_weight: float = 0.5,
|
|
60
|
+
colormap: str = "jet",
|
|
61
|
+
) -> npt.NDArray[np.uint8]:
|
|
62
|
+
color_map = matplotlib.colormaps[colormap]
|
|
19
63
|
heatmap = color_map(mask)[:, :, :3]
|
|
20
64
|
|
|
21
65
|
cam: npt.NDArray[np.float32] = (1 - image_weight) * heatmap + image_weight * img
|
|
@@ -25,36 +69,36 @@ def show_mask_on_image(
|
|
|
25
69
|
return cam.astype(np.uint8)
|
|
26
70
|
|
|
27
71
|
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
72
|
+
def scale_cam_image(
|
|
73
|
+
cam: npt.NDArray[np.float32], target_size: Optional[tuple[int, int]] = None
|
|
74
|
+
) -> npt.NDArray[np.float32]:
|
|
75
|
+
result = []
|
|
76
|
+
for img in cam:
|
|
77
|
+
img = img - np.min(img)
|
|
78
|
+
img = img / (1e-7 + np.max(img))
|
|
79
|
+
if target_size is not None:
|
|
80
|
+
img = np.array(Image.fromarray(img).resize(target_size))
|
|
33
81
|
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
ax2.imshow(self.original_image)
|
|
38
|
-
plt.show()
|
|
82
|
+
result.append(img)
|
|
83
|
+
|
|
84
|
+
return np.array(result, dtype=np.float32)
|
|
39
85
|
|
|
40
86
|
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
87
|
+
def deprocess_image(img: npt.NDArray[np.float32]) -> npt.NDArray[np.uint8]:
|
|
88
|
+
img = img - np.mean(img)
|
|
89
|
+
img = img / (np.std(img) + 1e-5)
|
|
90
|
+
img = img * 0.1
|
|
91
|
+
img = img + 0.5
|
|
92
|
+
img = np.clip(img, 0, 1)
|
|
46
93
|
|
|
47
|
-
|
|
48
|
-
raise NotImplementedError
|
|
94
|
+
return np.array(img * 255).astype(np.uint8)
|
|
49
95
|
|
|
50
|
-
def _preprocess_image(self, image: str | Path | Image.Image) -> tuple[torch.Tensor, npt.NDArray[np.float32]]:
|
|
51
|
-
if isinstance(image, (str, Path)):
|
|
52
|
-
image = Image.open(image)
|
|
53
96
|
|
|
54
|
-
|
|
55
|
-
|
|
97
|
+
def validate_target_class(target_class: Optional[int], num_classes: int) -> None:
|
|
98
|
+
if target_class is not None:
|
|
99
|
+
if target_class < 0 or target_class >= num_classes:
|
|
100
|
+
raise ValueError(f"target_class must be in range [0, {num_classes}), got {target_class}")
|
|
56
101
|
|
|
57
|
-
# Store original for visualization
|
|
58
|
-
rgb_img = np.array(image.resize(input_tensor.shape[-2:])).astype(np.float32) / 255.0
|
|
59
102
|
|
|
60
|
-
|
|
103
|
+
def predict_class(logits: torch.Tensor) -> int:
|
|
104
|
+
return int(torch.argmax(logits, dim=-1).item())
|