birder 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (100) hide show
  1. birder/__init__.py +2 -0
  2. birder/common/fs_ops.py +81 -1
  3. birder/common/training_cli.py +12 -2
  4. birder/common/training_utils.py +73 -12
  5. birder/data/collators/detection.py +3 -1
  6. birder/datahub/_lib.py +15 -6
  7. birder/datahub/evaluation.py +591 -0
  8. birder/eval/__init__.py +0 -0
  9. birder/eval/__main__.py +74 -0
  10. birder/eval/_embeddings.py +50 -0
  11. birder/eval/adversarial.py +315 -0
  12. birder/eval/benchmarks/__init__.py +0 -0
  13. birder/eval/benchmarks/awa2.py +357 -0
  14. birder/eval/benchmarks/bioscan5m.py +198 -0
  15. birder/eval/benchmarks/fishnet.py +318 -0
  16. birder/eval/benchmarks/flowers102.py +210 -0
  17. birder/eval/benchmarks/fungiclef.py +261 -0
  18. birder/eval/benchmarks/nabirds.py +202 -0
  19. birder/eval/benchmarks/newt.py +262 -0
  20. birder/eval/benchmarks/plankton.py +255 -0
  21. birder/eval/benchmarks/plantdoc.py +259 -0
  22. birder/eval/benchmarks/plantnet.py +252 -0
  23. birder/eval/classification.py +235 -0
  24. birder/eval/methods/__init__.py +0 -0
  25. birder/eval/methods/ami.py +78 -0
  26. birder/eval/methods/knn.py +71 -0
  27. birder/eval/methods/linear.py +152 -0
  28. birder/eval/methods/mlp.py +178 -0
  29. birder/eval/methods/simpleshot.py +100 -0
  30. birder/eval/methods/svm.py +92 -0
  31. birder/inference/classification.py +23 -2
  32. birder/inference/detection.py +35 -15
  33. birder/net/_vit_configs.py +5 -0
  34. birder/net/cait.py +3 -3
  35. birder/net/coat.py +3 -3
  36. birder/net/cswin_transformer.py +2 -1
  37. birder/net/deit.py +1 -1
  38. birder/net/deit3.py +1 -1
  39. birder/net/detection/__init__.py +2 -0
  40. birder/net/detection/base.py +41 -18
  41. birder/net/detection/deformable_detr.py +74 -50
  42. birder/net/detection/detr.py +29 -26
  43. birder/net/detection/efficientdet.py +42 -25
  44. birder/net/detection/faster_rcnn.py +53 -21
  45. birder/net/detection/fcos.py +42 -23
  46. birder/net/detection/lw_detr.py +1204 -0
  47. birder/net/detection/plain_detr.py +60 -47
  48. birder/net/detection/retinanet.py +47 -35
  49. birder/net/detection/rt_detr_v1.py +49 -46
  50. birder/net/detection/rt_detr_v2.py +95 -102
  51. birder/net/detection/ssd.py +47 -31
  52. birder/net/detection/ssdlite.py +2 -2
  53. birder/net/detection/yolo_v2.py +33 -18
  54. birder/net/detection/yolo_v3.py +35 -33
  55. birder/net/detection/yolo_v4.py +35 -20
  56. birder/net/detection/yolo_v4_tiny.py +1 -2
  57. birder/net/edgevit.py +3 -3
  58. birder/net/efficientvit_msft.py +1 -1
  59. birder/net/flexivit.py +1 -1
  60. birder/net/hiera.py +44 -67
  61. birder/net/hieradet.py +2 -2
  62. birder/net/maxvit.py +2 -2
  63. birder/net/mim/fcmae.py +2 -2
  64. birder/net/mim/mae_hiera.py +9 -16
  65. birder/net/mnasnet.py +2 -2
  66. birder/net/nextvit.py +4 -4
  67. birder/net/resnext.py +2 -2
  68. birder/net/rope_deit3.py +2 -2
  69. birder/net/rope_flexivit.py +2 -2
  70. birder/net/rope_vit.py +2 -2
  71. birder/net/simple_vit.py +1 -1
  72. birder/net/squeezenet.py +1 -1
  73. birder/net/ssl/capi.py +32 -25
  74. birder/net/ssl/dino_v2.py +12 -15
  75. birder/net/ssl/franca.py +26 -19
  76. birder/net/van.py +2 -2
  77. birder/net/vit.py +21 -3
  78. birder/net/vit_parallel.py +1 -1
  79. birder/net/vit_sam.py +62 -16
  80. birder/net/xcit.py +1 -1
  81. birder/ops/msda.py +46 -16
  82. birder/scripts/benchmark.py +35 -8
  83. birder/scripts/predict.py +14 -1
  84. birder/scripts/predict_detection.py +7 -1
  85. birder/scripts/train.py +27 -11
  86. birder/scripts/train_capi.py +13 -10
  87. birder/scripts/train_detection.py +18 -7
  88. birder/scripts/train_franca.py +10 -2
  89. birder/scripts/train_kd.py +28 -11
  90. birder/tools/adversarial.py +5 -0
  91. birder/tools/convert_model.py +101 -43
  92. birder/tools/quantize_model.py +33 -16
  93. birder/version.py +1 -1
  94. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/METADATA +17 -10
  95. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/RECORD +99 -75
  96. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/WHEEL +1 -1
  97. birder/scripts/evaluate.py +0 -176
  98. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/entry_points.txt +0 -0
  99. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/licenses/LICENSE +0 -0
  100. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/top_level.txt +0 -0
