birder 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. birder/adversarial/__init__.py +13 -0
  2. birder/adversarial/base.py +101 -0
  3. birder/adversarial/deepfool.py +173 -0
  4. birder/adversarial/fgsm.py +51 -18
  5. birder/adversarial/pgd.py +79 -28
  6. birder/adversarial/simba.py +172 -0
  7. birder/common/training_cli.py +11 -3
  8. birder/common/training_utils.py +18 -1
  9. birder/inference/data_parallel.py +1 -2
  10. birder/introspection/__init__.py +10 -6
  11. birder/introspection/attention_rollout.py +122 -54
  12. birder/introspection/base.py +73 -29
  13. birder/introspection/gradcam.py +71 -100
  14. birder/introspection/guided_backprop.py +146 -72
  15. birder/introspection/transformer_attribution.py +182 -0
  16. birder/net/detection/deformable_detr.py +14 -12
  17. birder/net/detection/detr.py +7 -3
  18. birder/net/detection/rt_detr_v1.py +3 -3
  19. birder/net/detection/yolo_v3.py +6 -11
  20. birder/net/detection/yolo_v4.py +7 -18
  21. birder/net/detection/yolo_v4_tiny.py +3 -3
  22. birder/net/fastvit.py +1 -1
  23. birder/net/mim/mae_vit.py +7 -8
  24. birder/net/pit.py +1 -1
  25. birder/net/resnet_v1.py +94 -34
  26. birder/net/ssl/data2vec.py +1 -1
  27. birder/net/ssl/data2vec2.py +4 -2
  28. birder/results/gui.py +15 -2
  29. birder/scripts/predict_detection.py +33 -1
  30. birder/scripts/train.py +24 -17
  31. birder/scripts/train_barlow_twins.py +10 -7
  32. birder/scripts/train_byol.py +10 -7
  33. birder/scripts/train_capi.py +12 -9
  34. birder/scripts/train_data2vec.py +10 -7
  35. birder/scripts/train_data2vec2.py +10 -7
  36. birder/scripts/train_detection.py +42 -18
  37. birder/scripts/train_dino_v1.py +10 -7
  38. birder/scripts/train_dino_v2.py +10 -7
  39. birder/scripts/train_dino_v2_dist.py +17 -7
  40. birder/scripts/train_franca.py +10 -7
  41. birder/scripts/train_i_jepa.py +17 -13
  42. birder/scripts/train_ibot.py +10 -7
  43. birder/scripts/train_kd.py +24 -18
  44. birder/scripts/train_mim.py +11 -10
  45. birder/scripts/train_mmcr.py +10 -7
  46. birder/scripts/train_rotnet.py +10 -7
  47. birder/scripts/train_simclr.py +10 -7
  48. birder/scripts/train_vicreg.py +10 -7
  49. birder/tools/__main__.py +6 -2
  50. birder/tools/adversarial.py +147 -96
  51. birder/tools/auto_anchors.py +361 -0
  52. birder/tools/ensemble_model.py +1 -1
  53. birder/tools/introspection.py +58 -31
  54. birder/version.py +1 -1
  55. {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/METADATA +2 -1
  56. {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/RECORD +60 -55
  57. {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/WHEEL +0 -0
  58. {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/entry_points.txt +0 -0
  59. {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/licenses/LICENSE +0 -0
  60. {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/top_level.txt +0 -0
@@ -110,10 +110,13 @@ def add_lr_scheduler_args(parser: argparse.ArgumentParser) -> None:
110
110
  type=int,
111
111
  default=40,
112
112
  metavar="N",
113
- help="decrease lr every step-size epochs (for step scheduler only)",
113
+ help="decrease lr every N epochs/steps (relative to after warmup, step scheduler only)",
114
114
  )
115
115
  group.add_argument(
116
- "--lr-steps", type=int, nargs="+", help="decrease lr every step-size epochs (multistep scheduler only)"
116
+ "--lr-steps",
117
+ type=int,
118
+ nargs="+",
119
+ help="absolute epoch/step milestones when to decrease lr (multistep scheduler only)",
117
120
  )
118
121
  group.add_argument(
119
122
  "--lr-step-gamma",
@@ -391,7 +394,7 @@ def add_ema_args(
391
394
  "--model-ema-warmup",
392
395
  type=int,
393
396
  metavar="N",
394
- help="number of epochs before EMA is applied (defaults to warmup epochs/iters, pass 0 to disable warmup)",
397
+ help="number of epochs/steps before EMA is applied (defaults to warmup epochs/steps, pass 0 to disable warmup)",
395
398
  )
396
399
 
397
400
 
@@ -656,6 +659,11 @@ def common_args_validation(args: argparse.Namespace) -> None:
656
659
  f"but it is set to '{args.lr_scheduler_update}'"
657
660
  )
658
661
 
662
+ # EMA
663
+ if hasattr(args, "model_ema_steps") is True:
664
+ if args.model_ema_steps < 1:
665
+ raise ValidationError("--model-ema-steps must be >= 1")
666
+
659
667
  # Compile args, argument dependant
660
668
  if hasattr(args, "compile_teacher") is True:
661
669
  if args.compile is True and args.compile_teacher is True:
@@ -491,12 +491,29 @@ def get_scheduler(
491
491
  if args.lr_scheduler == "constant":
492
492
  main_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0, total_iters=1)
493
493
  elif args.lr_scheduler == "step":
494
+ # Note: StepLR step_size is relative to when the main scheduler starts (after warmup)
495
+ # This means drops occur relative to the end of warmup, not at absolute epoch numbers
494
496
  main_scheduler = torch.optim.lr_scheduler.StepLR(
495
497
  optimizer, step_size=args.lr_step_size, gamma=args.lr_step_gamma
496
498
  )
497
499
  elif args.lr_scheduler == "multistep":
500
+ # For MultiStepLR, milestones should be absolute step numbers
501
+ # Adjust them to be relative to when the main scheduler starts (after warmup)
502
+ # This ensures drops occur at the specified absolute steps, not relative to after warmup
503
+ adjusted_milestones = [m - warmup_steps for m in args.lr_steps if m >= warmup_steps]
504
+ if len(adjusted_milestones) == 0:
505
+ logger.debug(
506
+ f"All MultiStepLR milestones {args.lr_steps} are before warmup "
507
+ f"(warmup ends at step {warmup_steps}). Using empty milestone list."
508
+ )
509
+ adjusted_milestones = []
510
+
511
+ logger.debug(
512
+ f"MultiStepLR milestones adjusted from {args.lr_steps} to {adjusted_milestones} "
513
+ f"(relative to main scheduler start after {warmup_steps} warmup steps)"
514
+ )
498
515
  main_scheduler = torch.optim.lr_scheduler.MultiStepLR(
499
- optimizer, milestones=args.lr_steps, gamma=args.lr_step_gamma
516
+ optimizer, milestones=adjusted_milestones, gamma=args.lr_step_gamma
500
517
  )
501
518
  elif args.lr_scheduler == "cosine":
502
519
  main_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
@@ -1,8 +1,7 @@
1
1
  """
2
2
  Inference-optimized multi-GPU parallelization
3
3
 
4
- This module provides InferenceDataParallel, an inference-specific alternative to
5
- torch.nn.DataParallel.
4
+ This module provides InferenceDataParallel, an inference-specific alternative to torch.nn.DataParallel.
6
5
  """
7
6
 
8
7
  import copy
@@ -1,9 +1,13 @@
1
- from birder.introspection.attention_rollout import AttentionRolloutInterpreter
2
- from birder.introspection.gradcam import GradCamInterpreter
3
- from birder.introspection.guided_backprop import GuidedBackpropInterpreter
1
+ from birder.introspection.attention_rollout import AttentionRollout
2
+ from birder.introspection.base import InterpretabilityResult
3
+ from birder.introspection.gradcam import GradCAM
4
+ from birder.introspection.guided_backprop import GuidedBackprop
5
+ from birder.introspection.transformer_attribution import TransformerAttribution
4
6
 
5
7
  __all__ = [
6
- "AttentionRolloutInterpreter",
7
- "GradCamInterpreter",
8
- "GuidedBackpropInterpreter",
8
+ "InterpretabilityResult",
9
+ "AttentionRollout",
10
+ "GradCAM",
11
+ "GuidedBackprop",
12
+ "TransformerAttribution",
9
13
  ]
@@ -1,5 +1,8 @@
1
1
  """
2
- Adapted from https://github.com/jacobgil/vit-explain/blob/main/vit_rollout.py
2
+ Attention Rollout for Vision Transformers, adapted from
3
+ https://github.com/jacobgil/vit-explain/blob/main/vit_rollout.py
4
+
5
+ Paper "Quantifying Attention Flow in Transformers", https://arxiv.org/abs/2005.00928
3
6
  """
4
7
 
5
8
  # Reference license: MIT
@@ -15,103 +18,168 @@ from PIL import Image
15
18
  from torch import nn
16
19
 
17
20
  from birder.introspection.base import InterpretabilityResult
18
- from birder.introspection.base import Interpreter
21
+ from birder.introspection.base import predict_class
22
+ from birder.introspection.base import preprocess_image
19
23
  from birder.introspection.base import show_mask_on_image
20
24
  from birder.net.vit import Encoder
21
25
 
22
26
 
23
- def rollout(
27
+ # pylint: disable=too-many-locals
28
+ def compute_rollout(
24
29
  attentions: list[torch.Tensor],
25
30
  discard_ratio: float,
26
31
  head_fusion: Literal["mean", "max", "min"],
27
32
  num_special_tokens: int,
33
+ patch_grid_shape: tuple[int, int],
28
34
  ) -> torch.Tensor:
29
- result = torch.eye(attentions[0].size(-1))
35
+ # Assume batch size = 1
36
+ num_tokens = attentions[0].size(-1)
37
+ device = attentions[0].device
38
+
39
+ # Start with identity (residual)
40
+ result = torch.eye(num_tokens, device=device)
41
+
30
42
  with torch.no_grad():
31
43
  for attention in attentions:
44
+ # Fuse heads: [B, H, T, T] -> [B, T, T]
32
45
  if head_fusion == "mean":
33
- attention_heads_fused = attention.mean(axis=1)
46
+ attention_heads_fused = attention.mean(dim=1)
34
47
  elif head_fusion == "max":
35
- attention_heads_fused = attention.max(axis=1)[0]
48
+ attention_heads_fused = attention.max(dim=1)[0]
36
49
  elif head_fusion == "min":
37
- attention_heads_fused = attention.min(axis=1)[0]
50
+ attention_heads_fused = attention.min(dim=1)[0]
38
51
  else:
39
- raise ValueError("Attention head fusion type Not supported")
40
-
41
- # Drop the lowest attentions, but don't drop the class token
42
- flat = attention_heads_fused.view(attention_heads_fused.size(0), -1)
43
- (_, indices) = flat.topk(int(flat.size(-1) * discard_ratio), -1, False)
44
- indices = indices[indices != 0]
45
- flat[0, indices] = 0
46
-
47
- eye = torch.eye(attention_heads_fused.size(-1))
48
- a = (attention_heads_fused + 1.0 * eye) / 2
49
- a = a / a.sum(dim=-1)
50
-
51
- result = torch.matmul(a, result)
52
-
53
- # Look at the total attention between the class token and the image patches
54
- mask = result[0, 0, num_special_tokens:]
55
-
56
- width = int(mask.size(-1) ** 0.5)
57
- mask = mask.reshape(width, width)
58
- mask = mask / torch.max(mask)
52
+ raise ValueError(f"Unsupported head_fusion: {head_fusion}")
53
+
54
+ # attention_heads_fused: [1, T, T] (batch = 1)
55
+ if discard_ratio > 0:
56
+ # Work on the single batch element
57
+ attn = attention_heads_fused[0] # [T, T]
58
+
59
+ # Define which positions are "non-special"
60
+ idx = torch.arange(num_tokens, device=attn.device)
61
+ is_special = idx < num_special_tokens
62
+ non_special = ~is_special
63
+
64
+ # We are only allowed to prune NON-special <-> NON-special entries
65
+ allow = non_special[:, None] & non_special[None, :] # [T, T]
66
+
67
+ allowed_values = attn[allow]
68
+ num_allowed = allowed_values.numel()
69
+ if num_allowed > 0:
70
+ num_to_discard = int(num_allowed * discard_ratio)
71
+ if num_to_discard > 0:
72
+ # Drop the smallest allowed values
73
+ (_, low_idx) = torch.topk(allowed_values, num_to_discard, largest=False)
74
+ allowed_values[low_idx] = 0
75
+ attn[allow] = allowed_values
76
+ attention_heads_fused[0] = attn
77
+
78
+ # Add residual connection and normalize
79
+ eye = torch.eye(num_tokens, device=attention_heads_fused.device)
80
+ a = (attention_heads_fused + eye) / 2.0 # [1, T, T]
81
+ a = a / a.sum(dim=-1, keepdim=True)
82
+
83
+ # Accumulate attention across layers
84
+ result = torch.matmul(a, result) # [1, T, T]
85
+
86
+ rollout = result[0] # [T, T]
87
+
88
+ # Build final token → patch map
89
+ if 0 < num_special_tokens < num_tokens:
90
+ # Sources: all special tokens (0 .. num_special_tokens-1)
91
+ # Targets: all non-special tokens (num_special_tokens .. end)
92
+ source_to_patches = rollout[:num_special_tokens, num_special_tokens:]
93
+ mask = source_to_patches.mean(dim=0)
94
+ else:
95
+ # No special tokens (or all are special): fall back to averaging over all sources
96
+ mask = rollout.mean(dim=0) # [T]
97
+
98
+ # Normalize and reshape to 2D map using actual patch grid dimensions
99
+ mask = mask / (mask.max() + 1e-8)
100
+ (grid_h, grid_w) = patch_grid_shape
101
+ mask = mask.reshape(grid_h, grid_w)
59
102
 
60
103
  return mask
61
104
 
62
105
 
63
- class AttentionRollout:
64
- def __init__(self, net: torch.nn.Module, attention_layer_name: str) -> None:
106
+ class AttentionGatherer:
107
+ def __init__(self, net: nn.Module, attention_layer_name: str) -> None:
65
108
  assert hasattr(net, "encoder") is True and isinstance(net.encoder, Encoder)
109
+
66
110
  net.encoder.set_need_attn()
67
111
  self.net = net
112
+ self.attentions: list[torch.Tensor] = []
113
+ self.handles: list[torch.utils.hooks.RemovableHandle] = []
114
+
115
+ # Register hooks on attention layers
68
116
  for name, module in self.net.named_modules():
69
117
  if name.endswith(attention_layer_name) is True:
70
- module.register_forward_hook(self.get_attention)
118
+ handle = module.register_forward_hook(self._capture_attention)
119
+ self.handles.append(handle)
71
120
 
72
- self.attentions: list[torch.Tensor] = []
73
-
74
- def get_attention(
75
- self, _module: torch.nn.Module, _inputs: tuple[torch.Tensor, ...], outputs: tuple[torch.Tensor, ...]
121
+ def _capture_attention(
122
+ self, _module: nn.Module, _inputs: tuple[torch.Tensor, ...], outputs: tuple[torch.Tensor, ...]
76
123
  ) -> None:
77
124
  self.attentions.append(outputs[1].cpu())
78
125
 
79
- def __call__(
80
- self, x: torch.Tensor, discard_ratio: float, head_fusion: Literal["mean", "max", "min"]
81
- ) -> torch.Tensor:
126
+ def __call__(self, x: torch.Tensor) -> tuple[list[torch.Tensor], torch.Tensor]:
82
127
  self.attentions = []
83
128
  with torch.inference_mode():
84
- self.net(x)
129
+ logits = self.net(x)
130
+
131
+ return (self.attentions, logits)
85
132
 
86
- return rollout(self.attentions, discard_ratio, head_fusion, self.net.num_special_tokens)
133
+ def release(self) -> None:
134
+ for handle in self.handles:
135
+ handle.remove()
87
136
 
88
137
 
89
- class AttentionRolloutInterpreter(Interpreter):
138
+ class AttentionRollout:
90
139
  def __init__(
91
140
  self,
92
- model: nn.Module,
141
+ net: nn.Module,
93
142
  device: torch.device,
94
143
  transform: Callable[..., torch.Tensor],
95
- attention_layer_name: str,
96
- discard_ratio: float,
97
- head_fusion: Literal["mean", "max", "min"],
144
+ attention_layer_name: str = "self_attention",
145
+ discard_ratio: float = 0.9,
146
+ head_fusion: Literal["mean", "max", "min"] = "max",
98
147
  ) -> None:
99
- super().__init__(model, device, transform)
100
- self.attention_rollout = AttentionRollout(model, attention_layer_name)
148
+ if not 0 <= discard_ratio <= 1:
149
+ raise ValueError(f"discard_ratio must be in [0, 1], got {discard_ratio}")
150
+
151
+ self.net = net.eval()
152
+ self.device = device
153
+ self.transform = transform
101
154
  self.discard_ratio = discard_ratio
102
155
  self.head_fusion = head_fusion
156
+ self.attention_gatherer = AttentionGatherer(net, attention_layer_name)
103
157
 
104
- def interpret(self, image: str | Path | Image.Image, target_class: Optional[int] = None) -> InterpretabilityResult:
105
- (input_tensor, rgb_img) = self._preprocess_image(image)
158
+ def __call__(self, image: str | Path | Image.Image, target_class: Optional[int] = None) -> InterpretabilityResult:
159
+ (input_tensor, rgb_img) = preprocess_image(image, self.transform, self.device)
106
160
 
107
- attention_map = self.attention_rollout(
108
- input_tensor, discard_ratio=self.discard_ratio, head_fusion=self.head_fusion
109
- )
161
+ (attentions, logits) = self.attention_gatherer(input_tensor)
110
162
 
111
- # Resize attention map to match image size
163
+ (_, _, H, W) = input_tensor.shape
164
+ patch_grid_shape = (H // self.net.stem_stride, W // self.net.stem_stride)
165
+
166
+ attention_map = compute_rollout(
167
+ attentions, self.discard_ratio, self.head_fusion, self.net.num_special_tokens, patch_grid_shape
168
+ )
112
169
  attention_img = Image.fromarray(attention_map.numpy())
113
- attention_img = attention_img.resize(rgb_img.shape[:2])
170
+ attention_img = attention_img.resize((rgb_img.shape[1], rgb_img.shape[0]))
114
171
  attention_arr = np.array(attention_img)
172
+
115
173
  visualization = show_mask_on_image(rgb_img, attention_arr, image_weight=0.4)
116
174
 
117
- return InterpretabilityResult(rgb_img, visualization, raw_output=attention_arr)
175
+ return InterpretabilityResult(
176
+ original_image=rgb_img,
177
+ visualization=visualization,
178
+ raw_output=attention_arr,
179
+ logits=logits.detach(),
180
+ predicted_class=predict_class(logits),
181
+ )
182
+
183
+ def __del__(self) -> None:
184
+ if hasattr(self, "attention_gatherer") is True:
185
+ self.attention_gatherer.release()
@@ -2,6 +2,7 @@ from collections.abc import Callable
2
2
  from dataclasses import dataclass
3
3
  from pathlib import Path
4
4
  from typing import Optional
5
+ from typing import Protocol
5
6
 
6
7
  import matplotlib
7
8
  import matplotlib.pyplot as plt
@@ -9,13 +10,56 @@ import numpy as np
9
10
  import numpy.typing as npt
10
11
  import torch
11
12
  from PIL import Image
12
- from torch import nn
13
+
14
+
15
+ @dataclass(frozen=True)
16
+ class InterpretabilityResult:
17
+ original_image: npt.NDArray[np.float32]
18
+ visualization: npt.NDArray[np.float32] | npt.NDArray[np.uint8]
19
+ raw_output: npt.NDArray[np.float32]
20
+ logits: Optional[torch.Tensor] = None
21
+ predicted_class: Optional[int] = None
22
+
23
+ def show(self, figsize: tuple[int, int] = (12, 8)) -> None:
24
+ _, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
25
+ ax1.imshow(self.visualization)
26
+ ax2.imshow(self.original_image)
27
+ plt.show()
28
+
29
+
30
+ class Interpreter(Protocol):
31
+ def __call__(
32
+ self, image: str | Path | Image.Image, target_class: Optional[int] = None
33
+ ) -> InterpretabilityResult: ...
34
+
35
+
36
+ def load_image(image: str | Path | Image.Image) -> Image.Image:
37
+ if isinstance(image, (str, Path)):
38
+ return Image.open(image)
39
+
40
+ return image
41
+
42
+
43
+ def preprocess_image(
44
+ image: str | Path | Image.Image, transform: Callable[..., torch.Tensor], device: torch.device
45
+ ) -> tuple[torch.Tensor, npt.NDArray[np.float32]]:
46
+ pil_image = load_image(image)
47
+ input_tensor = transform(pil_image).unsqueeze(dim=0).to(device)
48
+
49
+ # Resize and normalize for visualization
50
+ resized = pil_image.resize((input_tensor.shape[-1], input_tensor.shape[-2]))
51
+ rgb_img = np.array(resized).astype(np.float32) / 255.0
52
+
53
+ return (input_tensor, rgb_img)
13
54
 
14
55
 
15
56
  def show_mask_on_image(
16
- img: npt.NDArray[np.float32], mask: npt.NDArray[np.float32], image_weight: float = 0.5
17
- ) -> npt.NDArray[np.float32]:
18
- color_map = matplotlib.colormaps["jet"]
57
+ img: npt.NDArray[np.float32],
58
+ mask: npt.NDArray[np.float32],
59
+ image_weight: float = 0.5,
60
+ colormap: str = "jet",
61
+ ) -> npt.NDArray[np.uint8]:
62
+ color_map = matplotlib.colormaps[colormap]
19
63
  heatmap = color_map(mask)[:, :, :3]
20
64
 
21
65
  cam: npt.NDArray[np.float32] = (1 - image_weight) * heatmap + image_weight * img
@@ -25,36 +69,36 @@ def show_mask_on_image(
25
69
  return cam.astype(np.uint8)
26
70
 
27
71
 
28
- @dataclass
29
- class InterpretabilityResult:
30
- original_image: npt.NDArray[np.float32]
31
- visualization: npt.NDArray[np.float32] | npt.NDArray[np.uint8]
32
- raw_output: npt.NDArray[np.float32]
72
+ def scale_cam_image(
73
+ cam: npt.NDArray[np.float32], target_size: Optional[tuple[int, int]] = None
74
+ ) -> npt.NDArray[np.float32]:
75
+ result = []
76
+ for img in cam:
77
+ img = img - np.min(img)
78
+ img = img / (1e-7 + np.max(img))
79
+ if target_size is not None:
80
+ img = np.array(Image.fromarray(img).resize(target_size))
33
81
 
34
- def show(self, figsize: tuple[int, int] = (12, 8)) -> None:
35
- _, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
36
- ax1.imshow(self.visualization)
37
- ax2.imshow(self.original_image)
38
- plt.show()
82
+ result.append(img)
83
+
84
+ return np.array(result, dtype=np.float32)
39
85
 
40
86
 
41
- class Interpreter:
42
- def __init__(self, model: nn.Module, device: torch.device, transform: Callable[..., torch.Tensor]) -> None:
43
- self.model = model.eval()
44
- self.device = device
45
- self.transform = transform
87
+ def deprocess_image(img: npt.NDArray[np.float32]) -> npt.NDArray[np.uint8]:
88
+ img = img - np.mean(img)
89
+ img = img / (np.std(img) + 1e-5)
90
+ img = img * 0.1
91
+ img = img + 0.5
92
+ img = np.clip(img, 0, 1)
46
93
 
47
- def interpret(self, image: str | Path | Image.Image, target_class: Optional[int] = None) -> InterpretabilityResult:
48
- raise NotImplementedError
94
+ return np.array(img * 255).astype(np.uint8)
49
95
 
50
- def _preprocess_image(self, image: str | Path | Image.Image) -> tuple[torch.Tensor, npt.NDArray[np.float32]]:
51
- if isinstance(image, (str, Path)):
52
- image = Image.open(image)
53
96
 
54
- # Transform for model
55
- input_tensor = self.transform(image).unsqueeze(dim=0).to(self.device)
97
+ def validate_target_class(target_class: Optional[int], num_classes: int) -> None:
98
+ if target_class is not None:
99
+ if target_class < 0 or target_class >= num_classes:
100
+ raise ValueError(f"target_class must be in range [0, {num_classes}), got {target_class}")
56
101
 
57
- # Store original for visualization
58
- rgb_img = np.array(image.resize(input_tensor.shape[-2:])).astype(np.float32) / 255.0
59
102
 
60
- return (input_tensor, rgb_img)
103
+ def predict_class(logits: torch.Tensor) -> int:
104
+ return int(torch.argmax(logits, dim=-1).item())