birder 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. birder/adversarial/deepfool.py +2 -0
  2. birder/adversarial/simba.py +2 -0
  3. birder/common/fs_ops.py +2 -2
  4. birder/common/masking.py +13 -4
  5. birder/common/training_cli.py +6 -1
  6. birder/common/training_utils.py +4 -2
  7. birder/inference/classification.py +1 -1
  8. birder/introspection/__init__.py +2 -0
  9. birder/introspection/base.py +0 -7
  10. birder/introspection/feature_pca.py +101 -0
  11. birder/kernels/soft_nms/soft_nms.cpp +5 -2
  12. birder/model_registry/model_registry.py +3 -2
  13. birder/net/base.py +3 -3
  14. birder/net/biformer.py +2 -2
  15. birder/net/cas_vit.py +6 -6
  16. birder/net/coat.py +8 -8
  17. birder/net/conv2former.py +2 -2
  18. birder/net/convnext_v1.py +22 -2
  19. birder/net/convnext_v2.py +2 -2
  20. birder/net/crossformer.py +2 -2
  21. birder/net/cspnet.py +2 -2
  22. birder/net/cswin_transformer.py +2 -2
  23. birder/net/darknet.py +2 -2
  24. birder/net/davit.py +2 -2
  25. birder/net/deit.py +3 -3
  26. birder/net/deit3.py +3 -3
  27. birder/net/densenet.py +2 -2
  28. birder/net/detection/deformable_detr.py +2 -2
  29. birder/net/detection/detr.py +2 -2
  30. birder/net/detection/efficientdet.py +2 -2
  31. birder/net/detection/faster_rcnn.py +2 -2
  32. birder/net/detection/fcos.py +2 -2
  33. birder/net/detection/retinanet.py +2 -2
  34. birder/net/detection/rt_detr_v1.py +4 -4
  35. birder/net/detection/ssd.py +2 -2
  36. birder/net/detection/ssdlite.py +2 -2
  37. birder/net/detection/yolo_v2.py +2 -2
  38. birder/net/detection/yolo_v3.py +2 -2
  39. birder/net/detection/yolo_v4.py +2 -2
  40. birder/net/edgenext.py +2 -2
  41. birder/net/edgevit.py +1 -1
  42. birder/net/efficientformer_v1.py +4 -4
  43. birder/net/efficientformer_v2.py +6 -6
  44. birder/net/efficientnet_lite.py +2 -2
  45. birder/net/efficientnet_v1.py +2 -2
  46. birder/net/efficientnet_v2.py +2 -2
  47. birder/net/efficientvim.py +3 -3
  48. birder/net/efficientvit_mit.py +2 -2
  49. birder/net/efficientvit_msft.py +2 -2
  50. birder/net/fasternet.py +2 -2
  51. birder/net/fastvit.py +2 -3
  52. birder/net/flexivit.py +11 -6
  53. birder/net/focalnet.py +2 -3
  54. birder/net/gc_vit.py +17 -2
  55. birder/net/ghostnet_v1.py +2 -2
  56. birder/net/ghostnet_v2.py +2 -2
  57. birder/net/groupmixformer.py +2 -2
  58. birder/net/hgnet_v1.py +2 -2
  59. birder/net/hgnet_v2.py +2 -2
  60. birder/net/hiera.py +2 -2
  61. birder/net/hieradet.py +2 -2
  62. birder/net/hornet.py +2 -2
  63. birder/net/iformer.py +2 -2
  64. birder/net/inception_next.py +2 -2
  65. birder/net/inception_resnet_v1.py +2 -2
  66. birder/net/inception_resnet_v2.py +2 -2
  67. birder/net/inception_v3.py +2 -2
  68. birder/net/inception_v4.py +2 -2
  69. birder/net/levit.py +4 -4
  70. birder/net/lit_v1.py +2 -2
  71. birder/net/lit_v1_tiny.py +2 -2
  72. birder/net/lit_v2.py +2 -2
  73. birder/net/maxvit.py +2 -2
  74. birder/net/metaformer.py +2 -2
  75. birder/net/mnasnet.py +2 -2
  76. birder/net/mobilenet_v1.py +2 -2
  77. birder/net/mobilenet_v2.py +2 -2
  78. birder/net/mobilenet_v3_large.py +2 -2
  79. birder/net/mobilenet_v4.py +2 -2
  80. birder/net/mobilenet_v4_hybrid.py +2 -2
  81. birder/net/mobileone.py +2 -2
  82. birder/net/mobilevit_v2.py +2 -2
  83. birder/net/moganet.py +2 -2
  84. birder/net/mvit_v2.py +2 -2
  85. birder/net/nextvit.py +2 -2
  86. birder/net/nfnet.py +2 -2
  87. birder/net/pit.py +6 -6
  88. birder/net/pvt_v1.py +2 -2
  89. birder/net/pvt_v2.py +2 -2
  90. birder/net/rdnet.py +2 -2
  91. birder/net/regionvit.py +6 -6
  92. birder/net/regnet.py +2 -2
  93. birder/net/regnet_z.py +2 -2
  94. birder/net/repghost.py +2 -2
  95. birder/net/repvgg.py +2 -2
  96. birder/net/repvit.py +6 -6
  97. birder/net/resnest.py +2 -2
  98. birder/net/resnet_v1.py +2 -2
  99. birder/net/resnet_v2.py +2 -2
  100. birder/net/resnext.py +2 -2
  101. birder/net/rope_deit3.py +3 -3
  102. birder/net/rope_flexivit.py +13 -6
  103. birder/net/rope_vit.py +69 -10
  104. birder/net/shufflenet_v1.py +2 -2
  105. birder/net/shufflenet_v2.py +2 -2
  106. birder/net/smt.py +1 -2
  107. birder/net/squeezenext.py +2 -2
  108. birder/net/ssl/byol.py +3 -2
  109. birder/net/ssl/capi.py +156 -11
  110. birder/net/ssl/data2vec.py +3 -1
  111. birder/net/ssl/data2vec2.py +3 -1
  112. birder/net/ssl/dino_v1.py +1 -1
  113. birder/net/ssl/dino_v2.py +140 -18
  114. birder/net/ssl/franca.py +145 -13
  115. birder/net/ssl/ibot.py +1 -2
  116. birder/net/ssl/mmcr.py +3 -1
  117. birder/net/starnet.py +2 -2
  118. birder/net/swiftformer.py +6 -6
  119. birder/net/swin_transformer_v1.py +2 -2
  120. birder/net/swin_transformer_v2.py +2 -2
  121. birder/net/tiny_vit.py +2 -2
  122. birder/net/transnext.py +1 -1
  123. birder/net/uniformer.py +1 -1
  124. birder/net/van.py +1 -1
  125. birder/net/vgg.py +1 -1
  126. birder/net/vgg_reduced.py +1 -1
  127. birder/net/vit.py +172 -8
  128. birder/net/vit_parallel.py +5 -5
  129. birder/net/vit_sam.py +3 -3
  130. birder/net/vovnet_v1.py +2 -2
  131. birder/net/vovnet_v2.py +2 -2
  132. birder/net/wide_resnet.py +2 -2
  133. birder/net/xception.py +2 -2
  134. birder/net/xcit.py +2 -2
  135. birder/results/detection.py +104 -0
  136. birder/results/gui.py +10 -8
  137. birder/scripts/benchmark.py +1 -1
  138. birder/scripts/train.py +13 -18
  139. birder/scripts/train_barlow_twins.py +10 -14
  140. birder/scripts/train_byol.py +11 -15
  141. birder/scripts/train_capi.py +38 -17
  142. birder/scripts/train_data2vec.py +11 -15
  143. birder/scripts/train_data2vec2.py +13 -17
  144. birder/scripts/train_detection.py +11 -14
  145. birder/scripts/train_dino_v1.py +20 -22
  146. birder/scripts/train_dino_v2.py +126 -63
  147. birder/scripts/train_dino_v2_dist.py +127 -64
  148. birder/scripts/train_franca.py +49 -34
  149. birder/scripts/train_i_jepa.py +11 -14
  150. birder/scripts/train_ibot.py +16 -18
  151. birder/scripts/train_kd.py +14 -20
  152. birder/scripts/train_mim.py +10 -13
  153. birder/scripts/train_mmcr.py +11 -15
  154. birder/scripts/train_rotnet.py +12 -16
  155. birder/scripts/train_simclr.py +10 -14
  156. birder/scripts/train_vicreg.py +10 -14
  157. birder/tools/avg_model.py +24 -8
  158. birder/tools/det_results.py +91 -0
  159. birder/tools/introspection.py +35 -9
  160. birder/tools/results.py +11 -7
  161. birder/tools/show_iterator.py +1 -1
  162. birder/version.py +1 -1
  163. {birder-0.3.0.dist-info → birder-0.3.2.dist-info}/METADATA +1 -1
  164. birder-0.3.2.dist-info/RECORD +299 -0
  165. birder-0.3.0.dist-info/RECORD +0 -298
  166. {birder-0.3.0.dist-info → birder-0.3.2.dist-info}/WHEEL +0 -0
  167. {birder-0.3.0.dist-info → birder-0.3.2.dist-info}/entry_points.txt +0 -0
  168. {birder-0.3.0.dist-info → birder-0.3.2.dist-info}/licenses/LICENSE +0 -0
  169. {birder-0.3.0.dist-info → birder-0.3.2.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,8 @@
2
2
  DeepFool
3
3
 
4
4
  Paper "DeepFool: a simple and accurate method to fool deep neural networks", https://arxiv.org/abs/1511.04599
5
+
6
+ Generated by gpt-5.2-codex xhigh.
5
7
  """
6
8
 
7
9
  from typing import Optional
@@ -2,6 +2,8 @@
2
2
  SimBA (Simple Black-box Attack)
3
3
 
4
4
  Paper "Simple Black-box Adversarial Attacks", https://arxiv.org/abs/1905.07121
5
+
6
+ Generated by gpt-5.2-codex xhigh.
5
7
  """
6
8
 
7
9
  from typing import Optional
birder/common/fs_ops.py CHANGED
@@ -627,7 +627,7 @@ def load_model(
627
627
  net.to(dtype)
628
628
  if inference is True:
629
629
  for param in net.parameters():
630
- param.requires_grad = False
630
+ param.requires_grad_(False)
631
631
 
632
632
  if pt2 is False: # NOTE: Remove when GraphModule add support for 'eval'
633
633
  net.eval()
@@ -799,7 +799,7 @@ def load_detection_model(
799
799
  net.to(dtype)
800
800
  if inference is True:
801
801
  for param in net.parameters():
802
- param.requires_grad = False
802
+ param.requires_grad_(False)
803
803
 
804
804
  net.eval()
805
805
 
birder/common/masking.py CHANGED
@@ -84,8 +84,8 @@ def mask_tensor(
84
84
 
85
85
  (B, H, W, _) = x.size()
86
86
 
87
- shaped_mask = mask.reshape(-1, H // patch_factor, W // patch_factor)
88
- shaped_mask = shaped_mask.repeat_interleave(patch_factor, axis=1).repeat_interleave(patch_factor, axis=2)
87
+ shaped_mask = mask.reshape(B, H // patch_factor, W // patch_factor)
88
+ shaped_mask = shaped_mask.repeat_interleave(patch_factor, dim=1).repeat_interleave(patch_factor, dim=2)
89
89
  shaped_mask = shaped_mask.unsqueeze(3).type_as(x)
90
90
 
91
91
  if mask_token is not None:
@@ -228,14 +228,23 @@ class Masking:
228
228
 
229
229
 
230
230
  class UniformMasking(Masking):
231
- def __init__(self, input_size: tuple[int, int], mask_ratio: float, device: Optional[torch.device] = None) -> None:
231
+ def __init__(
232
+ self,
233
+ input_size: tuple[int, int],
234
+ mask_ratio: float,
235
+ min_mask_size: int = 1,
236
+ device: Optional[torch.device] = None,
237
+ ) -> None:
232
238
  self.h = input_size[0]
233
239
  self.w = input_size[1]
234
240
  self.mask_ratio = mask_ratio
241
+ self.min_mask_size = min_mask_size
235
242
  self.device = device
236
243
 
237
244
  def __call__(self, batch_size: int) -> torch.Tensor:
238
- return uniform_mask(batch_size, self.h, self.w, self.mask_ratio, device=self.device)[0]
245
+ return uniform_mask(
246
+ batch_size, self.h, self.w, self.mask_ratio, min_mask_size=self.min_mask_size, device=self.device
247
+ )[0]
239
248
 
240
249
 
241
250
  class BlockMasking(Masking):
@@ -39,6 +39,7 @@ def add_optimization_args(parser: argparse.ArgumentParser, default_batch_size: i
39
39
  group = parser.add_argument_group("Optimization parameters")
40
40
  group.add_argument("--batch-size", type=int, default=default_batch_size, metavar="N", help="the batch size")
41
41
  group.add_argument("--opt", type=str, choices=list(get_args(OptimizerType)), default="sgd", help="optimizer to use")
42
+ group.add_argument("--opt-fused", default=False, action="store_true", help="use fused optimizer implementation")
42
43
  group.add_argument("--momentum", type=float, default=0.9, metavar="M", help="optimizer momentum")
43
44
  group.add_argument("--nesterov", default=False, action="store_true", help="use nesterov momentum")
44
45
  group.add_argument("--opt-eps", type=float, help="optimizer epsilon (None to use the optimizer default)")
@@ -249,6 +250,7 @@ def add_data_aug_args(
249
250
  default_level: int = 4,
250
251
  default_min_scale: Optional[float] = None,
251
252
  default_re_prob: Optional[float] = None,
253
+ smoothing_alpha: bool = False,
252
254
  mixup_cutmix: bool = False,
253
255
  ) -> None:
254
256
  group = parser.add_argument_group("Data augmentation parameters")
@@ -285,6 +287,8 @@ def add_data_aug_args(
285
287
  group.add_argument(
286
288
  "--simple-crop", default=False, action="store_true", help="use simple random crop (SRC) instead of RRC"
287
289
  )
290
+ if smoothing_alpha is True:
291
+ group.add_argument("--smoothing-alpha", type=float, default=0.0, help="label smoothing alpha")
288
292
  if mixup_cutmix is True:
289
293
  group.add_argument("--mixup-alpha", type=float, help="mixup alpha")
290
294
  group.add_argument("--cutmix", default=False, action="store_true", help="enable cutmix")
@@ -565,9 +569,9 @@ def add_training_data_args(parser: argparse.ArgumentParser, unsupervised: bool =
565
569
  group.add_argument("--wds", default=False, action="store_true", help="use webdataset for training")
566
570
  group.add_argument("--wds-info", type=str, metavar="FILE", help="wds info file path")
567
571
  group.add_argument("--wds-cache-dir", type=str, metavar="DIR", help="webdataset cache directory")
568
- group.add_argument("--wds-train-size", type=int, metavar="N", help="size of the wds training set")
569
572
  if unsupervised is False:
570
573
  group.add_argument("--wds-class-file", type=str, metavar="FILE", help="class list file")
574
+ group.add_argument("--wds-train-size", type=int, metavar="N", help="size of the wds training set")
571
575
  group.add_argument("--wds-val-size", type=int, metavar="N", help="size of the wds validation set")
572
576
  group.add_argument(
573
577
  "--wds-training-split", type=str, default="training", metavar="NAME", help="wds dataset train split"
@@ -576,6 +580,7 @@ def add_training_data_args(parser: argparse.ArgumentParser, unsupervised: bool =
576
580
  "--wds-val-split", type=str, default="validation", metavar="NAME", help="wds dataset validation split"
577
581
  )
578
582
  else:
583
+ group.add_argument("--wds-size", type=int, metavar="N", help="size of the wds")
579
584
  group.add_argument(
580
585
  "--wds-split", type=str, default="training", metavar="NAME", help="wds dataset split to load"
581
586
  )
@@ -593,12 +593,14 @@ def get_optimizer(parameters: list[dict[str, Any]], l_rate: float, args: argpars
593
593
  kwargs["betas"] = args.opt_betas
594
594
  if getattr(args, "opt_alpha", None) is not None:
595
595
  kwargs["alpha"] = args.opt_alpha
596
+ if getattr(args, "opt_fused", False) is True:
597
+ kwargs["fused"] = True
596
598
 
597
599
  # For optimizer compilation
598
600
  # lr = torch.tensor(l_rate) - Causes weird LR scheduling bugs
599
601
  lr = l_rate
600
- if getattr(args, "compile_opt", False) is not False:
601
- if opt not in ("lamb", "lambw", "lars"):
602
+ if getattr(args, "compile_opt", False) is True:
603
+ if opt not in ("sgd", "lamb", "lambw", "lars"):
602
604
  logger.debug("Setting optimizer capturable to True")
603
605
  kwargs["capturable"] = True
604
606
 
@@ -85,7 +85,7 @@ def infer_batch(
85
85
  logits = net(t(tta_input), **kwargs)
86
86
  outs.append(logits if return_logits is True else F.softmax(logits, dim=1))
87
87
 
88
- out = torch.stack(outs).mean(axis=0)
88
+ out = torch.stack(outs).mean(dim=0)
89
89
 
90
90
  else:
91
91
  logits = net(inputs, **kwargs)
@@ -1,5 +1,6 @@
1
1
  from birder.introspection.attention_rollout import AttentionRollout
2
2
  from birder.introspection.base import InterpretabilityResult
3
+ from birder.introspection.feature_pca import FeaturePCA
3
4
  from birder.introspection.gradcam import GradCAM
4
5
  from birder.introspection.guided_backprop import GuidedBackprop
5
6
  from birder.introspection.transformer_attribution import TransformerAttribution
@@ -7,6 +8,7 @@ from birder.introspection.transformer_attribution import TransformerAttribution
7
8
  __all__ = [
8
9
  "InterpretabilityResult",
9
10
  "AttentionRollout",
11
+ "FeaturePCA",
10
12
  "GradCAM",
11
13
  "GuidedBackprop",
12
14
  "TransformerAttribution",
@@ -2,7 +2,6 @@ from collections.abc import Callable
2
2
  from dataclasses import dataclass
3
3
  from pathlib import Path
4
4
  from typing import Optional
5
- from typing import Protocol
6
5
 
7
6
  import matplotlib
8
7
  import matplotlib.pyplot as plt
@@ -27,12 +26,6 @@ class InterpretabilityResult:
27
26
  plt.show()
28
27
 
29
28
 
30
- class Interpreter(Protocol):
31
- def __call__(
32
- self, image: str | Path | Image.Image, target_class: Optional[int] = None
33
- ) -> InterpretabilityResult: ...
34
-
35
-
36
29
  def load_image(image: str | Path | Image.Image) -> Image.Image:
37
30
  if isinstance(image, (str, Path)):
38
31
  return Image.open(image)
@@ -0,0 +1,101 @@
1
+ from collections.abc import Callable
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image
8
+ from sklearn.decomposition import PCA
9
+
10
+ from birder.introspection.base import InterpretabilityResult
11
+ from birder.introspection.base import preprocess_image
12
+ from birder.net.base import DetectorBackbone
13
+
14
+
15
+ class FeaturePCA:
16
+ """
17
+ Visualizes feature maps using Principal Component Analysis
18
+
19
+ This method extracts feature maps from a specified stage of a DetectorBackbone model,
20
+ applies PCA to reduce the channel dimension to 3 components, and visualizes them as an RGB image where:
21
+ - R channel = 1st principal component (most important)
22
+ - G channel = 2nd principal component
23
+ - B channel = 3rd principal component
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ net: DetectorBackbone,
29
+ device: torch.device,
30
+ transform: Callable[..., torch.Tensor],
31
+ normalize: bool = False,
32
+ channels_last: bool = False,
33
+ stage: Optional[str] = None,
34
+ ) -> None:
35
+ self.net = net.eval()
36
+ self.device = device
37
+ self.transform = transform
38
+ self.normalize = normalize
39
+ self.channels_last = channels_last
40
+ self.stage = stage
41
+
42
+ def __call__(self, image: str | Path | Image.Image) -> InterpretabilityResult:
43
+ (input_tensor, rgb_img) = preprocess_image(image, self.transform, self.device)
44
+
45
+ with torch.inference_mode():
46
+ features_dict = self.net.detection_features(input_tensor)
47
+
48
+ if self.stage is not None:
49
+ features = features_dict[self.stage]
50
+ else:
51
+ features = list(features_dict.values())[-1] # Use the last stage by default
52
+
53
+ features_np = features.cpu().numpy()
54
+
55
+ # Handle channels_last format (B, H, W, C) vs channels_first (B, C, H, W)
56
+ if self.channels_last is True:
57
+ (B, H, W, C) = features_np.shape
58
+ # Already in (B, H, W, C), just reshape to (B*H*W, C)
59
+ features_reshaped = features_np.reshape(-1, C)
60
+ else:
61
+ (B, C, H, W) = features_np.shape
62
+ # Reshape to (spatial_points, channels) for PCA
63
+ features_reshaped = features_np.reshape(B, C, -1)
64
+ features_reshaped = features_reshaped.transpose(0, 2, 1) # (B, H*W, C)
65
+ features_reshaped = features_reshaped.reshape(-1, C) # (B*H*W, C)
66
+
67
+ x = features_reshaped
68
+ if self.normalize is True:
69
+ x = x / (np.linalg.norm(x, axis=1, keepdims=True) + 1e-6)
70
+
71
+ pca = PCA(n_components=3)
72
+ pca_features = pca.fit_transform(x)
73
+ pca_features = pca_features.reshape(B, H, W, 3)
74
+
75
+ # Extract all 3 components (B=1)
76
+ pca_rgb = pca_features[0] # (H, W, 3)
77
+
78
+ # Normalize each channel independently to [0, 1]
79
+ for i in range(3):
80
+ channel = pca_rgb[:, :, i]
81
+ channel = channel - channel.min()
82
+ channel = channel / (channel.max() + 1e-7)
83
+ pca_rgb[:, :, i] = channel
84
+
85
+ target_size = (input_tensor.size(-1), input_tensor.size(-2)) # PIL expects (width, height)
86
+ pca_rgb_resized = (
87
+ np.array(
88
+ Image.fromarray((pca_rgb * 255).astype(np.uint8)).resize(target_size, Image.Resampling.BILINEAR)
89
+ ).astype(np.float32)
90
+ / 255.0
91
+ )
92
+
93
+ visualization = (pca_rgb_resized * 255).astype(np.uint8)
94
+
95
+ return InterpretabilityResult(
96
+ original_image=rgb_img,
97
+ visualization=visualization,
98
+ raw_output=pca_rgb.astype(np.float32),
99
+ logits=None,
100
+ predicted_class=None,
101
+ )
@@ -4,6 +4,9 @@
4
4
  * Taken from:
5
5
  * https://github.com/MrParosk/soft_nms
6
6
  * Licensed under the MIT License
7
+ *
8
+ * Modified by:
9
+ * Ofer Hasson — 2026-01-10
7
10
  **************************************************************************************************
8
11
  */
9
12
 
@@ -40,8 +43,8 @@ torch::Tensor calculate_iou(const torch::Tensor& boxes, const torch::Tensor& are
40
43
  auto xx2 = torch::minimum(boxes.index({idx, 2}), boxes.index({Slice(idx + 1, None), 2}));
41
44
  auto yy2 = torch::minimum(boxes.index({idx, 3}), boxes.index({Slice(idx + 1, None), 3}));
42
45
 
43
- auto w = torch::maximum(torch::zeros_like(xx1), xx2 - xx1);
44
- auto h = torch::maximum(torch::zeros_like(yy1), yy2 - yy1);
46
+ auto w = (xx2 - xx1).clamp_min(0);
47
+ auto h = (yy2 - yy1).clamp_min(0);
45
48
 
46
49
  auto intersection = w * h;
47
50
  auto union_ = areas.index({idx}) + areas.index({Slice(idx + 1, None)}) - intersection;
@@ -87,14 +87,15 @@ class ModelRegistry:
87
87
  no further registration is needed.
88
88
  """
89
89
 
90
+ alias_key = alias.lower()
90
91
  if net_type.auto_register is False:
91
92
  # Register the model manually, as the base class doesn't take care of that for us
92
- registry.register_model(alias, type(alias, (net_type,), {"config": config}))
93
+ self.register_model(alias_key, type(alias, (net_type,), {"config": config}))
93
94
 
94
95
  if alias in self.aliases:
95
96
  warnings.warn(f"Alias {alias} is already registered", UserWarning)
96
97
 
97
- self.aliases[alias] = type(alias, (net_type,), {"config": config})
98
+ self.aliases[alias_key] = type(alias, (net_type,), {"config": config})
98
99
 
99
100
  def register_weights(self, name: str, weights_info: manifest.ModelMetadataType) -> None:
100
101
  if name in self._pretrained_nets:
birder/net/base.py CHANGED
@@ -173,14 +173,14 @@ class BaseNet(nn.Module):
173
173
 
174
174
  def freeze(self, freeze_classifier: bool = True, unfreeze_features: bool = False) -> None:
175
175
  for param in self.parameters():
176
- param.requires_grad = False
176
+ param.requires_grad_(False)
177
177
 
178
178
  if freeze_classifier is False:
179
179
  for param in self.classifier.parameters():
180
- param.requires_grad = True
180
+ param.requires_grad_(True)
181
181
  if unfreeze_features is True and hasattr(self, "features") is True:
182
182
  for param in self.features.parameters():
183
- param.requires_grad = True
183
+ param.requires_grad_(True)
184
184
 
185
185
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
186
186
  """
birder/net/biformer.py CHANGED
@@ -468,14 +468,14 @@ class BiFormer(DetectorBackbone):
468
468
 
469
469
  def freeze_stages(self, up_to_stage: int) -> None:
470
470
  for param in self.stem.parameters():
471
- param.requires_grad = False
471
+ param.requires_grad_(False)
472
472
 
473
473
  for idx, module in enumerate(self.body.children()):
474
474
  if idx >= up_to_stage:
475
475
  break
476
476
 
477
477
  for param in module.parameters():
478
- param.requires_grad = False
478
+ param.requires_grad_(False)
479
479
 
480
480
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
481
481
  x = self.stem(x)
birder/net/cas_vit.py CHANGED
@@ -269,18 +269,18 @@ class CAS_ViT(DetectorBackbone):
269
269
 
270
270
  def freeze(self, freeze_classifier: bool = True, unfreeze_features: bool = False) -> None:
271
271
  for param in self.parameters():
272
- param.requires_grad = False
272
+ param.requires_grad_(False)
273
273
 
274
274
  if freeze_classifier is False:
275
275
  for param in self.classifier.parameters():
276
- param.requires_grad = True
276
+ param.requires_grad_(True)
277
277
 
278
278
  for param in self.dist_classifier.parameters():
279
- param.requires_grad = True
279
+ param.requires_grad_(True)
280
280
 
281
281
  if unfreeze_features is True:
282
282
  for param in self.features.parameters():
283
- param.requires_grad = True
283
+ param.requires_grad_(True)
284
284
 
285
285
  def transform_to_backbone(self) -> None:
286
286
  self.features = nn.Identity()
@@ -300,14 +300,14 @@ class CAS_ViT(DetectorBackbone):
300
300
 
301
301
  def freeze_stages(self, up_to_stage: int) -> None:
302
302
  for param in self.stem.parameters():
303
- param.requires_grad = False
303
+ param.requires_grad_(False)
304
304
 
305
305
  for idx, module in enumerate(self.body.children()):
306
306
  if idx >= up_to_stage:
307
307
  break
308
308
 
309
309
  for param in module.parameters():
310
- param.requires_grad = False
310
+ param.requires_grad_(False)
311
311
 
312
312
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
313
313
  x = self.stem(x)
birder/net/coat.py CHANGED
@@ -563,24 +563,24 @@ class CoaT(DetectorBackbone):
563
563
  def freeze_stages(self, up_to_stage: int) -> None:
564
564
  if up_to_stage >= 1:
565
565
  for param in self.patch_embed1.parameters():
566
- param.requires_grad = False
566
+ param.requires_grad_(False)
567
567
  for param in self.serial_blocks1.parameters():
568
- param.requires_grad = False
568
+ param.requires_grad_(False)
569
569
  if up_to_stage >= 2:
570
570
  for param in self.patch_embed2.parameters():
571
- param.requires_grad = False
571
+ param.requires_grad_(False)
572
572
  for param in self.serial_blocks2.parameters():
573
- param.requires_grad = False
573
+ param.requires_grad_(False)
574
574
  if up_to_stage >= 3:
575
575
  for param in self.patch_embed3.parameters():
576
- param.requires_grad = False
576
+ param.requires_grad_(False)
577
577
  for param in self.serial_blocks3.parameters():
578
- param.requires_grad = False
578
+ param.requires_grad_(False)
579
579
  if up_to_stage >= 4:
580
580
  for param in self.patch_embed4.parameters():
581
- param.requires_grad = False
581
+ param.requires_grad_(False)
582
582
  for param in self.serial_blocks4.parameters():
583
- param.requires_grad = False
583
+ param.requires_grad_(False)
584
584
 
585
585
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
586
586
  features = self._features(x)
birder/net/conv2former.py CHANGED
@@ -218,14 +218,14 @@ class Conv2Former(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
218
218
 
219
219
  def freeze_stages(self, up_to_stage: int) -> None:
220
220
  for param in self.stem.parameters():
221
- param.requires_grad = False
221
+ param.requires_grad_(False)
222
222
 
223
223
  for idx, module in enumerate(self.body.children()):
224
224
  if idx >= up_to_stage:
225
225
  break
226
226
 
227
227
  for param in module.parameters():
228
- param.requires_grad = False
228
+ param.requires_grad_(False)
229
229
 
230
230
  def masked_encoding_retention(
231
231
  self,
birder/net/convnext_v1.py CHANGED
@@ -158,14 +158,14 @@ class ConvNeXt_v1(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
158
158
 
159
159
  def freeze_stages(self, up_to_stage: int) -> None:
160
160
  for param in self.stem.parameters():
161
- param.requires_grad = False
161
+ param.requires_grad_(False)
162
162
 
163
163
  for idx, module in enumerate(self.body.children()):
164
164
  if idx >= up_to_stage:
165
165
  break
166
166
 
167
167
  for param in module.parameters():
168
- param.requires_grad = False
168
+ param.requires_grad_(False)
169
169
 
170
170
  def masked_encoding_retention(
171
171
  self,
@@ -195,6 +195,21 @@ class ConvNeXt_v1(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
195
195
  return self.features(x)
196
196
 
197
197
 
198
+ registry.register_model_config(
199
+ "convnext_v1_atto", # Not in the original v1, taken from v2
200
+ ConvNeXt_v1,
201
+ config={"in_channels": [40, 80, 160, 320], "num_layers": [2, 2, 6, 2], "drop_path_rate": 0.0},
202
+ )
203
+ registry.register_model_config(
204
+ "convnext_v1_femto", # Not in the original v1, taken from v2
205
+ ConvNeXt_v1,
206
+ config={"in_channels": [48, 96, 192, 384], "num_layers": [2, 2, 6, 2], "drop_path_rate": 0.0},
207
+ )
208
+ registry.register_model_config(
209
+ "convnext_v1_pico", # Not in the original v1, taken from v2
210
+ ConvNeXt_v1,
211
+ config={"in_channels": [64, 128, 256, 512], "num_layers": [2, 2, 6, 2], "drop_path_rate": 0.0},
212
+ )
198
213
  registry.register_model_config(
199
214
  "convnext_v1_nano", # Not in the original v1, taken from v2
200
215
  ConvNeXt_v1,
@@ -220,6 +235,11 @@ registry.register_model_config(
220
235
  ConvNeXt_v1,
221
236
  config={"in_channels": [192, 384, 768, 1536], "num_layers": [3, 3, 27, 3], "drop_path_rate": 0.5},
222
237
  )
238
+ registry.register_model_config(
239
+ "convnext_v1_huge", # Not in the original v1, taken from v2
240
+ ConvNeXt_v1,
241
+ config={"in_channels": [352, 704, 1408, 2816], "num_layers": [3, 3, 27, 3], "drop_path_rate": 0.5},
242
+ )
223
243
 
224
244
  registry.register_weights(
225
245
  "convnext_v1_tiny_eu-common256px",
birder/net/convnext_v2.py CHANGED
@@ -180,14 +180,14 @@ class ConvNeXt_v2(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
180
180
 
181
181
  def freeze_stages(self, up_to_stage: int) -> None:
182
182
  for param in self.stem.parameters():
183
- param.requires_grad = False
183
+ param.requires_grad_(False)
184
184
 
185
185
  for idx, module in enumerate(self.body.children()):
186
186
  if idx >= up_to_stage:
187
187
  break
188
188
 
189
189
  for param in module.parameters():
190
- param.requires_grad = False
190
+ param.requires_grad_(False)
191
191
 
192
192
  def masked_encoding_retention(
193
193
  self,
birder/net/crossformer.py CHANGED
@@ -404,14 +404,14 @@ class CrossFormer(DetectorBackbone):
404
404
 
405
405
  def freeze_stages(self, up_to_stage: int) -> None:
406
406
  for param in self.patch_embed.parameters():
407
- param.requires_grad = False
407
+ param.requires_grad_(False)
408
408
 
409
409
  for idx, module in enumerate(self.body.children()):
410
410
  if idx >= up_to_stage:
411
411
  break
412
412
 
413
413
  for param in module.parameters():
414
- param.requires_grad = False
414
+ param.requires_grad_(False)
415
415
 
416
416
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
417
417
  x = self.patch_embed(x)
birder/net/cspnet.py CHANGED
@@ -342,14 +342,14 @@ class CSPNet(DetectorBackbone):
342
342
 
343
343
  def freeze_stages(self, up_to_stage: int) -> None:
344
344
  for param in self.stem.parameters():
345
- param.requires_grad = False
345
+ param.requires_grad_(False)
346
346
 
347
347
  for idx, module in enumerate(self.body.children()):
348
348
  if idx >= up_to_stage:
349
349
  break
350
350
 
351
351
  for param in module.parameters():
352
- param.requires_grad = False
352
+ param.requires_grad_(False)
353
353
 
354
354
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
355
355
  x = self.stem(x)
@@ -359,14 +359,14 @@ class CSWin_Transformer(DetectorBackbone):
359
359
 
360
360
  def freeze_stages(self, up_to_stage: int) -> None:
361
361
  for param in self.stem.parameters():
362
- param.requires_grad = False
362
+ param.requires_grad_(False)
363
363
 
364
364
  for idx, module in enumerate(self.body.children()):
365
365
  if idx >= up_to_stage:
366
366
  break
367
367
 
368
368
  for param in module.parameters():
369
- param.requires_grad = False
369
+ param.requires_grad_(False)
370
370
 
371
371
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
372
372
  x = self.stem(x)
birder/net/darknet.py CHANGED
@@ -115,14 +115,14 @@ class Darknet(DetectorBackbone):
115
115
 
116
116
  def freeze_stages(self, up_to_stage: int) -> None:
117
117
  for param in self.stem.parameters():
118
- param.requires_grad = False
118
+ param.requires_grad_(False)
119
119
 
120
120
  for idx, module in enumerate(self.body.children()):
121
121
  if idx >= up_to_stage:
122
122
  break
123
123
 
124
124
  for param in module.parameters():
125
- param.requires_grad = False
125
+ param.requires_grad_(False)
126
126
 
127
127
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
128
128
  x = self.stem(x)
birder/net/davit.py CHANGED
@@ -391,14 +391,14 @@ class DaViT(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
391
391
 
392
392
  def freeze_stages(self, up_to_stage: int) -> None:
393
393
  for param in self.stem.parameters():
394
- param.requires_grad = False
394
+ param.requires_grad_(False)
395
395
 
396
396
  for idx, module in enumerate(self.body.children()):
397
397
  if idx >= up_to_stage:
398
398
  break
399
399
 
400
400
  for param in module.parameters():
401
- param.requires_grad = False
401
+ param.requires_grad_(False)
402
402
 
403
403
  def masked_encoding_retention(
404
404
  self,
birder/net/deit.py CHANGED
@@ -117,14 +117,14 @@ class DeiT(BaseNet):
117
117
 
118
118
  def freeze(self, freeze_classifier: bool = True, unfreeze_features: bool = False) -> None:
119
119
  for param in self.parameters():
120
- param.requires_grad = False
120
+ param.requires_grad_(False)
121
121
 
122
122
  if freeze_classifier is False:
123
123
  for param in self.classifier.parameters():
124
- param.requires_grad = True
124
+ param.requires_grad_(True)
125
125
 
126
126
  for param in self.dist_classifier.parameters():
127
- param.requires_grad = True
127
+ param.requires_grad_(True)
128
128
 
129
129
  def set_causal_attention(self, is_causal: bool = True) -> None:
130
130
  self.encoder.set_causal_attention(is_causal)
birder/net/deit3.py CHANGED
@@ -182,16 +182,16 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
182
182
 
183
183
  def freeze_stages(self, up_to_stage: int) -> None:
184
184
  for param in self.conv_proj.parameters():
185
- param.requires_grad = False
185
+ param.requires_grad_(False)
186
186
 
187
- self.pos_embedding.requires_grad = False
187
+ self.pos_embedding.requires_grad_(False)
188
188
 
189
189
  for idx, module in enumerate(self.encoder.children()):
190
190
  if idx >= up_to_stage:
191
191
  break
192
192
 
193
193
  for param in module.parameters():
194
- param.requires_grad = False
194
+ param.requires_grad_(False)
195
195
 
196
196
  def set_causal_attention(self, is_causal: bool = True) -> None:
197
197
  self.encoder.set_causal_attention(is_causal)