birder/net/deit3.py CHANGED
@@ -185,7 +185,7 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
185
185
  xs = self.encoder.forward_features(x, out_indices=self.out_indices)
186
186
 
187
187
  out: dict[str, torch.Tensor] = {}
188
- for stage_name, stage_x in zip(self.return_stages, xs):
188
+ for stage_name, stage_x in zip(self.return_stages, xs, strict=True):
189
189
  stage_x = stage_x[:, self.num_special_tokens :]
190
190
  stage_x = stage_x.permute(0, 2, 1)
191
191
  B, C, _ = stage_x.size()
@@ -3,6 +3,7 @@ from birder.net.detection.detr import DETR
3
3
  from birder.net.detection.efficientdet import EfficientDet
4
4
  from birder.net.detection.faster_rcnn import Faster_RCNN
5
5
  from birder.net.detection.fcos import FCOS
6
+ from birder.net.detection.lw_detr import LW_DETR
6
7
  from birder.net.detection.plain_detr import Plain_DETR
7
8
  from birder.net.detection.retinanet import RetinaNet
8
9
  from birder.net.detection.rt_detr_v1 import RT_DETR_v1
@@ -21,6 +22,7 @@ __all__ = [
21
22
  "EfficientDet",
22
23
  "Faster_RCNN",
23
24
  "FCOS",
25
+ "LW_DETR",
24
26
  "Plain_DETR",
25
27
  "RetinaNet",
26
28
  "RT_DETR_v1",
@@ -44,6 +44,7 @@ class DetectionBaseNet(nn.Module):
44
44
  block_group_regex: Optional[str]
45
45
  auto_register = False
46
46
  scriptable = True
47
+ exportable = True
47
48
  task = str(Task.OBJECT_DETECTION)
48
49
 
49
50
  def __init_subclass__(cls) -> None:
@@ -134,40 +135,62 @@ class DetectionBaseNet(nn.Module):
134
135
  f" Found invalid box {degenerate_bb} for target at index {target_idx}.",
135
136
  )
136
137
 
137
- # pylint: disable=protected-access
138
- def _to_img_list(self, x: torch.Tensor, image_sizes: Optional[list[list[int]]] = None) -> "ImageList":
138
+ def _to_img_list(self, x: torch.Tensor, image_sizes: Optional[list[tuple[int, int]]] = None) -> "ImageList":
139
+ B = x.size(0)
139
140
  if image_sizes is None:
140
- image_sizes = [img.shape[-2:] for img in x]
141
-
142
- image_sizes_list: list[tuple[int, int]] = []
143
- for image_size in image_sizes:
144
- torch._assert(
145
- len(image_size) == 2,
146
- f"Input tensors expected to have in the last two elements H and W, instead got {image_size}",
147
- )
148
- image_sizes_list.append((image_size[0], image_size[1]))
141
+ H = x.size(2)
142
+ W = x.size(3)
143
+ h_tensor = torch.full((B,), H, dtype=torch.int64, device=x.device)
144
+ w_tensor = torch.full((B,), W, dtype=torch.int64, device=x.device)
145
+ image_sizes_tensor = torch.stack([h_tensor, w_tensor], dim=1)
146
+ else:
147
+ image_sizes_tensor = torch.tensor(image_sizes, dtype=torch.int64, device=x.device)
149
148
 
150
- return ImageList(x, image_sizes_list)
149
+ return ImageList(x, image_sizes_tensor)
151
150
 
152
151
  def forward(
153
152
  self,
154
153
  x: torch.Tensor,
155
154
  targets: Optional[list[dict[str, torch.Tensor]]] = None,
156
155
  masks: Optional[torch.Tensor] = None,
157
- image_sizes: Optional[list[list[int]]] = None,
156
+ image_sizes: Optional[list[tuple[int, int]]] = None,
158
157
  ) -> tuple[list[dict[str, torch.Tensor]], dict[str, torch.Tensor]]:
159
158
  # TypedDict not supported for TorchScript - avoid returning DetectorResultType
160
159
  raise NotImplementedError
161
160
 
162
161
 
163
162
  class ImageList:
164
- def __init__(self, tensors: torch.Tensor, image_sizes: list[tuple[int, int]]) -> None:
163
+ def __init__(self, tensors: torch.Tensor, image_sizes: torch.Tensor) -> None:
165
164
  self.tensors = tensors
