birder 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. birder/adversarial/base.py +1 -1
  2. birder/adversarial/simba.py +4 -4
  3. birder/common/cli.py +1 -1
  4. birder/common/fs_ops.py +11 -11
  5. birder/common/lib.py +2 -2
  6. birder/common/masking.py +3 -3
  7. birder/common/training_cli.py +24 -2
  8. birder/common/training_utils.py +28 -4
  9. birder/data/collators/detection.py +9 -1
  10. birder/data/transforms/detection.py +27 -8
  11. birder/data/transforms/mosaic.py +1 -1
  12. birder/datahub/classification.py +3 -3
  13. birder/inference/classification.py +3 -3
  14. birder/inference/data_parallel.py +1 -1
  15. birder/inference/detection.py +5 -5
  16. birder/inference/wbf.py +1 -1
  17. birder/introspection/attention_rollout.py +5 -5
  18. birder/introspection/feature_pca.py +4 -4
  19. birder/introspection/gradcam.py +1 -1
  20. birder/introspection/guided_backprop.py +2 -2
  21. birder/introspection/transformer_attribution.py +3 -3
  22. birder/layers/attention_pool.py +2 -2
  23. birder/model_registry/model_registry.py +2 -1
  24. birder/net/__init__.py +2 -0
  25. birder/net/_rope_vit_configs.py +5 -0
  26. birder/net/_vit_configs.py +0 -13
  27. birder/net/alexnet.py +5 -5
  28. birder/net/base.py +28 -3
  29. birder/net/biformer.py +17 -17
  30. birder/net/cait.py +2 -2
  31. birder/net/cas_vit.py +1 -1
  32. birder/net/coat.py +15 -15
  33. birder/net/convnext_v1.py +2 -10
  34. birder/net/convnext_v1_iso.py +198 -0
  35. birder/net/convnext_v2.py +2 -10
  36. birder/net/crossformer.py +9 -9
  37. birder/net/crossvit.py +1 -1
  38. birder/net/cspnet.py +1 -1
  39. birder/net/cswin_transformer.py +10 -10
  40. birder/net/davit.py +10 -10
  41. birder/net/deit.py +56 -3
  42. birder/net/deit3.py +27 -15
  43. birder/net/detection/__init__.py +4 -0
  44. birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
  45. birder/net/detection/base.py +6 -5
  46. birder/net/detection/deformable_detr.py +26 -28
  47. birder/net/detection/detr.py +9 -9
  48. birder/net/detection/efficientdet.py +9 -28
  49. birder/net/detection/faster_rcnn.py +22 -22
  50. birder/net/detection/fcos.py +8 -8
  51. birder/net/detection/plain_detr.py +852 -0
  52. birder/net/detection/retinanet.py +4 -4
  53. birder/net/detection/rt_detr_v1.py +81 -25
  54. birder/net/detection/rt_detr_v2.py +1147 -0
  55. birder/net/detection/ssd.py +5 -5
  56. birder/net/detection/yolo_v2.py +12 -12
  57. birder/net/detection/yolo_v3.py +19 -19
  58. birder/net/detection/yolo_v4.py +16 -16
  59. birder/net/detection/yolo_v4_tiny.py +3 -3
  60. birder/net/edgenext.py +3 -3
  61. birder/net/edgevit.py +10 -14
  62. birder/net/efficientformer_v1.py +1 -1
  63. birder/net/efficientvim.py +9 -9
  64. birder/net/efficientvit_mit.py +2 -2
  65. birder/net/efficientvit_msft.py +3 -3
  66. birder/net/fasternet.py +1 -1
  67. birder/net/fastvit.py +5 -12
  68. birder/net/flexivit.py +28 -15
  69. birder/net/focalnet.py +5 -9
  70. birder/net/gc_vit.py +11 -11
  71. birder/net/ghostnet_v1.py +1 -1
  72. birder/net/ghostnet_v2.py +1 -1
  73. birder/net/groupmixformer.py +12 -12
  74. birder/net/hgnet_v1.py +1 -1
  75. birder/net/hgnet_v2.py +4 -4
  76. birder/net/hiera.py +6 -6
  77. birder/net/hieradet.py +9 -9
  78. birder/net/hornet.py +3 -3
  79. birder/net/iformer.py +4 -4
  80. birder/net/inception_next.py +4 -14
  81. birder/net/levit.py +3 -3
  82. birder/net/lit_v1.py +13 -15
  83. birder/net/lit_v1_tiny.py +9 -9
  84. birder/net/lit_v2.py +14 -15
  85. birder/net/maxvit.py +10 -22
  86. birder/net/metaformer.py +2 -2
  87. birder/net/mim/crossmae.py +5 -5
  88. birder/net/mim/fcmae.py +3 -5
  89. birder/net/mim/mae_hiera.py +7 -7
  90. birder/net/mim/mae_vit.py +3 -5
  91. birder/net/mim/simmim.py +2 -3
  92. birder/net/mobilenet_v4_hybrid.py +4 -4
  93. birder/net/mobileone.py +5 -12
  94. birder/net/mobilevit_v1.py +2 -2
  95. birder/net/mobilevit_v2.py +5 -9
  96. birder/net/mvit_v2.py +24 -24
  97. birder/net/nextvit.py +2 -2
  98. birder/net/pit.py +11 -26
  99. birder/net/pvt_v1.py +4 -4
  100. birder/net/pvt_v2.py +5 -11
  101. birder/net/regionvit.py +15 -15
  102. birder/net/regnet.py +1 -1
  103. birder/net/repghost.py +4 -5
  104. birder/net/repvgg.py +3 -5
  105. birder/net/repvit.py +2 -2
  106. birder/net/resnest.py +1 -1
  107. birder/net/rope_deit3.py +29 -15
  108. birder/net/rope_flexivit.py +28 -15
  109. birder/net/rope_vit.py +41 -23
  110. birder/net/sequencer2d.py +3 -4
  111. birder/net/shufflenet_v1.py +1 -1
  112. birder/net/shufflenet_v2.py +1 -1
  113. birder/net/simple_vit.py +47 -5
  114. birder/net/smt.py +7 -7
  115. birder/net/ssl/barlow_twins.py +1 -1
  116. birder/net/ssl/byol.py +2 -2
  117. birder/net/ssl/capi.py +3 -3
  118. birder/net/ssl/data2vec2.py +1 -1
  119. birder/net/ssl/dino_v2.py +11 -1
  120. birder/net/ssl/franca.py +26 -2
  121. birder/net/ssl/i_jepa.py +4 -4
  122. birder/net/ssl/mmcr.py +1 -1
  123. birder/net/swiftformer.py +1 -1
  124. birder/net/swin_transformer_v1.py +4 -5
  125. birder/net/swin_transformer_v2.py +4 -7
  126. birder/net/tiny_vit.py +3 -3
  127. birder/net/transnext.py +19 -19
  128. birder/net/uniformer.py +4 -4
  129. birder/net/vgg.py +1 -10
  130. birder/net/vit.py +38 -25
  131. birder/net/vit_parallel.py +35 -20
  132. birder/net/vit_sam.py +10 -10
  133. birder/net/vovnet_v2.py +1 -1
  134. birder/net/xcit.py +9 -7
  135. birder/ops/msda.py +4 -4
  136. birder/ops/swattention.py +10 -10
  137. birder/results/classification.py +3 -3
  138. birder/results/gui.py +8 -8
  139. birder/scripts/benchmark.py +37 -12
  140. birder/scripts/evaluate.py +1 -1
  141. birder/scripts/predict.py +3 -3
  142. birder/scripts/predict_detection.py +2 -2
  143. birder/scripts/train.py +63 -15
  144. birder/scripts/train_barlow_twins.py +10 -7
  145. birder/scripts/train_byol.py +10 -7
  146. birder/scripts/train_capi.py +15 -10
  147. birder/scripts/train_data2vec.py +10 -7
  148. birder/scripts/train_data2vec2.py +10 -7
  149. birder/scripts/train_detection.py +29 -14
  150. birder/scripts/train_dino_v1.py +13 -9
  151. birder/scripts/train_dino_v2.py +27 -14
  152. birder/scripts/train_dino_v2_dist.py +28 -15
  153. birder/scripts/train_franca.py +16 -9
  154. birder/scripts/train_i_jepa.py +12 -9
  155. birder/scripts/train_ibot.py +15 -11
  156. birder/scripts/train_kd.py +64 -17
  157. birder/scripts/train_mim.py +11 -8
  158. birder/scripts/train_mmcr.py +11 -8
  159. birder/scripts/train_rotnet.py +11 -7
  160. birder/scripts/train_simclr.py +10 -7
  161. birder/scripts/train_vicreg.py +10 -7
  162. birder/tools/adversarial.py +4 -4
  163. birder/tools/auto_anchors.py +5 -5
  164. birder/tools/avg_model.py +1 -1
  165. birder/tools/convert_model.py +30 -22
  166. birder/tools/det_results.py +1 -1
  167. birder/tools/download_model.py +1 -1
  168. birder/tools/ensemble_model.py +1 -1
  169. birder/tools/introspection.py +11 -2
  170. birder/tools/labelme_to_coco.py +2 -2
  171. birder/tools/model_info.py +12 -14
  172. birder/tools/pack.py +8 -8
  173. birder/tools/quantize_model.py +53 -4
  174. birder/tools/results.py +2 -2
  175. birder/tools/show_det_iterator.py +19 -6
  176. birder/tools/show_iterator.py +2 -2
  177. birder/tools/similarity.py +5 -5
  178. birder/tools/stats.py +4 -6
  179. birder/tools/voc_to_coco.py +1 -1
  180. birder/version.py +1 -1
  181. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
  182. birder-0.4.1.dist-info/RECORD +300 -0
  183. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
  184. birder-0.4.0.dist-info/RECORD +0 -297
  185. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
  186. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
  187. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
