birder 0.2.2__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. birder/common/lib.py +2 -9
  2. birder/common/training_cli.py +24 -0
  3. birder/common/training_utils.py +338 -41
  4. birder/data/collators/detection.py +11 -3
  5. birder/data/dataloader/webdataset.py +12 -2
  6. birder/data/datasets/coco.py +8 -10
  7. birder/data/transforms/detection.py +30 -13
  8. birder/inference/detection.py +108 -4
  9. birder/inference/wbf.py +226 -0
  10. birder/kernels/load_kernel.py +16 -11
  11. birder/kernels/soft_nms/soft_nms.cpp +17 -18
  12. birder/net/__init__.py +8 -0
  13. birder/net/cait.py +4 -3
  14. birder/net/convnext_v1.py +5 -0
  15. birder/net/crossformer.py +33 -30
  16. birder/net/crossvit.py +4 -3
  17. birder/net/deit.py +3 -3
  18. birder/net/deit3.py +3 -3
  19. birder/net/detection/deformable_detr.py +2 -5
  20. birder/net/detection/detr.py +2 -5
  21. birder/net/detection/efficientdet.py +67 -93
  22. birder/net/detection/fcos.py +2 -7
  23. birder/net/detection/retinanet.py +2 -7
  24. birder/net/detection/rt_detr_v1.py +2 -0
  25. birder/net/detection/yolo_anchors.py +205 -0
  26. birder/net/detection/yolo_v2.py +25 -24
  27. birder/net/detection/yolo_v3.py +39 -40
  28. birder/net/detection/yolo_v4.py +28 -26
  29. birder/net/detection/yolo_v4_tiny.py +24 -20
  30. birder/net/efficientformer_v1.py +15 -9
  31. birder/net/efficientformer_v2.py +39 -29
  32. birder/net/efficientvit_msft.py +9 -7
  33. birder/net/fasternet.py +1 -1
  34. birder/net/fastvit.py +1 -0
  35. birder/net/flexivit.py +5 -4
  36. birder/net/gc_vit.py +671 -0
  37. birder/net/hiera.py +12 -9
  38. birder/net/hornet.py +9 -7
  39. birder/net/iformer.py +8 -6
  40. birder/net/levit.py +42 -30
  41. birder/net/lit_v1.py +472 -0
  42. birder/net/lit_v1_tiny.py +357 -0
  43. birder/net/lit_v2.py +436 -0
  44. birder/net/maxvit.py +67 -55
  45. birder/net/mobilenet_v4_hybrid.py +1 -1
  46. birder/net/mobileone.py +1 -0
  47. birder/net/mvit_v2.py +13 -12
  48. birder/net/pit.py +4 -3
  49. birder/net/pvt_v1.py +4 -1
  50. birder/net/repghost.py +1 -0
  51. birder/net/repvgg.py +1 -0
  52. birder/net/repvit.py +1 -0
  53. birder/net/resnet_v1.py +1 -1
  54. birder/net/resnext.py +67 -25
  55. birder/net/rope_deit3.py +5 -3
  56. birder/net/rope_flexivit.py +7 -4
  57. birder/net/rope_vit.py +10 -5
  58. birder/net/se_resnet_v1.py +46 -0
  59. birder/net/se_resnext.py +3 -0
  60. birder/net/simple_vit.py +11 -8
  61. birder/net/swin_transformer_v1.py +71 -68
  62. birder/net/swin_transformer_v2.py +38 -31
  63. birder/net/tiny_vit.py +20 -10
  64. birder/net/transnext.py +38 -28
  65. birder/net/vit.py +5 -19
  66. birder/net/vit_parallel.py +5 -4
  67. birder/net/vit_sam.py +38 -37
  68. birder/net/vovnet_v1.py +15 -0
  69. birder/net/vovnet_v2.py +31 -1
  70. birder/ops/msda.py +108 -43
  71. birder/ops/swattention.py +124 -61
  72. birder/results/detection.py +4 -0
  73. birder/scripts/benchmark.py +110 -32
  74. birder/scripts/predict.py +8 -0
  75. birder/scripts/predict_detection.py +18 -11
  76. birder/scripts/train.py +48 -46
  77. birder/scripts/train_barlow_twins.py +44 -45
  78. birder/scripts/train_byol.py +44 -45
  79. birder/scripts/train_capi.py +50 -49
  80. birder/scripts/train_data2vec.py +45 -47
  81. birder/scripts/train_data2vec2.py +45 -47
  82. birder/scripts/train_detection.py +83 -50
  83. birder/scripts/train_dino_v1.py +60 -47
  84. birder/scripts/train_dino_v2.py +86 -52
  85. birder/scripts/train_dino_v2_dist.py +84 -50
  86. birder/scripts/train_franca.py +51 -52
  87. birder/scripts/train_i_jepa.py +45 -47
  88. birder/scripts/train_ibot.py +51 -53
  89. birder/scripts/train_kd.py +194 -76
  90. birder/scripts/train_mim.py +44 -45
  91. birder/scripts/train_mmcr.py +44 -45
  92. birder/scripts/train_rotnet.py +45 -46
  93. birder/scripts/train_simclr.py +44 -45
  94. birder/scripts/train_vicreg.py +44 -45
  95. birder/tools/auto_anchors.py +20 -1
  96. birder/tools/convert_model.py +18 -15
  97. birder/tools/det_results.py +114 -2
  98. birder/tools/pack.py +172 -103
  99. birder/tools/quantize_model.py +73 -67
  100. birder/tools/show_det_iterator.py +10 -1
  101. birder/version.py +1 -1
  102. {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/METADATA +4 -3
  103. {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/RECORD +107 -101
  104. {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/WHEEL +0 -0
  105. {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/entry_points.txt +0 -0
  106. {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/licenses/LICENSE +0 -0
  107. {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/top_level.txt +0 -0
birder/net/__init__.py CHANGED
@@ -31,6 +31,7 @@ from birder.net.fasternet import FasterNet
31
31
  from birder.net.fastvit import FastViT
32
32
  from birder.net.flexivit import FlexiViT
33
33
  from birder.net.focalnet import FocalNet
34
+ from birder.net.gc_vit import GC_ViT
34
35
  from birder.net.ghostnet_v1 import GhostNet_v1
35
36
  from birder.net.ghostnet_v2 import GhostNet_v2
36
37
  from birder.net.groupmixformer import GroupMixFormer
@@ -46,6 +47,9 @@ from birder.net.inception_resnet_v2 import Inception_ResNet_v2
46
47
  from birder.net.inception_v3 import Inception_v3
47
48
  from birder.net.inception_v4 import Inception_v4
48
49
  from birder.net.levit import LeViT
50
+ from birder.net.lit_v1 import LIT_v1
51
+ from birder.net.lit_v1_tiny import LIT_v1_Tiny
52
+ from birder.net.lit_v2 import LIT_v2
49
53
  from birder.net.maxvit import MaxViT
50
54
  from birder.net.metaformer import MetaFormer
51
55
  from birder.net.mnasnet import MNASNet
@@ -143,6 +147,7 @@ __all__ = [
143
147
  "FastViT",
144
148
  "FlexiViT",
145
149
  "FocalNet",
150
+ "GC_ViT",
146
151
  "GhostNet_v1",
147
152
  "GhostNet_v2",
148
153
  "GroupMixFormer",
@@ -158,6 +163,9 @@ __all__ = [
158
163
  "Inception_v3",
159
164
  "Inception_v4",
160
165
  "LeViT",
166
+ "LIT_v1",
167
+ "LIT_v1_Tiny",
168
+ "LIT_v2",
161
169
  "MaxViT",
162
170
  "MetaFormer",
163
171
  "MNASNet",
birder/net/cait.py CHANGED
@@ -268,14 +268,15 @@ class CaiT(BaseNet):
268
268
  super().adjust_size(new_size)
269
269
 
270
270
  # Add back class tokens
271
- self.pos_embed = nn.Parameter(
272
- adjust_position_embedding(
271
+ with torch.no_grad():
272
+ pos_embed = adjust_position_embedding(
273
273
  self.pos_embed,
274
274
  (old_size[0] // self.patch_size[0], old_size[1] // self.patch_size[1]),
275
275
  (new_size[0] // self.patch_size[0], new_size[1] // self.patch_size[1]),
276
276
  0,
277
277
  )
278
- )
278
+
279
+ self.pos_embed = nn.Parameter(pos_embed)
279
280
 
280
281
 
281
282
  registry.register_model_config(
birder/net/convnext_v1.py CHANGED
@@ -195,6 +195,11 @@ class ConvNeXt_v1(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
195
195
  return self.features(x)
196
196
 
197
197
 
198
+ registry.register_model_config(
199
+ "convnext_v1_nano", # Not in the original v1, taken from v2
200
+ ConvNeXt_v1,
201
+ config={"in_channels": [80, 160, 320, 640], "num_layers": [2, 2, 8, 2], "drop_path_rate": 0.1},
202
+ )
198
203
  registry.register_model_config(
199
204
  "convnext_v1_tiny",
200
205
  ConvNeXt_v1,
birder/net/crossformer.py CHANGED
@@ -98,15 +98,17 @@ class Attention(nn.Module):
98
98
  self.proj_drop = nn.Dropout(proj_drop)
99
99
 
100
100
  def define_bias_table(self) -> None:
101
- position_bias_h = torch.arange(1 - self.group_size[0], self.group_size[0])
102
- position_bias_w = torch.arange(1 - self.group_size[1], self.group_size[1])
101
+ device = next(self.pos.parameters()).device
102
+ position_bias_h = torch.arange(1 - self.group_size[0], self.group_size[0], device=device)
103
+ position_bias_w = torch.arange(1 - self.group_size[1], self.group_size[1], device=device)
103
104
  biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w], indexing="ij")) # 2, 2Wh-1, 2W2-1
104
105
  biases = biases.flatten(1).transpose(0, 1).float()
105
106
  self.biases = nn.Buffer(biases)
106
107
 
107
108
  def define_relative_position_index(self) -> None:
108
- coords_h = torch.arange(self.group_size[0])
109
- coords_w = torch.arange(self.group_size[1])
109
+ device = self.biases.device
110
+ coords_h = torch.arange(self.group_size[0], device=device)
111
+ coords_w = torch.arange(self.group_size[1], device=device)
110
112
  coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) # 2, Wh, Ww
111
113
  coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
112
114
  relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
@@ -430,32 +432,33 @@ class CrossFormer(DetectorBackbone):
430
432
 
431
433
  new_patch_resolution = (new_size[0] // self.patch_sizes[0], new_size[1] // self.patch_sizes[0])
432
434
  input_resolution = new_patch_resolution
433
- for mod in self.body.modules():
434
- if isinstance(mod, CrossFormerStage):
435
- for m in mod.modules():
436
- if isinstance(m, PatchMerging):
437
- m.input_resolution = input_resolution
438
- input_resolution = (input_resolution[0] // 2, input_resolution[1] // 2)
439
- elif isinstance(m, CrossFormerBlock):
440
- m.input_resolution = input_resolution
441
-
442
- mod.resolution = input_resolution
443
-
444
- new_group_size = (int(new_size[0] / (2**5)), int(new_size[1] / (2**5)))
445
- for m in self.body.modules():
446
- if isinstance(m, CrossFormerBlock):
447
- m.group_size = new_group_size
448
- if m.input_resolution[0] <= m.group_size[0]:
449
- m.use_lda = False
450
- m.group_size = (m.input_resolution[0], m.group_size[1])
451
- if m.input_resolution[1] <= m.group_size[1]:
452
- m.use_lda = False
453
- m.group_size = (m.group_size[0], m.input_resolution[1])
454
-
455
- elif isinstance(m, Attention):
456
- m.group_size = new_group_size
457
- m.define_bias_table()
458
- m.define_relative_position_index()
435
+ with torch.no_grad():
436
+ for mod in self.body.modules():
437
+ if isinstance(mod, CrossFormerStage):
438
+ for m in mod.modules():
439
+ if isinstance(m, PatchMerging):
440
+ m.input_resolution = input_resolution
441
+ input_resolution = (input_resolution[0] // 2, input_resolution[1] // 2)
442
+ elif isinstance(m, CrossFormerBlock):
443
+ m.input_resolution = input_resolution
444
+
445
+ mod.resolution = input_resolution
446
+
447
+ new_group_size = (int(new_size[0] / (2**5)), int(new_size[1] / (2**5)))
448
+ for m in self.body.modules():
449
+ if isinstance(m, CrossFormerBlock):
450
+ m.group_size = new_group_size
451
+ if m.input_resolution[0] <= m.group_size[0]:
452
+ m.use_lda = False
453
+ m.group_size = (m.input_resolution[0], m.group_size[1])
454
+ if m.input_resolution[1] <= m.group_size[1]:
455
+ m.use_lda = False
456
+ m.group_size = (m.group_size[0], m.input_resolution[1])
457
+
458
+ elif isinstance(m, Attention):
459
+ m.group_size = new_group_size
460
+ m.define_bias_table()
461
+ m.define_relative_position_index()
459
462
 
460
463
 
461
464
  registry.register_model_config(
birder/net/crossvit.py CHANGED
@@ -359,9 +359,10 @@ class CrossViT(BaseNet):
359
359
  old_w = old_size[1] // self.patch_size[i]
360
360
  h = new_size[0] // self.patch_size[i]
361
361
  w = new_size[1] // self.patch_size[i]
362
- self.pos_embed[i] = nn.Parameter(
363
- adjust_position_embedding(self.pos_embed[i], (old_h, old_w), (h, w), num_prefix_tokens=1)
364
- )
362
+ with torch.no_grad():
363
+ pos_embed = adjust_position_embedding(self.pos_embed[i], (old_h, old_w), (h, w), num_prefix_tokens=1)
364
+
365
+ self.pos_embed[i] = nn.Parameter(pos_embed)
365
366
 
366
367
 
367
368
  registry.register_model_config(
birder/net/deit.py CHANGED
@@ -187,14 +187,14 @@ class DeiT(BaseNet):
187
187
  num_prefix_tokens = 2
188
188
 
189
189
  # Add back class tokens
190
- self.pos_embedding = nn.Parameter(
191
- adjust_position_embedding(
190
+ with torch.no_grad():
191
+ pos_embedding = adjust_position_embedding(
192
192
  self.pos_embedding,
193
193
  (old_size[0] // self.patch_size, old_size[1] // self.patch_size),
194
194
  (new_size[0] // self.patch_size, new_size[1] // self.patch_size),
195
195
  num_prefix_tokens,
196
196
  )
197
- )
197
+ self.pos_embedding = nn.Parameter(pos_embedding)
198
198
 
199
199
 
200
200
  registry.register_model_config(
birder/net/deit3.py CHANGED
@@ -355,14 +355,14 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
355
355
  num_prefix_tokens = 0
356
356
 
357
357
  # Add back class tokens
358
- self.pos_embedding = nn.Parameter(
359
- adjust_position_embedding(
358
+ with torch.no_grad():
359
+ pos_embedding = adjust_position_embedding(
360
360
  self.pos_embedding,
361
361
  (old_size[0] // self.patch_size, old_size[1] // self.patch_size),
362
362
  (new_size[0] // self.patch_size, new_size[1] // self.patch_size),
363
363
  num_prefix_tokens,
364
364
  )
365
- )
365
+ self.pos_embedding = nn.Parameter(pos_embedding)
366
366
 
367
367
 
368
368
  registry.register_model_config(
@@ -757,11 +757,8 @@ class Deformable_DETR(DetectionBaseNet):
757
757
  for s, l, b in zip(scores, labels, boxes):
758
758
  # Non-maximum suppression
759
759
  if self.soft_nms is not None:
760
- # Actually much faster on CPU
761
- device = b.device
762
- (soft_scores, keep) = self.soft_nms(b.cpu(), s.cpu(), l.cpu(), score_threshold=0.001)
763
- keep = keep.to(device)
764
- s[keep] = soft_scores.to(device)
760
+ (soft_scores, keep) = self.soft_nms(b, s, l, score_threshold=0.001)
761
+ s[keep] = soft_scores
765
762
 
766
763
  b = b[keep]
767
764
  s = s[keep]
@@ -465,11 +465,8 @@ class DETR(DetectionBaseNet):
465
465
  for s, l, b in zip(scores, labels, boxes):
466
466
  # Non-maximum suppression
467
467
  if self.soft_nms is not None:
468
- # Actually much faster on CPU
469
- device = b.device
470
- (soft_scores, keep) = self.soft_nms(b.cpu(), s.cpu(), l.cpu(), score_threshold=0.001)
471
- keep = keep.to(device)
472
- s[keep] = soft_scores.to(device)
468
+ (soft_scores, keep) = self.soft_nms(b, s, l, score_threshold=0.001)
469
+ s[keep] = soft_scores
473
470
 
474
471
  b = b[keep]
475
472
  s = s[keep]
@@ -83,32 +83,25 @@ class Interpolate2d(nn.Module):
83
83
 
84
84
  def __init__(
85
85
  self,
86
- size: Optional[int | tuple[int, int]] = None,
87
- scale_factor: Optional[float | tuple[float, float]] = None,
88
86
  mode: str = "nearest",
89
87
  align_corners: Optional[bool] = False,
90
88
  ) -> None:
91
89
  super().__init__()
92
- self.size = size
93
- self.scale_factor = scale_factor
94
90
  self.mode = mode
95
91
  self.align_corners = align_corners
96
92
  if mode == "nearest":
97
93
  self.align_corners = None
98
94
 
99
- def forward(self, x: torch.Tensor) -> torch.Tensor:
100
- return F.interpolate(
101
- x, self.size, self.scale_factor, self.mode, self.align_corners, recompute_scale_factor=False
102
- )
95
+ def forward(self, x: torch.Tensor, size: tuple[int, int]) -> torch.Tensor:
96
+ size_list = [size[0], size[1]]
97
+ return F.interpolate(x, size_list, None, self.mode, self.align_corners, recompute_scale_factor=False)
103
98
 
104
99
 
105
- class ResampleFeatureMap(nn.Sequential):
100
+ class ResampleFeatureMap(nn.Module):
106
101
  def __init__(
107
102
  self,
108
103
  in_channels: int,
109
104
  out_channels: int,
110
- input_size: tuple[int, int],
111
- output_size: tuple[int, int],
112
105
  downsample: Literal["max", "bilinear"],
113
106
  upsample: Literal["nearest", "bilinear"],
114
107
  norm_layer: Optional[Callable[..., nn.Module]],
@@ -116,46 +109,63 @@ class ResampleFeatureMap(nn.Sequential):
116
109
  super().__init__()
117
110
  self.in_channels = in_channels
118
111
  self.out_channels = out_channels
119
- self.input_size = input_size
120
- self.output_size = output_size
112
+ self.downsample_mode = downsample
121
113
 
122
114
  if in_channels != out_channels:
123
115
  # padding = ((stride - 1) + (kernel_size - 1)) // 2
124
- self.add_module(
125
- "conv",
126
- Conv2dNormActivation(
127
- in_channels,
128
- out_channels,
129
- kernel_size=(1, 1),
130
- stride=(1, 1),
131
- padding=(0, 0),
132
- norm_layer=norm_layer,
133
- bias=False,
134
- activation_layer=None,
135
- ),
116
+ self.conv = Conv2dNormActivation(
117
+ in_channels,
118
+ out_channels,
119
+ kernel_size=(1, 1),
120
+ stride=(1, 1),
121
+ padding=(0, 0),
122
+ norm_layer=norm_layer,
123
+ bias=False,
124
+ activation_layer=None,
136
125
  )
126
+ else:
127
+ self.conv = None
137
128
 
138
- if input_size[0] > output_size[0] and input_size[1] > output_size[1]:
139
- if downsample == "max":
140
- stride_size_h = int((input_size[0] - 1) // output_size[0] + 1)
141
- stride_size_w = int((input_size[1] - 1) // output_size[1] + 1)
129
+ self.downsample = None
130
+ if downsample != "max":
131
+ self.downsample = Interpolate2d(mode=downsample)
132
+
133
+ self.upsample = Interpolate2d(mode=upsample)
134
+
135
+ def forward(self, x: torch.Tensor, target_size: tuple[int, int]) -> torch.Tensor:
136
+ if self.conv is not None:
137
+ x = self.conv(x)
138
+
139
+ (in_h, in_w) = x.shape[-2:]
140
+ (target_h, target_w) = target_size
141
+ if in_h == target_h and in_w == target_w:
142
+ return x
143
+
144
+ downsample_needed = in_h > target_h or in_w > target_w
145
+ upsample_needed = in_h < target_h or in_w < target_w
146
+
147
+ if downsample_needed is True and upsample_needed is False:
148
+ if self.downsample_mode == "max":
149
+ stride_size_h = int((in_h - 1) // target_h + 1)
150
+ stride_size_w = int((in_w - 1) // target_w + 1)
142
151
  kernel_size = (stride_size_h + 1, stride_size_w + 1)
143
152
  stride = (stride_size_h, stride_size_w)
144
153
  padding = (
145
154
  ((stride[0] - 1) + (kernel_size[0] - 1)) // 2,
146
155
  ((stride[1] - 1) + (kernel_size[1] - 1)) // 2,
147
156
  )
157
+ return F.max_pool2d(x, kernel_size=kernel_size, stride=stride, padding=padding)
148
158
 
149
- down_inst = nn.MaxPool2d(kernel_size, stride=stride, padding=padding)
159
+ if self.downsample is not None:
160
+ return self.downsample(x, size=target_size)
150
161
 
151
- else:
152
- down_inst = Interpolate2d(size=output_size, mode=downsample)
162
+ if upsample_needed is True and downsample_needed is False:
163
+ return self.upsample(x, size=target_size)
153
164
 
154
- self.add_module("downsample", down_inst)
165
+ if self.downsample is not None and self.downsample_mode != "max":
166
+ return self.downsample(x, size=target_size)
155
167
 
156
- else:
157
- if input_size[0] < output_size[0] or input_size[1] < output_size[1]:
158
- self.add_module("upsample", Interpolate2d(size=output_size, mode=upsample))
168
+ return self.upsample(x, size=target_size)
159
169
 
160
170
 
161
171
  class FpnCombine(nn.Module):
@@ -164,8 +174,6 @@ class FpnCombine(nn.Module):
164
174
  in_channels: list[int],
165
175
  fpn_channels: int,
166
176
  inputs_offsets: list[int],
167
- input_size: list[tuple[int, int]],
168
- output_size: tuple[int, int],
169
177
  downsample: Literal["max", "bilinear"],
170
178
  upsample: Literal["nearest", "bilinear"],
171
179
  norm_layer: Optional[Callable[..., nn.Module]],
@@ -173,14 +181,14 @@ class FpnCombine(nn.Module):
173
181
  ):
174
182
  super().__init__()
175
183
  self.weight_method = weight_method
184
+ self.inputs_offsets = inputs_offsets
185
+ self.target_offset = inputs_offsets[0]
176
186
 
177
187
  self.resample = nn.ModuleDict()
178
188
  for offset in inputs_offsets:
179
189
  self.resample[str(offset)] = ResampleFeatureMap(
180
190
  in_channels[offset],
181
191
  fpn_channels,
182
- input_size=input_size[offset],
183
- output_size=output_size,
184
192
  downsample=downsample,
185
193
  upsample=upsample,
186
194
  norm_layer=norm_layer,
@@ -193,10 +201,12 @@ class FpnCombine(nn.Module):
193
201
 
194
202
  def forward(self, x: list[torch.Tensor]) -> torch.Tensor:
195
203
  dtype = x[0].dtype
204
+ target = x[self.target_offset]
205
+ target_size = (int(target.shape[-2]), int(target.shape[-1]))
196
206
  nodes = []
197
207
  for offset, resample in self.resample.items():
198
208
  input_node = x[int(offset)]
199
- input_node = resample(input_node)
209
+ input_node = resample(input_node, target_size=target_size)
200
210
  nodes.append(input_node)
201
211
 
202
212
  if self.weight_method == "attn":
@@ -231,8 +241,6 @@ class BiFpnLayer(nn.Module):
231
241
  def __init__(
232
242
  self,
233
243
  in_channels: list[int],
234
- input_size: list[tuple[int, int]],
235
- feat_sizes: list[tuple[int, int]],
236
244
  fpn_config: list[dict[str, Any]],
237
245
  fpn_channels: int,
238
246
  num_levels: int,
@@ -248,8 +256,6 @@ class BiFpnLayer(nn.Module):
248
256
  in_channels,
249
257
  fpn_channels,
250
258
  inputs_offsets=fnode_cfg["inputs_offsets"],
251
- input_size=input_size,
252
- output_size=feat_sizes[fnode_cfg["feat_level"]],
253
259
  downsample=downsample,
254
260
  upsample=upsample,
255
261
  norm_layer=norm_layer,
@@ -290,9 +296,6 @@ class BiFpnLayer(nn.Module):
290
296
  class BiFpn(nn.Module):
291
297
  def __init__(
292
298
  self,
293
- image_size: tuple[int, int],
294
- min_level: int,
295
- max_level: int,
296
299
  num_levels: int,
297
300
  backbone_channels: list[int],
298
301
  fpn_channels: int,
@@ -300,45 +303,29 @@ class BiFpn(nn.Module):
300
303
  bifpn_config: list[dict[str, Any]],
301
304
  ):
302
305
  super().__init__()
303
- feat_size = image_size
304
- feat_sizes = [feat_size]
305
- for _ in range(1, max_level + 1):
306
- feat_size = ((feat_size[0] - 1) // 2 + 1, (feat_size[1] - 1) // 2 + 1)
307
- feat_sizes.append(feat_size)
308
-
309
- input_size = feat_sizes.copy()
310
- input_size = input_size[-num_levels:]
311
- prev_feat_size = feat_sizes[min_level]
312
- self.resample = nn.ModuleDict()
313
- for level in range(num_levels):
314
- feat_size = feat_sizes[level + min_level]
315
- if level < len(backbone_channels):
316
- in_channels = backbone_channels[level]
317
- input_size[level] = feat_size
318
- else:
319
- self.resample[str(level)] = ResampleFeatureMap(
306
+ self.resample = nn.ModuleList()
307
+ num_backbone_levels = len(backbone_channels)
308
+ extra_levels = max(0, num_levels - num_backbone_levels)
309
+ in_channels = backbone_channels[-1]
310
+ for _ in range(extra_levels):
311
+ self.resample.append(
312
+ ResampleFeatureMap(
320
313
  in_channels=in_channels,
321
314
  out_channels=fpn_channels,
322
- input_size=prev_feat_size,
323
- output_size=feat_size,
324
315
  downsample="max",
325
316
  upsample="nearest",
326
317
  norm_layer=nn.BatchNorm2d,
327
318
  )
328
- in_channels = fpn_channels
329
- backbone_channels.append(in_channels)
330
-
331
- prev_feat_size = feat_size
319
+ )
320
+ in_channels = fpn_channels
321
+ backbone_channels.append(in_channels)
332
322
 
333
323
  self.cells = nn.ModuleList()
334
324
  fpn_combine_channels = backbone_channels
335
325
  for _ in range(fpn_cell_repeats):
336
326
  fpn_combine_channels = fpn_combine_channels + [fpn_channels for _ in bifpn_config]
337
- input_size = input_size + [feat_sizes[fc["feat_level"]] for fc in bifpn_config]
338
327
  fpn_layer = BiFpnLayer(
339
328
  in_channels=fpn_combine_channels,
340
- input_size=input_size,
341
- feat_sizes=feat_sizes,
342
329
  fpn_config=bifpn_config,
343
330
  fpn_channels=fpn_channels,
344
331
  num_levels=num_levels,
@@ -348,11 +335,12 @@ class BiFpn(nn.Module):
348
335
  )
349
336
  self.cells.append(fpn_layer)
350
337
  fpn_combine_channels = fpn_combine_channels[-num_levels::]
351
- input_size = input_size[-num_levels::]
352
338
 
353
339
  def forward(self, x: list[torch.Tensor]) -> list[torch.Tensor]:
354
- for resample in self.resample.values():
355
- x.append(resample(x[-1]))
340
+ for resample in self.resample:
341
+ input_node = x[-1]
342
+ target_size = ((input_node.shape[-2] - 1) // 2 + 1, (input_node.shape[-1] - 1) // 2 + 1)
343
+ x.append(resample(input_node, target_size=target_size))
356
344
 
357
345
  for cell in self.cells:
358
346
  x = cell(x)
@@ -572,9 +560,6 @@ class EfficientDet(DetectionBaseNet):
572
560
  self.backbone.return_stages = self.backbone.return_stages[-3:]
573
561
 
574
562
  self.bifpn = BiFpn(
575
- image_size=self.size,
576
- min_level=min_level,
577
- max_level=max_level,
578
563
  num_levels=num_levels,
579
564
  backbone_channels=self.backbone.return_channels,
580
565
  fpn_channels=fpn_channels,
@@ -614,12 +599,6 @@ class EfficientDet(DetectionBaseNet):
614
599
  num_anchors=self.anchor_generator.num_anchors_per_location()[0],
615
600
  )
616
601
 
617
- def adjust_size(self, new_size: tuple[int, int]) -> None:
618
- if new_size == self.size:
619
- return
620
-
621
- raise RuntimeError("Model resizing not supported")
622
-
623
602
  def freeze(self, freeze_classifier: bool = True) -> None:
624
603
  for param in self.parameters():
625
604
  param.requires_grad = False
@@ -706,13 +685,8 @@ class EfficientDet(DetectionBaseNet):
706
685
 
707
686
  # Non-maximum suppression
708
687
  if self.soft_nms is not None:
709
- # Actually much faster on CPU
710
- device = image_boxes.device
711
- (soft_scores, keep) = self.soft_nms(
712
- image_boxes.cpu(), image_scores.cpu(), image_labels.cpu(), score_threshold=0.001
713
- )
714
- keep = keep.to(device)
715
- image_scores[keep] = soft_scores.to(device)
688
+ (soft_scores, keep) = self.soft_nms(image_boxes, image_scores, image_labels, score_threshold=0.001)
689
+ image_scores[keep] = soft_scores
716
690
  else:
717
691
  keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
718
692
 
@@ -455,13 +455,8 @@ class FCOS(DetectionBaseNet):
455
455
 
456
456
  # Non-maximum suppression
457
457
  if self.soft_nms is not None:
458
- # Actually much faster on CPU
459
- device = image_boxes.device
460
- (soft_scores, keep) = self.soft_nms(
461
- image_boxes.cpu(), image_scores.cpu(), image_labels.cpu(), score_threshold=0.001
462
- )
463
- keep = keep.to(device)
464
- image_scores[keep] = soft_scores.to(device)
458
+ (soft_scores, keep) = self.soft_nms(image_boxes, image_scores, image_labels, score_threshold=0.001)
459
+ image_scores[keep] = soft_scores
465
460
  else:
466
461
  keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
467
462
 
@@ -417,13 +417,8 @@ class RetinaNet(DetectionBaseNet):
417
417
 
418
418
  # Non-maximum suppression
419
419
  if self.soft_nms is not None:
420
- # Actually much faster on CPU
421
- device = image_boxes.device
422
- (soft_scores, keep) = self.soft_nms(
423
- image_boxes.cpu(), image_scores.cpu(), image_labels.cpu(), score_threshold=0.001
424
- )
425
- keep = keep.to(device)
426
- image_scores[keep] = soft_scores.to(device)
420
+ (soft_scores, keep) = self.soft_nms(image_boxes, image_scores, image_labels, score_threshold=0.001)
421
+ image_scores[keep] = soft_scores
427
422
  else:
428
423
  keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
429
424
 
@@ -1070,6 +1070,7 @@ class RT_DETR_v1(DetectionBaseNet):
1070
1070
  W = feat.shape[3]
1071
1071
  spatial_shapes.append([H, W])
1072
1072
  level_start_index.append(H * W + level_start_index[-1])
1073
+
1073
1074
  level_start_index.pop()
1074
1075
 
1075
1076
  detections: list[dict[str, torch.Tensor]] = []
@@ -1086,6 +1087,7 @@ class RT_DETR_v1(DetectionBaseNet):
1086
1087
 
1087
1088
  return (detections, losses)
1088
1089
 
1090
+ @torch.no_grad() # type: ignore[untyped-decorator]
1089
1091
  def reparameterize_model(self) -> None:
1090
1092
  if self.reparameterized is True:
1091
1093
  return