166
- self.image_sizes = image_sizes
165
+ self.image_sizes = image_sizes # Shape: (B, 2) with [H, W] format
167
166
 
168
167
  def to(self, device: torch.device) -> "ImageList":
169
168
  cast_tensor = self.tensors.to(device)
170
- return ImageList(cast_tensor, self.image_sizes)
169
+ cast_sizes = self.image_sizes.to(device)
170
+ return ImageList(cast_tensor, cast_sizes)
171
+
172
+
173
+ def clip_boxes_to_image(boxes: torch.Tensor, image_size: torch.Tensor) -> torch.Tensor:
174
+ """
175
+ Clip boxes to image boundaries
176
+
177
+ Parameters
178
+ ----------
179
+ boxes
180
+ Boxes in (x1, y1, x2, y2) format, shape (..., 4)
181
+ image_size
182
+ Tensor of [height, width]
183
+
184
+ Returns
185
+ -------
186
+ Clipped boxes
187
+ """
188
+
189
+ boxes_x = boxes[..., 0::2].clamp(min=0, max=image_size[1])
190
+ boxes_y = boxes[..., 1::2].clamp(min=0, max=image_size[0])
191
+ clipped_boxes = torch.stack([boxes_x[..., 0], boxes_y[..., 0], boxes_x[..., 1], boxes_y[..., 1]], dim=-1)
192
+
193
+ return clipped_boxes
171
194
 
172
195
 
173
196
  ###############################################################################
@@ -325,7 +348,7 @@ class SimpleFeaturePyramidNetwork(nn.Module):
325
348
 
326
349
 
327
350
  # pylint: disable=protected-access,too-many-locals