@@ -56,7 +56,7 @@ def pixel_eps_to_normalized(
56
56
 
57
57
 
58
58
  def clamp_normalized(inputs: torch.Tensor, rgb_stats: RGBType) -> torch.Tensor:
59
- (min_val, max_val) = normalized_bounds(rgb_stats, device=inputs.device, dtype=inputs.dtype)
59
+ min_val, max_val = normalized_bounds(rgb_stats, device=inputs.device, dtype=inputs.dtype)
60
60
  return torch.clamp(inputs, min=min_val, max=max_val)
61
61
 
62
62
 
@@ -87,7 +87,7 @@ class SimBA:
87
87
  if self._is_successful(current_logits, label, target_label):
88
88
  return adv_inputs.detach(), num_queries
89
89
 
90
- (_, channels, height, width) = adv_inputs.shape
90
+ _, channels, height, width = adv_inputs.shape
91
91
  num_dims = channels * height * width
92
92
  step = pixel_eps_to_normalized(self.step_size, self.rgb_stats, device=adv_inputs.device, dtype=adv_inputs.dtype)
93
93
  step_vals = step.view(-1) # Per-channel steps
@@ -98,11 +98,11 @@ class SimBA:
98
98
 
99
99
  # Coordinate-wise search in random order
100
100
  for flat_idx in perm[:num_steps]:
101
- (c, rem) = divmod(int(flat_idx.item()), stride)
102
- (h, w) = divmod(rem, width)
101
+ c, rem = divmod(int(flat_idx.item()), stride)
102
+ h, w = divmod(rem, width)
103
103
  step_val = step_vals[c]
104
104
 
105
- (candidate_inputs, candidate_logits, candidate_objective) = self._best_candidate(
105
+ candidate_inputs, candidate_logits, candidate_objective = self._best_candidate(
106
106
  adv_inputs, c, h, w, step_val, label, target_label
107
107
  )
108
108
  num_queries += 2
birder/common/cli.py CHANGED
@@ -49,7 +49,7 @@ class FlexibleDictAction(argparse.Action):
49
49
  new_dict = {}
50
50
  for pair in pairs:
51
51
  # Split each pair into key and value
52
- (key, value) = pair.split("=", 1)
52
+ key, value = pair.split("=", 1)
53
53
  key = key.strip()
54
54
 
55
55
  # Try to safely evaluate the value (handles ints and strings mostly)
birder/common/fs_ops.py CHANGED
@@ -384,7 +384,7 @@ def load_checkpoint(
384
384
  )
385
385
 
386
386
  # Initialize network and restore checkpoint state
387
- net = registry.net_factory(network, input_channels, num_classes, config=config, size=size)
387
+ net = registry.net_factory(network, num_classes, input_channels, config=config, size=size)
388
388
 
389
389
  # When a checkpoint was trained with EMA:
390
390
  # The primary weights in the checkpoint file are the EMA weights
@@ -437,7 +437,7 @@ def load_mim_checkpoint(
437
437
  size = lib.get_size_from_signature(signature)
438
438
 
439
439
  # Initialize network and restore checkpoint state
440
- net_encoder = registry.net_factory(encoder, input_channels, num_classes, config=encoder_config, size=size)
440
+ net_encoder = registry.net_factory(encoder, num_classes, input_channels, config=encoder_config, size=size)
441
441
  net = registry.mim_net_factory(
442
442
  network, net_encoder, config=config, size=size, mask_ratio=mask_ratio, min_mask_size=min_mask_size
443
443
  )
@@ -488,7 +488,7 @@ def load_detection_checkpoint(
488
488
  size = lib.get_size_from_signature(signature)
489
489
 
490
490
  # Initialize network and restore checkpoint state
491
- net_backbone = registry.net_factory(backbone, input_channels, num_classes, config=backbone_config, size=size)
491
+ net_backbone = registry.net_factory(backbone, num_classes, input_channels, config=backbone_config, size=size)
492
492
  net = registry.detection_net_factory(network, num_classes, net_backbone, config=config, size=size)
493
493
 
494
494
  # When a checkpoint was trained with EMA:
@@ -584,7 +584,7 @@ def load_model(
584
584
  merged_config = None # type: ignore[assignment]
585
585
 
586
586
  model_state: dict[str, Any] = safetensors.torch.load_file(path, device=device.type)
587
- net = registry.net_factory(network, input_channels, num_classes, config=merged_config, size=size)
587
+ net = registry.net_factory(network, num_classes, input_channels, config=merged_config, size=size)
588
588
  if reparameterized is True:
589
589
  net.reparameterize_model()
590
590
 
@@ -611,7 +611,7 @@ def load_model(
611
611
  if len(merged_config) == 0:
612
612
  merged_config = None
613
613
 
614
- net = registry.net_factory(network, input_channels, num_classes, config=merged_config, size=size)
614
+ net = registry.net_factory(network, num_classes, input_channels, config=merged_config, size=size)
615
615
  if reparameterized is True:
616
616
  net.reparameterize_model()
617
617
 
@@ -733,7 +733,7 @@ def load_detection_model(
733
733
 
734
734
  model_state: dict[str, Any] = safetensors.torch.load_file(path, device=device.type)
735
735
  net_backbone = registry.net_factory(
736
- backbone, input_channels, num_classes, config=backbone_merged_config, size=size
736
+ backbone, num_classes, input_channels, config=backbone_merged_config, size=size
737
737
  )
738
738
  if backbone_reparameterized is True:
739
739
  net_backbone.reparameterize_model()
@@ -776,7 +776,7 @@ def load_detection_model(
776
776
  merged_config = None
777
777
 
778
778
  net_backbone = registry.net_factory(
779
- backbone, input_channels, num_classes, config=backbone_merged_config, size=size
779
+ backbone, num_classes, input_channels, config=backbone_merged_config, size=size
780
780
  )
781
781
  if backbone_reparameterized is True:
782
782
  net_backbone.reparameterize_model()
@@ -959,7 +959,7 @@ def load_model_with_cfg(
959
959
  encoder_name = cfg["encoder"]
960
960
 
961
961
  encoder_config = cfg.get("encoder_config", None)
962
- encoder = registry.net_factory(encoder_name, input_channels, num_classes=0, config=encoder_config, size=size)
962
+ encoder = registry.net_factory(encoder_name, 0, input_channels, config=encoder_config, size=size)
963
963
  net = registry.mim_net_factory(name, encoder, config=model_config, size=size)
964
964
 
965
965
  elif cfg["task"] == Task.OBJECT_DETECTION:
@@ -969,14 +969,14 @@ def load_model_with_cfg(
969
969
  backbone_name = cfg["backbone"]
970
970
 
971
971
  backbone_config = cfg.get("backbone_config", None)
972
- backbone = registry.net_factory(backbone_name, input_channels, num_classes, config=backbone_config, size=size)
972
+ backbone = registry.net_factory(backbone_name, num_classes, input_channels, config=backbone_config, size=size)
973
973
  if cfg.get("backbone_reparameterized", False) is True:
974
974
  backbone.reparameterize_model()
975
975
 
976
976
  net = registry.detection_net_factory(name, num_classes, backbone, config=model_config, size=size)
977
977
 
978
978
  elif cfg["task"] == Task.IMAGE_CLASSIFICATION:
979
- net = registry.net_factory(name, input_channels, num_classes, config=model_config, size=size)
979
+ net = registry.net_factory(name, num_classes, input_channels, config=model_config, size=size)
980
980
 
981
981
  else:
982
982
  raise ValueError(f"Configuration not supported: {cfg['task']}")
@@ -1019,7 +1019,7 @@ def download_model_by_weights(
1019
1019
  f"Requested format '{file_format}' not available for {weights}, available formats are: {available_formats}"
1020
1020
  )
1021
1021
 
1022
- (model_file, url) = get_pretrained_model_url(weights, file_format)
1022
+ model_file, url = get_pretrained_model_url(weights, file_format)
1023
1023
  if dst is None:
1024
1024
  dst = settings.MODELS_DIR.joinpath(model_file)
1025
1025
 
birder/common/lib.py CHANGED
@@ -157,6 +157,6 @@ def get_pretrained_model_url(weights: str, file_format: str) -> tuple[str, str]:
157
157
 
158
158
  def format_duration(seconds: float) -> str:
159
159
  s = int(seconds)
160
- (mm, ss) = divmod(s, 60)
161
- (hh, mm) = divmod(mm, 60)
160
+ mm, ss = divmod(s, 60)
161
+ hh, mm = divmod(mm, 60)
162
162
  return f"{hh:d}:{mm:02d}:{ss:02d}"
birder/common/masking.py CHANGED
@@ -16,7 +16,7 @@ def _mask_token_omission(
16
16
  Parameters
17
17
  ----------
18
18
  x
19
- Tensor of shape (N, L, D), where N is the batch size, L is the sequence length, and D is the feature dimension.
19
+ Tensor of shape (N, L, D), where N is the batch size, L is the sequence length and D is the feature dimension.
20
20
  mask_ratio
21
21
  The ratio of the sequence length to be masked. This value should be between 0 and 1.
22
22
  kept_mask_ratio
@@ -48,7 +48,7 @@ def _mask_token_omission(
48
48
  # Masking: length -> length * mask_ratio
49
49
  # Perform per-sample random masking by per-sample shuffling.
50
50
  # Per-sample shuffling is done by argsort random noise.
51
- (N, L, D) = x.size() # batch, length, dim
51
+ N, L, D = x.size() # batch, length, dim
52
52
  len_keep = int(L * (1 - mask_ratio))
53
53
  len_masked = int(L * (mask_ratio - kept_mask_ratio))
54
54
 
@@ -82,7 +82,7 @@ def mask_tensor(
82
82
  if channels_last is False:
83
83
  x = x.permute(0, 2, 3, 1)
84
84
 
85
- (B, H, W, _) = x.size()
85
+ B, H, W, _ = x.size()
86
86
 
87
87
  shaped_mask = mask.reshape(B, H // patch_factor, W // patch_factor)
88
88
  shaped_mask = shaped_mask.repeat_interleave(patch_factor, dim=1).repeat_interleave(patch_factor, dim=2)
@@ -13,6 +13,7 @@ from birder.conf import settings
13
13
  from birder.data.datasets.coco import MosaicType
14
14
  from birder.data.transforms.classification import AugType
15
15
  from birder.data.transforms.classification import RGBMode
16
+ from birder.data.transforms.detection import MULTISCALE_STEP
16
17
  from birder.data.transforms.detection import AugType as DetAugType
17
18
 
18
19
  logger = logging.getLogger(__name__)
@@ -199,10 +200,16 @@ def add_detection_input_args(parser: argparse.ArgumentParser) -> None:
199
200
  action="store_true",
200
201
  help="enable random square resize once per batch (capped by max(--size))",
201
202
  )
203
+ group.add_argument(
204
+ "--multiscale-step",
205
+ type=int,
206
+ default=MULTISCALE_STEP,
207
+ help="step size for multiscale size lists and collator padding divisibility (size_divisible)",
208
+ )
202
209
  group.add_argument(
203
210
  "--multiscale-min-size",
204
211
  type=int,
205
- help="minimum short-edge size for multiscale lists (rounded up to nearest multiple of 32)",
212
+ help="minimum short-edge size for multiscale lists (rounded up to nearest multiple of --multiscale-step)",
206
213
  )
207
214
 
208
215
 
@@ -515,7 +522,10 @@ def add_distributed_args(parser: argparse.ArgumentParser) -> None:
515
522
 
516
523
 
517
524
  def add_logging_and_debug_args(
518
- parser: argparse.ArgumentParser, default_log_interval: int = 50, fake_data: bool = True
525
+ parser: argparse.ArgumentParser,
526
+ default_log_interval: int = 50,
527
+ fake_data: bool = True,
528
+ classification: bool = False,
519
529
  ) -> None:
520
530
  group = parser.add_argument_group("Logging and debugging parameters")
521
531
  group.add_argument(
@@ -525,6 +535,11 @@ def add_logging_and_debug_args(
525
535
  metavar="NAME",
526
536
  help="experiment name for logging (creates dedicated directory for the run)",
527
537
  )
538
+ if classification is True:
539
+ group.add_argument(
540
+ "--top-k", type=int, metavar="K", help="additional top-k accuracy value to track (top-1 is always tracked)"
541
+ )
542
+
528
543
  group.add_argument(
529
544
  "--log-interval",
530
545
  type=int,
@@ -746,3 +761,10 @@ def common_args_validation(args: argparse.Namespace) -> None:
746
761
  # Precision_args, shared by all scripts
747
762
  if args.amp is True and args.model_dtype != "float32":
748
763
  raise ValidationError("--amp can only be used with --model-dtype float32")
764
+
765
+ if hasattr(args, "top_k") is True and args.top_k is not None:
766
+ if args.top_k == 1:
767
+ raise ValidationError("Top-1 accuracy is tracked by default, please remove 1 from --top-k argument")
768
+
769
+ if args.top_k <= 0:
770
+ raise ValidationError("--top-k value must be a positive integer")
@@ -11,6 +11,7 @@ from collections import deque
11
11
  from collections.abc import Callable
12
12
  from collections.abc import Generator
13
13
  from collections.abc import Iterator
14
+ from collections.abc import Sequence
14
15
  from datetime import datetime
15
16
  from pathlib import Path
16
17
  from typing import Any
@@ -361,7 +362,7 @@ def optimizer_parameter_groups(
361
362
  Return parameter groups for optimizers with per-parameter group weight decay.
362
363
 
363
364
  This function creates parameter groups with customizable weight decay, layer-wise
364
- learning rate scaling, and special handling for different parameter types. It supports
365
+ learning rate scaling and special handling for different parameter types. It supports
365
366
  advanced optimization techniques like layer decay and custom weight decay rules.
366
367
 
367
368
  Referenced from https://github.com/pytorch/vision/blob/main/references/classification/utils.py and from
@@ -450,7 +451,7 @@ def optimizer_parameter_groups(
450
451
  visited_modules = []
451
452
  while len(module_stack_with_prefix) > 0: # pylint: disable=too-many-nested-blocks
452
453
  skip_module = False
453
- (module, prefix) = module_stack_with_prefix.pop()
454
+ module, prefix = module_stack_with_prefix.pop()
454
455
  if id(module) in visited_modules:
455
456
  skip_module = True
456
457
 
@@ -884,6 +885,11 @@ class SmoothedValue:
884
885
  self.total: torch.Tensor | float = 0.0
885
886
  self.count: int = 0
886
887
 
888
+ def clear(self) -> None:
889
+ self.deque.clear()
890
+ self.total = 0.0
891
+ self.count = 0
892
+
887
893
  def update(self, value: torch.Tensor | float, n: int = 1) -> None:
888
894
  self.deque.append(value)
889
895
  self.count += n
@@ -927,14 +933,32 @@ class SmoothedValue:
927
933
  return to_tensor(v, torch.device("cpu")).item() # type: ignore[no-any-return]
928
934
 
929
935
 
930
- def accuracy(y_true: torch.Tensor, y_pred: torch.Tensor) -> float:
936
+ @torch.no_grad() # type: ignore[untyped-decorator]
937
+ def accuracy(y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
931
938
  if y_pred.dim() > 1 and y_pred.size(1) > 1:
932
939
  y_pred = y_pred.argmax(dim=1)
933
940
 
934
941
  y_true = y_true.flatten()
935
942
  y_pred = y_pred.flatten()
936
943
 
937
- return (y_true == y_pred).float().mean().item() # type: ignore[no-any-return]
944
+ return (y_true == y_pred).sum() / y_true.numel()
945
+
946
+
947
+ @torch.no_grad() # type: ignore[untyped-decorator]
948
+ def topk_accuracy(y_true: torch.Tensor, y_pred: torch.Tensor, topk: Sequence[int]) -> list[torch.Tensor]:
949
+ maxk = min(max(topk), y_pred.size(1))
950
+ batch_size = y_true.size(0)
951
+
952
+ _, pred = y_pred.topk(maxk, dim=1, largest=True, sorted=True)
953
+ correct = pred.eq(y_true.unsqueeze(1))
954
+
955
+ res: list[torch.Tensor] = []
956
+ for k in topk:
957
+ k = min(k, maxk)
958
+ correct_k = correct[:, :k].any(dim=1).sum(dtype=torch.float32)
959
+ res.append((correct_k / batch_size))
960
+
961
+ return res
938
962
 
939
963
 
940
964
  ###############################################################################
@@ -70,13 +70,21 @@ class BatchRandomResizeCollator(DetectionCollator):
70
70
  size: tuple[int, int],
71
71
  size_divisible: int = 32,
72
72
  multiscale_min_size: Optional[int] = None,
73
+ multiscale_step: Optional[int] = None,
73
74
  ) -> None:
74
75
  super().__init__(input_offset, size_divisible=size_divisible)
75
76
  if size is None:
76
77
  raise ValueError("size must be provided for batch multiscale")
77
78
 
78
79
  max_side = max(size)
79
- sizes = [side for side in build_multiscale_sizes(multiscale_min_size) if side <= max_side]
80
+ if multiscale_step is None:
81
+ multiscale_step = size_divisible
82
+
83
+ sizes = []
84
+ for side in build_multiscale_sizes(multiscale_min_size, multiscale_step=multiscale_step):
85
+ if side <= max_side:
86
+ sizes.append(side)
87
+
80
88
  if len(sizes) == 0:
81
89
  sizes = [max_side]
82
90
 
@@ -17,17 +17,20 @@ DEFAULT_MULTISCALE_MAX_SIZE = 800
17
17
 
18
18
 
19
19
  def build_multiscale_sizes(
20
- min_size: Optional[int] = None, max_size: int = DEFAULT_MULTISCALE_MAX_SIZE
20
+ min_size: Optional[int] = None, max_size: int = DEFAULT_MULTISCALE_MAX_SIZE, multiscale_step: int = MULTISCALE_STEP
21
21
  ) -> tuple[int, ...]:
22
+ if multiscale_step <= 0:
23
+ raise ValueError("multiscale_step must be positive")
24
+
22
25
  if min_size is None:
23
26
  min_size = DEFAULT_MULTISCALE_MIN_SIZE
24
27
 
25
- start = int(math.ceil(min_size / MULTISCALE_STEP) * MULTISCALE_STEP)
26
- end = int(math.floor(max_size / MULTISCALE_STEP) * MULTISCALE_STEP)
28
+ start = int(math.ceil(min_size / multiscale_step) * multiscale_step)
29
+ end = int(math.floor(max_size / multiscale_step) * multiscale_step)
27
30
  if end < start:
28
31
  return (start,)
29
32
 
30
- return tuple(range(start, end + 1, MULTISCALE_STEP))
33
+ return tuple(range(start, end + 1, multiscale_step))
31
34
 
32
35
 
33
36
  class ResizeWithRandomInterpolation(nn.Module):
@@ -59,6 +62,7 @@ def get_birder_augment(
59
62
  multiscale: bool,
60
63
  max_size: Optional[int],
61
64
  multiscale_min_size: Optional[int],
65
+ multiscale_step: int = MULTISCALE_STEP,
62
66
  post_mosaic: bool = False,
63
67
  ) -> Callable[..., torch.Tensor]:
64
68
  if dynamic_size is True:
@@ -98,7 +102,10 @@ def get_birder_augment(
98
102
  # Resize
99
103
  if multiscale is True:
100
104
  transformations.append(
101
- v2.RandomShortestSize(min_size=build_multiscale_sizes(multiscale_min_size), max_size=max_size or 1333),
105
+ v2.RandomShortestSize(
106
+ min_size=build_multiscale_sizes(multiscale_min_size, multiscale_step=multiscale_step),
107
+ max_size=max_size or 1333,
108
+ ),
102
109
  )
103
110
  else:
104
111
  transformations.append(
@@ -160,6 +167,7 @@ def training_preset(
160
167
  multiscale: bool = False,
161
168
  max_size: Optional[int] = None,
162
169
  multiscale_min_size: Optional[int] = None,
170
+ multiscale_step: int = MULTISCALE_STEP,
163
171
  post_mosaic: bool = False,
164
172
  ) -> Callable[..., torch.Tensor]:
165
173
  mean = rgv_values["mean"]
@@ -180,7 +188,15 @@ def training_preset(
180
188
  [
181
189
  v2.ToImage(),
182
190
  get_birder_augment(
183
- size, level, fill_value, dynamic_size, multiscale, max_size, multiscale_min_size, post_mosaic
191
+ size,
192
+ level,
193
+ fill_value,
194
+ dynamic_size,
195
+ multiscale,
196
+ max_size,
197
+ multiscale_min_size,
198
+ multiscale_step,
199
+ post_mosaic,
184
200
  ),
185
201
  v2.ToDtype(torch.float32, scale=True),
186
202
  v2.Normalize(mean=mean, std=std),
@@ -212,7 +228,10 @@ def training_preset(
212
228
  return v2.Compose( # type: ignore
213
229
  [
214
230
  v2.ToImage(),
215
- v2.RandomShortestSize(min_size=build_multiscale_sizes(multiscale_min_size), max_size=max_size or 1333),
231
+ v2.RandomShortestSize(
232
+ min_size=build_multiscale_sizes(multiscale_min_size, multiscale_step=multiscale_step),
233
+ max_size=max_size or 1333,
234
+ ),
216
235
  v2.RandomHorizontalFlip(0.5),
217
236
  v2.SanitizeBoundingBoxes(),
218
237
  v2.ToDtype(torch.float32, scale=True),
@@ -284,7 +303,7 @@ def training_preset(
284
303
  )
285
304
 
286
305
  if aug_type == "detr":
287
- multiscale_sizes = build_multiscale_sizes(multiscale_min_size)
306
+ multiscale_sizes = build_multiscale_sizes(multiscale_min_size, multiscale_step=multiscale_step)
288
307
  return v2.Compose( # type: ignore
289
308
  [
290
309
  v2.ToImage(),
@@ -19,7 +19,7 @@ def mosaic_random_center(
19
19
  Create a mosaic augmentation by combining 4 images into a single image.
20
20
 
21
21
  This augmentation places 4 images on a canvas, meeting at a randomly selected
22
- center point. Each image is scaled to fit, cropped as needed, and their bounding
22
+ center point. Each image is scaled to fit, cropped as needed and their bounding
23
23
  boxes are transformed accordingly.
24
24
 
25
25
  Parameters
@@ -63,7 +63,7 @@ class TestDataset(ImageFolder):
63
63
  super().__init__(self._root.joinpath(split), transform, target_transform, loader, is_valid_file)
64
64
 
65
65
  def __getitem__(self, index: int) -> tuple[str, torch.Tensor, Any]:
66
- (path, target) = self.samples[index]
66
+ path, target = self.samples[index]
67
67
  sample = self.loader(path)
68
68
  if self.transform is not None:
69
69
  sample = self.transform(sample)
@@ -122,7 +122,7 @@ class Flowers102(ImageFolder):
122
122
  super().__init__(self._root.joinpath(split), transform, target_transform, loader, is_valid_file)
123
123
 
124
124
  def __getitem__(self, index: int) -> tuple[str, torch.Tensor, Any]:
125
- (path, target) = self.samples[index]
125
+ path, target = self.samples[index]
126
126
  sample = self.loader(path)
127
127
  if self.transform is not None:
128
128
  sample = self.transform(sample)
@@ -182,7 +182,7 @@ class CUB_200_2011(ImageFolder):
182
182
  super().__init__(self._root.joinpath(split), transform, target_transform, loader, is_valid_file)
183
183
 
184
184
  def __getitem__(self, index: int) -> tuple[str, torch.Tensor, Any]:
185
- (path, target) = self.samples[index]
185
+ path, target = self.samples[index]
186
186
  sample = self.loader(path)
187
187
  if self.transform is not None:
188
188
  sample = self.transform(sample)
@@ -75,7 +75,7 @@ def infer_batch(
75
75
  embedding = embedding_tensor.cpu().float().numpy()
76
76
 
77
77
  elif tta is True:
78
- (_, _, H, W) = inputs.size()
78
+ _, _, H, W = inputs.size()
79
79
  crop_h = int(H * 0.8)
80
80
  crop_w = int(W * 0.8)
81
81
  tta_inputs = five_crop(inputs, size=[crop_h, crop_w])
@@ -137,7 +137,7 @@ def infer_dataloader_iter(
137
137
  inputs = inputs.to(device, dtype=model_dtype)
138
138
 
139
139
  with torch.amp.autocast(device.type, enabled=amp, dtype=amp_dtype):
140
- (out, embedding) = infer_batch(
140
+ out, embedding = infer_batch(
141
141
  net, inputs, return_embedding=return_embedding, tta=tta, return_logits=return_logits, **kwargs
142
142
  )
143
143
 
@@ -394,7 +394,7 @@ def evaluate(
394
394
  num_samples: Optional[int] = None,
395
395
  sparse: bool = False,
396
396
  ) -> Results | SparseResults:
397
- (sample_paths, outs, labels, _) = infer_dataloader(
397
+ sample_paths, outs, labels, _ = infer_dataloader(
398
398
  device, net, dataloader, tta=tta, model_dtype=model_dtype, amp=amp, amp_dtype=amp_dtype, num_samples=num_samples
399
399
  )
400
400
  if sparse is True:
@@ -253,7 +253,7 @@ class InferenceDataParallel(nn.Module):
253
253
 
254
254
  This allows custom methods (e.g., model.embedding()) to be called
255
255
  on the InferenceDataParallel instance, which then scatters inputs,
256
- calls the method on each replica, and gathers the results.
256
+ calls the method on each replica and gathers the results.
257
257
 
258
258
  Parameters
259
259
  ----------
@@ -20,7 +20,7 @@ def _normalize_image_sizes(inputs: torch.Tensor, image_sizes: Optional[list[list
20
20
  if image_sizes is not None:
21
21
  return image_sizes
22
22
 
23
- (_, _, height, width) = inputs.shape
23
+ _, _, height, width = inputs.shape
24
24
  return [[height, width] for _ in range(inputs.size(0))]
25
25
 
26
26
 
@@ -149,20 +149,20 @@ def infer_batch(
149
149
  **kwargs: Any,
150
150
  ) -> list[dict[str, torch.Tensor]]:
151
151
  if tta is False:
152
- (detections, _) = net(inputs, masks=masks, image_sizes=image_sizes, **kwargs)
152
+ detections, _ = net(inputs, masks=masks, image_sizes=image_sizes, **kwargs)
153
153
  return detections # type: ignore[no-any-return]
154
154
 
155
155
  normalized_sizes = _normalize_image_sizes(inputs, image_sizes)
156
156
  detections_list: list[list[dict[str, torch.Tensor]]] = []
157
157
 
158
158
  for scale in (0.8, 1.0, 1.2):
159
- (scaled_inputs, scaled_masks, scaled_sizes) = _resize_batch(inputs, normalized_sizes, scale, size_divisible=32)
160
- (detections, _) = net(scaled_inputs, masks=scaled_masks, image_sizes=scaled_sizes, **kwargs)
159
+ scaled_inputs, scaled_masks, scaled_sizes = _resize_batch(inputs, normalized_sizes, scale, size_divisible=32)
160
+ detections, _ = net(scaled_inputs, masks=scaled_masks, image_sizes=scaled_sizes, **kwargs)
161
161
  detections = _rescale_detections(detections, scaled_sizes, normalized_sizes)
162
162
  detections_list.append(detections)
163
163
 
164
164
  flipped_inputs = _hflip_inputs(scaled_inputs, scaled_sizes)
165
- (flipped_detections, _) = net(flipped_inputs, masks=scaled_masks, image_sizes=scaled_sizes, **kwargs)
165
+ flipped_detections, _ = net(flipped_inputs, masks=scaled_masks, image_sizes=scaled_sizes, **kwargs)
166
166
  flipped_detections = _invert_detections(flipped_detections, scaled_sizes)
167
167
  flipped_detections = _rescale_detections(flipped_detections, scaled_sizes, normalized_sizes)
168
168
  detections_list.append(flipped_detections)
birder/inference/wbf.py CHANGED
@@ -182,7 +182,7 @@ def fuse_detections_wbf_single(
182
182
  scores_list = [detection["scores"] for detection in detections]
183
183
  labels_list = [detection["labels"] for detection in detections]
184
184
 
185
- (boxes, scores, labels) = weighted_boxes_fusion(
185
+ boxes, scores, labels = weighted_boxes_fusion(
186
186
  boxes_list,
187
187
  scores_list,
188
188
  labels_list,
@@ -70,7 +70,7 @@ def compute_rollout(
70
70
  num_to_discard = int(num_allowed * discard_ratio)
71
71
  if num_to_discard > 0:
72
72
  # Drop the smallest allowed values
73
- (_, low_idx) = torch.topk(allowed_values, num_to_discard, largest=False)
73
+ _, low_idx = torch.topk(allowed_values, num_to_discard, largest=False)
74
74
  allowed_values[low_idx] = 0
75
75
  attn[allow] = allowed_values
76
76
  attention_heads_fused[0] = attn
@@ -97,7 +97,7 @@ def compute_rollout(
97
97
 
98
98
  # Normalize and reshape to 2D map using actual patch grid dimensions
99
99
  mask = mask / (mask.max() + 1e-8)
100
- (grid_h, grid_w) = patch_grid_shape
100
+ grid_h, grid_w = patch_grid_shape
101
101
  mask = mask.reshape(grid_h, grid_w)
102
102
 
103
103
  return mask
@@ -156,11 +156,11 @@ class AttentionRollout:
156
156
  self.attention_gatherer = AttentionGatherer(net, attention_layer_name)
157
157
 
158
158
  def __call__(self, image: str | Path | Image.Image, target_class: Optional[int] = None) -> InterpretabilityResult:
159
- (input_tensor, rgb_img) = preprocess_image(image, self.transform, self.device)
159
+ input_tensor, rgb_img = preprocess_image(image, self.transform, self.device)
160
160
 
161
- (attentions, logits) = self.attention_gatherer(input_tensor)
161
+ attentions, logits = self.attention_gatherer(input_tensor)
162
162
 
163
- (_, _, H, W) = input_tensor.shape
163
+ _, _, H, W = input_tensor.shape
164
164
  patch_grid_shape = (H // self.net.stem_stride, W // self.net.stem_stride)
165
165
 
166
166
  attention_map = compute_rollout(
@@ -17,7 +17,7 @@ class FeaturePCA:
17
17
  Visualizes feature maps using Principal Component Analysis
18
18
 
19
19
  This method extracts feature maps from a specified stage of a DetectorBackbone model,
20
- applies PCA to reduce the channel dimension to 3 components, and visualizes them as an RGB image where:
20
+ applies PCA to reduce the channel dimension to 3 components and visualizes them as an RGB image where:
21
21
  - R channel = 1st principal component (most important)
22
22
  - G channel = 2nd principal component
23
23
  - B channel = 3rd principal component
@@ -40,7 +40,7 @@ class FeaturePCA:
40
40
  self.stage = stage
41
41
 
42
42
  def __call__(self, image: str | Path | Image.Image) -> InterpretabilityResult:
43
- (input_tensor, rgb_img) = preprocess_image(image, self.transform, self.device)
43
+ input_tensor, rgb_img = preprocess_image(image, self.transform, self.device)
44
44
 
45
45
  with torch.inference_mode():
46
46
  features_dict = self.net.detection_features(input_tensor)
@@ -54,11 +54,11 @@ class FeaturePCA:
54
54
 
55
55
  # Handle channels_last format (B, H, W, C) vs channels_first (B, C, H, W)
56
56
  if self.channels_last is True:
57
- (B, H, W, C) = features_np.shape
57
+ B, H, W, C = features_np.shape
58
58
  # Already in (B, H, W, C), just reshape to (B*H*W, C)
59
59
  features_reshaped = features_np.reshape(-1, C)
60
60
  else:
61
- (B, C, H, W) = features_np.shape
61
+ B, C, H, W = features_np.shape
62
62
  # Reshape to (spatial_points, channels) for PCA
63
63
  features_reshaped = features_np.reshape(B, C, -1)
64
64
  features_reshaped = features_reshaped.transpose(0, 2, 1) # (B, H*W, C)
@@ -98,7 +98,7 @@ class GradCAM:
98
98
  self.activation_capture = ActivationCapture(net, target_layer, reshape_transform)
99
99
 
100
100
  def __call__(self, image: str | Path | Image.Image, target_class: Optional[int] = None) -> InterpretabilityResult:
101
- (input_tensor, rgb_img) = preprocess_image(image, self.transform, self.device)
101
+ input_tensor, rgb_img = preprocess_image(image, self.transform, self.device)
102
102
  input_tensor.requires_grad_(True)
103
103
 
104
104
  # Forward pass
@@ -38,7 +38,7 @@ class GuidedBackpropReLU(Function):
38
38
 
39
39
  @staticmethod
40
40
  def backward(ctx: Any, grad_output: torch.Tensor) -> torch.Tensor:
41
- (input_img, _output) = ctx.saved_tensors
41
+ input_img, _output = ctx.saved_tensors
42
42
 
43
43
  positive_mask_1 = (input_img > 0).type_as(grad_output)
44
44
  positive_mask_2 = (grad_output > 0).type_as(grad_output)
@@ -190,7 +190,7 @@ class GuidedBackprop:
190
190
  self.transform = transform
191
191
 
192
192
  def __call__(self, image: str | Path | Image.Image, target_class: Optional[int] = None) -> InterpretabilityResult:
193
- (input_tensor, rgb_img) = preprocess_image(image, self.transform, self.device)
193
+ input_tensor, rgb_img = preprocess_image(image, self.transform, self.device)
194
194
 
195
195
  # Get prediction
196
196
  with torch.inference_mode():