328
- @torch.jit._script_if_tracing # type: ignore
351
+ @torch.jit._script_if_tracing # type: ignore[untyped-decorator]
329
352
  def encode_boxes(reference_boxes: torch.Tensor, proposals: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
330
353
  """
331
354
  Encode a set of proposals with respect to some reference boxes
@@ -609,7 +632,7 @@ class Matcher(nn.Module):
609
632
 
610
633
  if match_quality_matrix.numel() == 0:
611
634
  # Empty targets or proposals not supported during training
612
- if match_quality_matrix.shape[0] == 0:
635
+ if match_quality_matrix.size(0) == 0:
613
636
  raise ValueError("No ground-truth boxes available for one of the images during training")
614
637
 
615
638
  raise ValueError("No proposal boxes available for one of the images during training")
@@ -56,7 +56,7 @@ class HungarianMatcher(nn.Module):
56
56
  @torch.jit.unused # type: ignore[untyped-decorator]
57
57
  def forward(
58
58
  self, class_logits: torch.Tensor, box_regression: torch.Tensor, targets: list[dict[str, torch.Tensor]]
59
- ) -> list[torch.Tensor]:
59
+ ) -> list[tuple[torch.Tensor, torch.Tensor]]:
60
60
  with torch.no_grad():
61
61
  B, num_queries = class_logits.shape[:2]
62
62
 
@@ -135,7 +135,7 @@ class MultiScaleDeformableAttention(nn.Module):
135
135
  self.reset_parameters()
136
136
 
137
137
  def reset_parameters(self) -> None:
138
- nn.init.constant_(self.sampling_offsets.weight, 0.0)
138
+ nn.init.zeros_(self.sampling_offsets.weight)
139
139
  thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
140
140
  grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
141
141
  grid_init = (
@@ -149,12 +149,12 @@ class MultiScaleDeformableAttention(nn.Module):
149
149
  with torch.no_grad():
150
150
  self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
151
151
 
152
- nn.init.constant_(self.attention_weights.weight, 0.0)
153
- nn.init.constant_(self.attention_weights.bias, 0.0)
152
+ nn.init.zeros_(self.attention_weights.weight)
153
+ nn.init.zeros_(self.attention_weights.bias)
154
154
  nn.init.xavier_uniform_(self.value_proj.weight)
155
- nn.init.constant_(self.value_proj.bias, 0.0)
155
+ nn.init.zeros_(self.value_proj.bias)
156
156
  nn.init.xavier_uniform_(self.output_proj.weight)
157
- nn.init.constant_(self.output_proj.bias, 0.0)
157
+ nn.init.zeros_(self.output_proj.bias)
158
158
 
159
159
  def forward(
160
160
  self,
@@ -164,10 +164,11 @@ class MultiScaleDeformableAttention(nn.Module):
164
164
  input_spatial_shapes: torch.Tensor,
165
165
  input_level_start_index: torch.Tensor,
166
166
  input_padding_mask: Optional[torch.Tensor] = None,
167
+ src_shapes: Optional[list[list[int]]] = None,
167
168
  ) -> torch.Tensor:
168
169
  N, num_queries, _ = query.size()
169
170
  N, sequence_length, _ = input_flatten.size()
170
- assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == sequence_length
171
+ # assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == sequence_length
171
172
 
172
173
  value = self.value_proj(input_flatten)
173
174
  if input_padding_mask is not None:
@@ -208,6 +209,7 @@ class MultiScaleDeformableAttention(nn.Module):
208
209
  sampling_locations,
209
210
  attention_weights,
210
211
  self.im2col_step,
212
+ src_shapes,
211
213
  )
212
214
 
213
215
  output = self.output_proj(output)
@@ -235,8 +237,9 @@ class DeformableTransformerEncoderLayer(nn.Module):
235
237
  spatial_shapes: torch.Tensor,
236
238
  level_start_index: torch.Tensor,
237
239
  mask: Optional[torch.Tensor],
240
+ src_shapes: Optional[list[list[int]]] = None,
238
241
  ) -> torch.Tensor:
239
- src2 = self.self_attn(src + pos, reference_points, src, spatial_shapes, level_start_index, mask)
242
+ src2 = self.self_attn(src + pos, reference_points, src, spatial_shapes, level_start_index, mask, src_shapes)
240
243
  src = src + self.dropout(src2)
241
244
  src = self.norm1(src)
242
245
 
@@ -277,13 +280,13 @@ class DeformableTransformerDecoderLayer(nn.Module):
277
280
  level_start_index: torch.Tensor,
278
281
  src_padding_mask: Optional[torch.Tensor],
279
282
  self_attn_mask: Optional[torch.Tensor] = None,
283
+ src_shapes: Optional[list[list[int]]] = None,
280
284
  ) -> torch.Tensor:
281
285
  # Self attention
282
- q = tgt + query_pos
283
- k = tgt + query_pos
286
+ q_k = tgt + query_pos
284
287
 
285
288
  tgt2, _ = self.self_attn(
286
- q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1), need_weights=False, attn_mask=self_attn_mask
289
+ q_k.transpose(0, 1), q_k.transpose(0, 1), tgt.transpose(0, 1), need_weights=False, attn_mask=self_attn_mask
287
290
  )
288
291
  tgt2 = tgt2.transpose(0, 1)
289
292
  tgt = tgt + self.dropout(tgt2)
@@ -291,7 +294,7 @@ class DeformableTransformerDecoderLayer(nn.Module):
291
294
 
292
295
  # Cross attention
293
296
  tgt2 = self.cross_attn(
294
- tgt + query_pos, reference_points, src, src_spatial_shapes, level_start_index, src_padding_mask
297
+ tgt + query_pos, reference_points, src, src_spatial_shapes, level_start_index, src_padding_mask, src_shapes
295
298
  )
296
299
  tgt = tgt + self.dropout(tgt2)
297
300
  tgt = self.norm2(tgt)
@@ -311,17 +314,15 @@ class DeformableTransformerEncoder(nn.Module):
311
314
 
312
315
  @staticmethod
313
316
  def get_reference_points(
314
- spatial_shapes: torch.Tensor, valid_ratios: torch.Tensor, device: torch.device
317
+ src_shapes: list[list[int]], valid_ratios: torch.Tensor, device: torch.device
315
318
  ) -> torch.Tensor:
316
319
  reference_points_list = []
317
- for lvl, spatial_shape in enumerate(spatial_shapes):
318
- H = spatial_shape[0]
319
- W = spatial_shape[1]
320
- ref_y, ref_x = torch.meshgrid(
321
- torch.linspace(0.5, H - 0.5, H, dtype=torch.float32, device=device),
322
- torch.linspace(0.5, W - 0.5, W, dtype=torch.float32, device=device),
323
- indexing="ij",
324
- )
320
+ for lvl, (H, W) in enumerate(src_shapes):
321
+ # Use arange instead of linspace - works with symbolic sizes
322
+ # linspace(0.5, H-0.5, H) is equivalent to arange(H) + 0.5
323
+ ref_y = (torch.arange(H, dtype=torch.float32, device=device) + 0.5).view(-1, 1).expand(-1, W)
324
+ ref_x = (torch.arange(W, dtype=torch.float32, device=device) + 0.5).view(1, -1).expand(H, -1)
325
+
325
326
  ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H)
326
327
  ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W)
327
328
  ref = torch.stack((ref_x, ref_y), dim=-1)
@@ -336,15 +337,16 @@ class DeformableTransformerEncoder(nn.Module):
336
337
  self,
337
338
  src: torch.Tensor,
338
339
  spatial_shapes: torch.Tensor,
340
+ src_shapes: list[list[int]],
339
341
  level_start_index: torch.Tensor,
340
342
  pos: torch.Tensor,
341
343
  valid_ratios: torch.Tensor,
342
344
  mask: torch.Tensor,
343
345
  ) -> torch.Tensor:
344
346
  out = src
345
- reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)
347
+ reference_points = self.get_reference_points(src_shapes, valid_ratios, device=src.device)
346
348
  for layer in self.layers:
347
- out = layer(out, pos, reference_points, spatial_shapes, level_start_index, mask)
349
+ out = layer(out, pos, reference_points, spatial_shapes, level_start_index, mask, src_shapes)
348
350
 
349
351
  return out
350
352
 
@@ -369,6 +371,7 @@ class DeformableTransformerDecoder(nn.Module):
369
371
  query_pos: torch.Tensor,
370
372
  src_valid_ratios: torch.Tensor,
371
373
  src_padding_mask: torch.Tensor,
374
+ src_shapes: Optional[list[list[int]]] = None,
372
375
  ) -> tuple[torch.Tensor, torch.Tensor]:
373
376
  output = tgt
374
377
 
@@ -391,6 +394,7 @@ class DeformableTransformerDecoder(nn.Module):
391
394
  src_spatial_shapes,
392
395
  src_level_start_index,
393
396
  src_padding_mask,
397
+ src_shapes=src_shapes,
394
398
  )
395
399
 
396
400
  if self.box_refine is True:
@@ -482,10 +486,11 @@ class DeformableTransformer(nn.Module):
482
486
  src_list = []
483
487
  lvl_pos_embed_list = []
484
488
  mask_list = []
485
- spatial_shape_list: list[list[int]] = [] # list[tuple[int, int]] not supported on TorchScript
489
+ src_shapes: list[list[int]] = [] # list[tuple[int, int]] not supported on TorchScript
490
+
486
491
  for lvl, (src, pos_embed, mask) in enumerate(zip(srcs, pos_embeds, masks)):
487
- _, _, H, W = src.size()
488
- spatial_shape_list.append([H, W])
492
+ H, W = src.shape[-2], src.shape[-1]
493
+ src_shapes.append([H, W])
489
494
  src = src.flatten(2).transpose(1, 2)
490
495
  pos_embed = pos_embed.flatten(2).transpose(1, 2)
491
496
  mask = mask.flatten(1)
@@ -497,13 +502,19 @@ class DeformableTransformer(nn.Module):
497
502
  src_flatten = torch.concat(src_list, dim=1)
498
503
  mask_flatten = torch.concat(mask_list, dim=1)
499
504
  lvl_pos_embed_flatten = torch.concat(lvl_pos_embed_list, dim=1)
500
- spatial_shapes = torch.as_tensor(spatial_shape_list, dtype=torch.long, device=src_flatten.device)
505
+ spatial_shapes = torch.tensor(src_shapes, dtype=torch.long, device=src_flatten.device)
501
506
  level_start_index = torch.concat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]), dim=0)
502
507
  valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], dim=1)
503
508
 
504
509
  # Encoder
505
510
  memory = self.encoder(
506
- src_flatten, spatial_shapes, level_start_index, lvl_pos_embed_flatten, valid_ratios, mask_flatten
511
+ src_flatten,
512
+ spatial_shapes,
513
+ src_shapes,
514
+ level_start_index,
515
+ lvl_pos_embed_flatten,
516
+ valid_ratios,
517
+ mask_flatten,
507
518
  )
508
519
 
509
520
  # Prepare input for decoder
@@ -515,7 +526,15 @@ class DeformableTransformer(nn.Module):
515
526
 
516
527
  # Decoder
517
528
  hs, inter_references = self.decoder(
518
- tgt, reference_points, memory, spatial_shapes, level_start_index, query_embed, valid_ratios, mask_flatten
529
+ tgt,
530
+ reference_points,
531
+ memory,
532
+ spatial_shapes,
533
+ level_start_index,
534
+ query_embed,
535
+ valid_ratios,
536
+ mask_flatten,
537
+ src_shapes,
519
538
  )
520
539
 
521
540
  return (hs, reference_points, inter_references)
@@ -587,7 +606,7 @@ class Deformable_DETR(DetectionBaseNet):
587
606
 
588
607
  self.query_embed = nn.Embedding(num_queries, hidden_dim * 2)
589
608
  self.pos_enc = PositionEmbeddingSine(hidden_dim // 2, normalize=True)
590
- self.matcher = HungarianMatcher(cost_class=2, cost_bbox=5, cost_giou=2)
609
+ self.matcher = HungarianMatcher(cost_class=2.0, cost_bbox=5.0, cost_giou=2.0)
591
610
 
592
611
  class_embed = nn.Linear(hidden_dim, self.num_classes)
593
612
  bbox_embed = MLP(hidden_dim, [hidden_dim, hidden_dim, 4], activation_layer=nn.ReLU)
@@ -641,7 +660,8 @@ class Deformable_DETR(DetectionBaseNet):
641
660
  for param in self.class_embed.parameters():
642
661
  param.requires_grad_(True)
643
662
 
644
- def _get_src_permutation_idx(self, indices: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
663
+ @staticmethod
664
+ def _get_src_permutation_idx(indices: list[tuple[torch.Tensor, torch.Tensor]]) -> tuple[torch.Tensor, torch.Tensor]:
645
665
  batch_idx = torch.concat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
646
666
  src_idx = torch.concat([src for (src, _) in indices])
647
667
  return (batch_idx, src_idx)
@@ -650,7 +670,7 @@ class Deformable_DETR(DetectionBaseNet):
650
670
  self,
651
671
  cls_logits: torch.Tensor,
652
672
  targets: list[dict[str, torch.Tensor]],
653
- indices: list[torch.Tensor],
673
+ indices: list[tuple[torch.Tensor, torch.Tensor]],
654
674
  num_boxes: int,
655
675
  ) -> torch.Tensor:
656
676
  idx = self._get_src_permutation_idx(indices)
@@ -675,7 +695,7 @@ class Deformable_DETR(DetectionBaseNet):
675
695
  self,
676
696
  box_output: torch.Tensor,
677
697
  targets: list[dict[str, torch.Tensor]],
678
- indices: list[torch.Tensor],
698
+ indices: list[tuple[torch.Tensor, torch.Tensor]],
679
699
  num_boxes: int,
680
700
  ) -> tuple[torch.Tensor, torch.Tensor]:
681
701
  idx = self._get_src_permutation_idx(indices)
@@ -709,7 +729,7 @@ class Deformable_DETR(DetectionBaseNet):
709
729
  if training_utils.is_dist_available_and_initialized() is True:
710
730
  torch.distributed.all_reduce(num_boxes)
711
731
 
712
- num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1).item()
732
+ num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1)
713
733
 
714
734
  loss_ce_list = []
715
735
  loss_bbox_list = []
@@ -734,7 +754,7 @@ class Deformable_DETR(DetectionBaseNet):
734
754
  return losses
735
755
 
736
756
  def postprocess_detections(
737
- self, class_logits: torch.Tensor, box_regression: torch.Tensor, image_shapes: list[tuple[int, int]]
757
+ self, class_logits: torch.Tensor, box_regression: torch.Tensor, image_sizes: torch.Tensor
738
758
  ) -> list[dict[str, torch.Tensor]]:
739
759
  prob = class_logits.sigmoid()
740
760
  topk_values, topk_indexes = torch.topk(prob.view(class_logits.shape[0], -1), k=100, dim=1)
@@ -743,14 +763,12 @@ class Deformable_DETR(DetectionBaseNet):
743
763
  labels = topk_indexes % class_logits.shape[2]
744
764
  labels += 1 # Background offset
745
765
 
746
- target_sizes = torch.tensor(image_shapes, device=class_logits.device)
747
-
748
766
  # Convert to [x0, y0, x1, y1] format
749
767
  boxes = box_ops.box_convert(box_regression, in_fmt="cxcywh", out_fmt="xyxy")
750
768
  boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
751
769
 
752
770
  # Convert from relative [0, 1] to absolute [0, height] coordinates
753
- img_h, img_w = target_sizes.unbind(1)
771
+ img_h, img_w = image_sizes.unbind(1)
754
772
  scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
755
773
  boxes = boxes * scale_fct[:, None, :]
756
774
 
@@ -776,16 +794,7 @@ class Deformable_DETR(DetectionBaseNet):
776
794
  return detections
777
795
 
778
796
  # pylint: disable=too-many-locals
779
- def forward(
780
- self,
781
- x: torch.Tensor,
782
- targets: Optional[list[dict[str, torch.Tensor]]] = None,
783
- masks: Optional[torch.Tensor] = None,
784
- image_sizes: Optional[list[list[int]]] = None,
785
- ) -> tuple[list[dict[str, torch.Tensor]], dict[str, torch.Tensor]]:
786
- self._input_check(targets)
787
- images = self._to_img_list(x, image_sizes)
788
-
797
+ def forward_net(self, x: torch.Tensor, masks: Optional[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
789
798
  features: dict[str, torch.Tensor] = self.backbone.detection_features(x)
790
799
  feature_list = list(features.values())
791
800
  mask_list = []
@@ -829,6 +838,20 @@ class Deformable_DETR(DetectionBaseNet):
829
838
  outputs_class = torch.stack(outputs_classes)
830
839
  outputs_coord = torch.stack(outputs_coords)
831
840
 
841
+ return (outputs_class, outputs_coord)
842
+
843
+ def forward(
844
+ self,
845
+ x: torch.Tensor,
846
+ targets: Optional[list[dict[str, torch.Tensor]]] = None,
847
+ masks: Optional[torch.Tensor] = None,
848
+ image_sizes: Optional[list[tuple[int, int]]] = None,
849
+ ) -> tuple[list[dict[str, torch.Tensor]], dict[str, torch.Tensor]]:
850
+ self._input_check(targets)
851
+ image_sizes_tensor = self._to_img_list(x, image_sizes).image_sizes
852
+
853
+ outputs_class, outputs_coord = self.forward_net(x, masks)
854
+
832
855
  losses = {}
833
856
  detections: list[dict[str, torch.Tensor]] = []
834
857
  if self.training is True:
@@ -838,14 +861,15 @@ class Deformable_DETR(DetectionBaseNet):
838
861
  for idx, target in enumerate(targets):
839
862
  boxes = target["boxes"]
840
863
  boxes = box_ops.box_convert(boxes, in_fmt="xyxy", out_fmt="cxcywh")
841
- boxes = boxes / torch.tensor(images.image_sizes[idx][::-1] * 2, dtype=torch.float32, device=x.device)
864
+ scale = image_sizes_tensor[idx].flip(0).repeat(2).float() # flip to [W, H], repeat to [W, H, W, H]
865
+ boxes = boxes / scale
842
866
  targets[idx]["boxes"] = boxes
843
867
  targets[idx]["labels"] = target["labels"] - 1 # No background
844
868
 
845
869
  losses = self.compute_loss(targets, outputs_class, outputs_coord)
846
870
 
847
871
  else:
848
- detections = self.postprocess_detections(outputs_class[-1], outputs_coord[-1], images.image_sizes)
872
+ detections = self.postprocess_detections(outputs_class[-1], outputs_coord[-1], image_sizes_tensor)
849
873
 
850
874
  return (detections, losses)
851
875
 
@@ -49,7 +49,7 @@ class HungarianMatcher(nn.Module):
49
49
  @torch.jit.unused # type: ignore[untyped-decorator]
50
50
  def forward(
51
51
  self, class_logits: torch.Tensor, box_regression: torch.Tensor, targets: list[dict[str, torch.Tensor]]
52
- ) -> list[torch.Tensor]:
52
+ ) -> list[tuple[torch.Tensor, torch.Tensor]]:
53
53
  with torch.no_grad():
54
54
  B, num_queries = class_logits.shape[:2]
55
55
 
@@ -148,10 +148,9 @@ class TransformerDecoderLayer(nn.Module):
148
148
  query_pos: torch.Tensor,
149
149
  memory_key_padding_mask: Optional[torch.Tensor] = None,
150
150
  ) -> torch.Tensor:
151
- q = tgt + query_pos
152
- k = tgt + query_pos
151
+ q_k = tgt + query_pos
153
152
 
154
- tgt2, _ = self.self_attn(q, k, value=tgt, need_weights=False)
153
+ tgt2, _ = self.self_attn(q_k, q_k, value=tgt, need_weights=False)
155
154
  tgt = tgt + self.dropout1(tgt2)
156
155
  tgt = self.norm1(tgt)
157
156
  tgt2, _ = self.multihead_attn(
@@ -341,7 +340,7 @@ class DETR(DetectionBaseNet):
341
340
  )
342
341
  self.pos_enc = PositionEmbeddingSine(hidden_dim // 2, normalize=True)
343
342
 
344
- self.matcher = HungarianMatcher(cost_class=1, cost_bbox=5, cost_giou=2)
343
+ self.matcher = HungarianMatcher(cost_class=1.0, cost_bbox=5.0, cost_giou=2.0)
345
344
  empty_weight = torch.ones(self.num_classes)
346
345
  empty_weight[0] = 0.1
347
346
  self.empty_weight = nn.Buffer(empty_weight)
@@ -365,7 +364,8 @@ class DETR(DetectionBaseNet):
365
364
  for param in self.class_embed.parameters():
366
365
  param.requires_grad_(True)
367
366
 
368
- def _get_src_permutation_idx(self, indices: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
367
+ @staticmethod
368
+ def _get_src_permutation_idx(indices: list[tuple[torch.Tensor, torch.Tensor]]) -> tuple[torch.Tensor, torch.Tensor]:
369
369
  batch_idx = torch.concat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
370
370
  src_idx = torch.concat([src for (src, _) in indices])
371
371
  return (batch_idx, src_idx)
@@ -374,7 +374,7 @@ class DETR(DetectionBaseNet):
374
374
  self,
375
375
  cls_logits: torch.Tensor,
376
376
  targets: list[dict[str, torch.Tensor]],
377
- indices: list[torch.Tensor],
377
+ indices: list[tuple[torch.Tensor, torch.Tensor]],
378
378
  ) -> torch.Tensor:
379
379
  idx = self._get_src_permutation_idx(indices)
380
380
  target_classes_o = torch.concat([t["labels"][J] for t, (_, J) in zip(targets, indices)], dim=0)
@@ -388,7 +388,7 @@ class DETR(DetectionBaseNet):
388
388
  self,
389
389
  box_output: torch.Tensor,
390
390
  targets: list[dict[str, torch.Tensor]],
391
- indices: list[torch.Tensor],
391
+ indices: list[tuple[torch.Tensor, torch.Tensor]],
392
392
  num_boxes: int,
393
393
  ) -> tuple[torch.Tensor, torch.Tensor]:
394
394
  idx = self._get_src_permutation_idx(indices)
@@ -422,7 +422,7 @@ class DETR(DetectionBaseNet):
422
422
  if training_utils.is_dist_available_and_initialized() is True:
423
423
  torch.distributed.all_reduce(num_boxes)
424
424
 
425
- num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1).item()
425
+ num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1)
426
426
 
427
427
  loss_ce_list = []
428
428
  loss_bbox_list = []
@@ -447,20 +447,17 @@ class DETR(DetectionBaseNet):
447
447
  return losses
448
448
 
449
449
  def postprocess_detections(
450
- self, class_logits: torch.Tensor, box_regression: torch.Tensor, image_shapes: list[tuple[int, int]]
450
+ self, class_logits: torch.Tensor, box_regression: torch.Tensor, image_sizes: torch.Tensor
451
451
  ) -> list[dict[str, torch.Tensor]]:
452
452
  prob = F.softmax(class_logits, -1)
453
453
  scores, labels = prob[..., 1:].max(-1)
454
454
  labels = labels + 1
455
455
 
456
- # TorchScript doesn't support creating tensor from tuples, convert everything to lists
457
- target_sizes = torch.tensor([list(s) for s in image_shapes], device=class_logits.device)
458
-
459
456
  # Convert to [x0, y0, x1, y1] format
460
457
  boxes = box_ops.box_convert(box_regression, in_fmt="cxcywh", out_fmt="xyxy")
461
458
 
462
459
  # Convert from relative [0, 1] to absolute [0, height] coordinates
463
- img_h, img_w = target_sizes.unbind(1)
460
+ img_h, img_w = image_sizes.unbind(1)
464
461
  scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
465
462
  boxes = boxes * scale_fct[:, None, :]
466
463
 
@@ -485,16 +482,7 @@ class DETR(DetectionBaseNet):
485
482
 
486
483
  return detections
487
484
 
488
- def forward(
489
- self,
490
- x: torch.Tensor,
491
- targets: Optional[list[dict[str, torch.Tensor]]] = None,
492
- masks: Optional[torch.Tensor] = None,
493
- image_sizes: Optional[list[list[int]]] = None,
494
- ) -> tuple[list[dict[str, torch.Tensor]], dict[str, torch.Tensor]]:
495
- self._input_check(targets)
496
- images = self._to_img_list(x, image_sizes)
497
-
485
+ def forward_net(self, x: torch.Tensor, masks: Optional[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
498
486
  features: dict[str, torch.Tensor] = self.backbone.detection_features(x)
499
487
  x = features[self.backbone.return_stages[-1]]
500
488
  if masks is not None:
@@ -505,6 +493,20 @@ class DETR(DetectionBaseNet):
505
493
  outputs_class = self.class_embed(hs)
506
494
  outputs_coord = self.bbox_embed(hs).sigmoid()
507
495
 
496
+ return (outputs_class, outputs_coord)
497
+
498
+ def forward(
499
+ self,
500
+ x: torch.Tensor,
501
+ targets: Optional[list[dict[str, torch.Tensor]]] = None,
502
+ masks: Optional[torch.Tensor] = None,
503
+ image_sizes: Optional[list[tuple[int, int]]] = None,
504
+ ) -> tuple[list[dict[str, torch.Tensor]], dict[str, torch.Tensor]]:
505
+ self._input_check(targets)
506
+ image_sizes_tensor = self._to_img_list(x, image_sizes).image_sizes
507
+
508
+ outputs_class, outputs_coord = self.forward_net(x, masks)
509
+
508
510
  losses = {}
509
511
  detections: list[dict[str, torch.Tensor]] = []
510
512
  if self.training is True:
@@ -514,13 +516,14 @@ class DETR(DetectionBaseNet):
514
516
  for idx, target in enumerate(targets):
515
517
  boxes = target["boxes"]
516
518
  boxes = box_ops.box_convert(boxes, in_fmt="xyxy", out_fmt="cxcywh")
517
- boxes = boxes / torch.tensor(images.image_sizes[idx][::-1] * 2, dtype=torch.float32, device=x.device)
519
+ scale = image_sizes_tensor[idx].flip(0).repeat(2).float() # flip to [W, H], repeat to [W, H, W, H]
520
+ boxes = boxes / scale
518
521
  targets[idx]["boxes"] = boxes
519
522
 
520
523
  losses = self.compute_loss(targets, outputs_class, outputs_coord)
521
524
 
522
525
  else:
523
- detections = self.postprocess_detections(outputs_class[-1], outputs_coord[-1], images.image_sizes)
526
+ detections = self.postprocess_detections(outputs_class[-1], outputs_coord[-1], image_sizes_tensor)
524
527
 
525
528
  return (detections, losses)
526
529