birder 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/__init__.py +2 -0
- birder/common/fs_ops.py +81 -1
- birder/common/training_cli.py +12 -2
- birder/common/training_utils.py +73 -12
- birder/data/collators/detection.py +3 -1
- birder/datahub/_lib.py +15 -6
- birder/datahub/evaluation.py +591 -0
- birder/eval/__init__.py +0 -0
- birder/eval/__main__.py +74 -0
- birder/eval/_embeddings.py +50 -0
- birder/eval/adversarial.py +315 -0
- birder/eval/benchmarks/__init__.py +0 -0
- birder/eval/benchmarks/awa2.py +357 -0
- birder/eval/benchmarks/bioscan5m.py +198 -0
- birder/eval/benchmarks/fishnet.py +318 -0
- birder/eval/benchmarks/flowers102.py +210 -0
- birder/eval/benchmarks/fungiclef.py +261 -0
- birder/eval/benchmarks/nabirds.py +202 -0
- birder/eval/benchmarks/newt.py +262 -0
- birder/eval/benchmarks/plankton.py +255 -0
- birder/eval/benchmarks/plantdoc.py +259 -0
- birder/eval/benchmarks/plantnet.py +252 -0
- birder/eval/classification.py +235 -0
- birder/eval/methods/__init__.py +0 -0
- birder/eval/methods/ami.py +78 -0
- birder/eval/methods/knn.py +71 -0
- birder/eval/methods/linear.py +152 -0
- birder/eval/methods/mlp.py +178 -0
- birder/eval/methods/simpleshot.py +100 -0
- birder/eval/methods/svm.py +92 -0
- birder/inference/classification.py +23 -2
- birder/inference/detection.py +35 -15
- birder/net/_vit_configs.py +5 -0
- birder/net/cait.py +3 -3
- birder/net/coat.py +3 -3
- birder/net/cswin_transformer.py +2 -1
- birder/net/deit.py +1 -1
- birder/net/deit3.py +1 -1
- birder/net/detection/__init__.py +2 -0
- birder/net/detection/base.py +41 -18
- birder/net/detection/deformable_detr.py +74 -50
- birder/net/detection/detr.py +29 -26
- birder/net/detection/efficientdet.py +42 -25
- birder/net/detection/faster_rcnn.py +53 -21
- birder/net/detection/fcos.py +42 -23
- birder/net/detection/lw_detr.py +1204 -0
- birder/net/detection/plain_detr.py +60 -47
- birder/net/detection/retinanet.py +47 -35
- birder/net/detection/rt_detr_v1.py +49 -46
- birder/net/detection/rt_detr_v2.py +95 -102
- birder/net/detection/ssd.py +47 -31
- birder/net/detection/ssdlite.py +2 -2
- birder/net/detection/yolo_v2.py +33 -18
- birder/net/detection/yolo_v3.py +35 -33
- birder/net/detection/yolo_v4.py +35 -20
- birder/net/detection/yolo_v4_tiny.py +1 -2
- birder/net/edgevit.py +3 -3
- birder/net/efficientvit_msft.py +1 -1
- birder/net/flexivit.py +1 -1
- birder/net/hiera.py +44 -67
- birder/net/hieradet.py +2 -2
- birder/net/maxvit.py +2 -2
- birder/net/mim/fcmae.py +2 -2
- birder/net/mim/mae_hiera.py +9 -16
- birder/net/mnasnet.py +2 -2
- birder/net/nextvit.py +4 -4
- birder/net/resnext.py +2 -2
- birder/net/rope_deit3.py +2 -2
- birder/net/rope_flexivit.py +2 -2
- birder/net/rope_vit.py +2 -2
- birder/net/simple_vit.py +1 -1
- birder/net/squeezenet.py +1 -1
- birder/net/ssl/capi.py +32 -25
- birder/net/ssl/dino_v2.py +12 -15
- birder/net/ssl/franca.py +26 -19
- birder/net/van.py +2 -2
- birder/net/vit.py +21 -3
- birder/net/vit_parallel.py +1 -1
- birder/net/vit_sam.py +62 -16
- birder/net/xcit.py +1 -1
- birder/ops/msda.py +46 -16
- birder/scripts/benchmark.py +35 -8
- birder/scripts/predict.py +14 -1
- birder/scripts/predict_detection.py +7 -1
- birder/scripts/train.py +27 -11
- birder/scripts/train_capi.py +13 -10
- birder/scripts/train_detection.py +18 -7
- birder/scripts/train_franca.py +10 -2
- birder/scripts/train_kd.py +28 -11
- birder/tools/adversarial.py +5 -0
- birder/tools/convert_model.py +101 -43
- birder/tools/quantize_model.py +33 -16
- birder/version.py +1 -1
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/METADATA +17 -10
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/RECORD +99 -75
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/WHEEL +1 -1
- birder/scripts/evaluate.py +0 -176
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/entry_points.txt +0 -0
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/licenses/LICENSE +0 -0
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/top_level.txt +0 -0
|
@@ -147,7 +147,7 @@ class MultiScaleDeformableAttention(nn.Module):
|
|
|
147
147
|
param.requires_grad_(False)
|
|
148
148
|
|
|
149
149
|
def reset_parameters(self) -> None:
|
|
150
|
-
nn.init.
|
|
150
|
+
nn.init.zeros_(self.sampling_offsets.weight)
|
|
151
151
|
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
|
|
152
152
|
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
|
|
153
153
|
grid_init = grid_init / grid_init.abs().max(-1, keepdim=True)[0]
|
|
@@ -158,25 +158,27 @@ class MultiScaleDeformableAttention(nn.Module):
|
|
|
158
158
|
with torch.no_grad():
|
|
159
159
|
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
|
|
160
160
|
|
|
161
|
-
nn.init.
|
|
162
|
-
nn.init.
|
|
161
|
+
nn.init.zeros_(self.attention_weights.weight)
|
|
162
|
+
nn.init.zeros_(self.attention_weights.bias)
|
|
163
163
|
nn.init.xavier_uniform_(self.value_proj.weight)
|
|
164
|
-
nn.init.
|
|
164
|
+
nn.init.zeros_(self.value_proj.bias)
|
|
165
165
|
nn.init.xavier_uniform_(self.output_proj.weight)
|
|
166
|
-
nn.init.
|
|
166
|
+
nn.init.zeros_(self.output_proj.bias)
|
|
167
167
|
|
|
168
|
+
# pylint: disable=too-many-locals
|
|
168
169
|
def forward(
|
|
169
170
|
self,
|
|
170
171
|
query: torch.Tensor,
|
|
171
172
|
reference_points: torch.Tensor,
|
|
172
173
|
input_flatten: torch.Tensor,
|
|
173
174
|
input_spatial_shapes: torch.Tensor,
|
|
175
|
+
src_shapes: list[list[int]],
|
|
174
176
|
input_level_start_index: torch.Tensor,
|
|
175
177
|
input_padding_mask: Optional[torch.Tensor] = None,
|
|
176
178
|
) -> torch.Tensor:
|
|
177
|
-
|
|
179
|
+
num_queries = query.size(1)
|
|
178
180
|
N, sequence_length, _ = input_flatten.size()
|
|
179
|
-
assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == sequence_length
|
|
181
|
+
# assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == sequence_length
|
|
180
182
|
|
|
181
183
|
value = self.value_proj(input_flatten)
|
|
182
184
|
if input_padding_mask is not None:
|
|
@@ -231,7 +233,7 @@ class MultiScaleDeformableAttention(nn.Module):
|
|
|
231
233
|
|
|
232
234
|
if self.method == "discrete":
|
|
233
235
|
output = self._forward_fallback(
|
|
234
|
-
value, input_spatial_shapes, sampling_locations, attention_weights, method="discrete"
|
|
236
|
+
value, input_spatial_shapes, src_shapes, sampling_locations, attention_weights, method="discrete"
|
|
235
237
|
)
|
|
236
238
|
else:
|
|
237
239
|
if self.uniform_points is True:
|
|
@@ -245,10 +247,11 @@ class MultiScaleDeformableAttention(nn.Module):
|
|
|
245
247
|
sampling_locations,
|
|
246
248
|
attention_weights,
|
|
247
249
|
self.im2col_step,
|
|
250
|
+
src_shapes,
|
|
248
251
|
)
|
|
249
252
|
else:
|
|
250
253
|
output = self._forward_fallback(
|
|
251
|
-
value, input_spatial_shapes, sampling_locations, attention_weights, method="default"
|
|
254
|
+
value, input_spatial_shapes, src_shapes, sampling_locations, attention_weights, method="default"
|
|
252
255
|
)
|
|
253
256
|
|
|
254
257
|
output = self.output_proj(output)
|
|
@@ -258,6 +261,7 @@ class MultiScaleDeformableAttention(nn.Module):
|
|
|
258
261
|
self,
|
|
259
262
|
value: torch.Tensor,
|
|
260
263
|
spatial_shapes: torch.Tensor,
|
|
264
|
+
src_shapes: list[list[int]],
|
|
261
265
|
sampling_locations: torch.Tensor,
|
|
262
266
|
attention_weights: torch.Tensor,
|
|
263
267
|
method: str = "default",
|
|
@@ -272,8 +276,7 @@ class MultiScaleDeformableAttention(nn.Module):
|
|
|
272
276
|
sampling_locations_list = sampling_grids.split(self.num_points, dim=-2)
|
|
273
277
|
|
|
274
278
|
sampling_value_list = []
|
|
275
|
-
|
|
276
|
-
for level, (H, W) in enumerate(spatial_shapes_list):
|
|
279
|
+
for level, (H, W) in enumerate(src_shapes):
|
|
277
280
|
value_l = value_list[level].reshape(B * n_heads, head_dim, H, W)
|
|
278
281
|
sampling_grid_l = sampling_locations_list[level]
|
|
279
282
|
|
|
@@ -361,21 +364,21 @@ class TransformerDecoderLayer(nn.Module):
|
|
|
361
364
|
reference_points: torch.Tensor,
|
|
362
365
|
src: torch.Tensor,
|
|
363
366
|
src_spatial_shapes: torch.Tensor,
|
|
367
|
+
src_shapes: list[list[int]],
|
|
364
368
|
level_start_index: torch.Tensor,
|
|
365
369
|
src_padding_mask: Optional[torch.Tensor],
|
|
366
370
|
self_attn_mask: Optional[torch.Tensor] = None,
|
|
367
371
|
) -> torch.Tensor:
|
|
368
372
|
# Self attention
|
|
369
|
-
|
|
370
|
-
k = tgt + query_pos
|
|
373
|
+
q_k = tgt + query_pos
|
|
371
374
|
|
|
372
|
-
tgt2 = self.self_attn(
|
|
375
|
+
tgt2 = self.self_attn(q_k, q_k, tgt, attn_mask=self_attn_mask)
|
|
373
376
|
tgt = tgt + self.dropout(tgt2)
|
|
374
377
|
tgt = self.norm1(tgt)
|
|
375
378
|
|
|
376
379
|
# Cross attention
|
|
377
380
|
tgt2 = self.cross_attn(
|
|
378
|
-
tgt + query_pos, reference_points, src, src_spatial_shapes, level_start_index, src_padding_mask
|
|
381
|
+
tgt + query_pos, reference_points, src, src_spatial_shapes, src_shapes, level_start_index, src_padding_mask
|
|
379
382
|
)
|
|
380
383
|
tgt = tgt + self.dropout(tgt2)
|
|
381
384
|
tgt = self.norm2(tgt)
|
|
@@ -526,18 +529,18 @@ class RT_DETRDecoder(nn.Module):
|
|
|
526
529
|
|
|
527
530
|
# Gather reference points
|
|
528
531
|
reference_points_unact = enc_outputs_coord_unact.gather(
|
|
529
|
-
dim=1, index=topk_ind.unsqueeze(-1).
|
|
532
|
+
dim=1, index=topk_ind.unsqueeze(-1).expand(-1, -1, enc_outputs_coord_unact.shape[-1])
|
|
530
533
|
)
|
|
531
534
|
|
|
532
535
|
enc_topk_bboxes = reference_points_unact.sigmoid()
|
|
533
536
|
|
|
534
537
|
# Gather encoder logits for loss computation
|
|
535
538
|
enc_topk_logits = enc_outputs_class.gather(
|
|
536
|
-
dim=1, index=topk_ind.unsqueeze(-1).
|
|
539
|
+
dim=1, index=topk_ind.unsqueeze(-1).expand(-1, -1, enc_outputs_class.shape[-1])
|
|
537
540
|
)
|
|
538
541
|
|
|
539
542
|
# Extract region features
|
|
540
|
-
target = output_memory.gather(dim=1, index=topk_ind.unsqueeze(-1).
|
|
543
|
+
target = output_memory.gather(dim=1, index=topk_ind.unsqueeze(-1).expand(-1, -1, output_memory.shape[-1]))
|
|
541
544
|
target = target.detach()
|
|
542
545
|
|
|
543
546
|
return (target, reference_points_unact.detach(), enc_topk_bboxes, enc_topk_logits)
|
|
@@ -551,6 +554,7 @@ class RT_DETRDecoder(nn.Module):
|
|
|
551
554
|
denoising_bbox_unact: Optional[torch.Tensor] = None,
|
|
552
555
|
attn_mask: Optional[torch.Tensor] = None,
|
|
553
556
|
padding_mask: Optional[list[torch.Tensor]] = None,
|
|
557
|
+
return_intermediates: bool = True,
|
|
554
558
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
555
559
|
memory = []
|
|
556
560
|
mask_flatten = []
|
|
@@ -578,18 +582,19 @@ class RT_DETRDecoder(nn.Module):
|
|
|
578
582
|
level_start_index_tensor = torch.tensor(level_start_index, dtype=torch.long, device=memory.device)
|
|
579
583
|
|
|
580
584
|
# Decoder forward
|
|
581
|
-
|
|
582
|
-
|
|
585
|
+
bboxes_list: list[torch.Tensor] = []
|
|
586
|
+
logits_list: list[torch.Tensor] = []
|
|
583
587
|
reference_points = init_ref_points_unact.sigmoid()
|
|
584
588
|
for decoder_layer, bbox_head, class_head in zip(self.layers, self.bbox_embed, self.class_embed):
|
|
585
589
|
query_pos = self.query_pos_head(reference_points)
|
|
586
|
-
reference_points_input = reference_points.unsqueeze(2).
|
|
590
|
+
reference_points_input = reference_points.unsqueeze(2).expand(-1, -1, len(spatial_shapes), -1)
|
|
587
591
|
target = decoder_layer(
|
|
588
592
|
target,
|
|
589
593
|
query_pos,
|
|
590
594
|
reference_points_input,
|
|
591
595
|
memory,
|
|
592
596
|
spatial_shapes_tensor,
|
|
597
|
+
spatial_shapes,
|
|
593
598
|
level_start_index_tensor,
|
|
594
599
|
memory_padding_mask,
|
|
595
600
|
attn_mask,
|
|
@@ -602,14 +607,19 @@ class RT_DETRDecoder(nn.Module):
|
|
|
602
607
|
# Classification
|
|
603
608
|
class_logits = class_head(target)
|
|
604
609
|
|
|
605
|
-
|
|
606
|
-
|
|
610
|
+
if return_intermediates is True:
|
|
611
|
+
bboxes_list.append(new_reference_points)
|
|
612
|
+
logits_list.append(class_logits)
|
|
607
613
|
|
|
608
614
|
# Update reference points for next layer
|
|
609
615
|
reference_points = new_reference_points.detach()
|
|
610
616
|
|
|
611
|
-
|
|
612
|
-
|
|
617
|
+
if return_intermediates is True:
|
|
618
|
+
out_bboxes = torch.stack(bboxes_list)
|
|
619
|
+
out_logits = torch.stack(logits_list)
|
|
620
|
+
else:
|
|
621
|
+
out_bboxes = new_reference_points
|
|
622
|
+
out_logits = class_logits
|
|
613
623
|
|
|
614
624
|
return (out_bboxes, out_logits, enc_topk_bboxes, enc_topk_logits)
|
|
615
625
|
|
|
@@ -675,7 +685,7 @@ class RT_DETR_v2(DetectionBaseNet):
|
|
|
675
685
|
self.decoder = RT_DETRDecoder(
|
|
676
686
|
hidden_dim=hidden_dim,
|
|
677
687
|
num_classes=self.num_classes,
|
|
678
|
-
num_queries=num_queries,
|
|
688
|
+
num_queries=self.num_queries,
|
|
679
689
|
num_decoder_layers=num_decoder_layers,
|
|
680
690
|
num_levels=self.num_levels,
|
|
681
691
|
num_heads=num_heads,
|
|
@@ -744,20 +754,32 @@ class RT_DETR_v2(DetectionBaseNet):
|
|
|
744
754
|
for param in self.denoising_class_embed.parameters():
|
|
745
755
|
param.requires_grad_(True)
|
|
746
756
|
|
|
747
|
-
|
|
757
|
+
@staticmethod
|
|
758
|
+
def _get_src_permutation_idx(indices: list[tuple[torch.Tensor, torch.Tensor]]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
748
759
|
batch_idx = torch.concat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
|
|
749
760
|
src_idx = torch.concat([src for (src, _) in indices])
|
|
750
761
|
return (batch_idx, src_idx)
|
|
751
762
|
|
|
752
|
-
def
|
|
763
|
+
def _compute_layer_losses(
|
|
753
764
|
self,
|
|
754
765
|
cls_logits: torch.Tensor,
|
|
755
766
|
box_output: torch.Tensor,
|
|
756
767
|
targets: list[dict[str, torch.Tensor]],
|
|
757
|
-
indices: list[torch.Tensor],
|
|
768
|
+
indices: list[tuple[torch.Tensor, torch.Tensor]],
|
|
758
769
|
num_boxes: float,
|
|
759
|
-
) -> torch.Tensor:
|
|
770
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
760
771
|
idx = self._get_src_permutation_idx(indices)
|
|
772
|
+
|
|
773
|
+
src_boxes = box_output[idx]
|
|
774
|
+
target_boxes = torch.concat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
|
775
|
+
|
|
776
|
+
src_boxes_xyxy = box_ops.box_convert(src_boxes, in_fmt="cxcywh", out_fmt="xyxy")
|
|
777
|
+
target_boxes_xyxy = box_ops.box_convert(target_boxes, in_fmt="cxcywh", out_fmt="xyxy")
|
|
778
|
+
|
|
779
|
+
# IoU for varifocal loss (class loss)
|
|
780
|
+
ious = torch.diag(box_ops.box_iou(src_boxes_xyxy, target_boxes_xyxy)).detach()
|
|
781
|
+
|
|
782
|
+
# Classification loss
|
|
761
783
|
target_classes_o = torch.concat([t["labels"][J] for t, (_, J) in zip(targets, indices)], dim=0)
|
|
762
784
|
target_classes = torch.full(cls_logits.shape[:2], self.num_classes, dtype=torch.int64, device=cls_logits.device)
|
|
763
785
|
target_classes[idx] = target_classes_o
|
|
@@ -771,15 +793,6 @@ class RT_DETR_v2(DetectionBaseNet):
|
|
|
771
793
|
target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
|
|
772
794
|
target_classes_onehot = target_classes_onehot[:, :, :-1]
|
|
773
795
|
|
|
774
|
-
src_boxes = box_output[idx]
|
|
775
|
-
target_boxes = torch.concat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
|
776
|
-
ious = torch.diag(
|
|
777
|
-
box_ops.box_iou(
|
|
778
|
-
box_ops.box_convert(src_boxes, in_fmt="cxcywh", out_fmt="xyxy"),
|
|
779
|
-
box_ops.box_convert(target_boxes, in_fmt="cxcywh", out_fmt="xyxy"),
|
|
780
|
-
)
|
|
781
|
-
).detach()
|
|
782
|
-
|
|
783
796
|
target_score_o = torch.zeros(cls_logits.shape[:2], dtype=cls_logits.dtype, device=cls_logits.device)
|
|
784
797
|
target_score_o[idx] = ious.to(cls_logits.dtype)
|
|
785
798
|
target_score = target_score_o.unsqueeze(-1) * target_classes_onehot
|
|
@@ -787,31 +800,13 @@ class RT_DETR_v2(DetectionBaseNet):
|
|
|
787
800
|
loss = varifocal_loss(cls_logits, target_score, target_classes_onehot, alpha=0.75, gamma=2.0)
|
|
788
801
|
loss_ce = (loss.mean(1).sum() / num_boxes) * cls_logits.shape[1]
|
|
789
802
|
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
def _box_loss(
|
|
793
|
-
self,
|
|
794
|
-
box_output: torch.Tensor,
|
|
795
|
-
targets: list[dict[str, torch.Tensor]],
|
|
796
|
-
indices: list[torch.Tensor],
|
|
797
|
-
num_boxes: float,
|
|
798
|
-
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
799
|
-
idx = self._get_src_permutation_idx(indices)
|
|
800
|
-
src_boxes = box_output[idx]
|
|
801
|
-
target_boxes = torch.concat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
|
802
|
-
|
|
803
|
-
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none")
|
|
804
|
-
loss_bbox = loss_bbox.sum() / num_boxes
|
|
803
|
+
# Box L1 loss
|
|
804
|
+
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none").sum() / num_boxes
|
|
805
805
|
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
box_ops.box_convert(src_boxes, in_fmt="cxcywh", out_fmt="xyxy"),
|
|
809
|
-
box_ops.box_convert(target_boxes, in_fmt="cxcywh", out_fmt="xyxy"),
|
|
810
|
-
)
|
|
811
|
-
)
|
|
812
|
-
loss_giou = loss_giou.sum() / num_boxes
|
|
806
|
+
# GIoU loss
|
|
807
|
+
loss_giou = (1 - torch.diag(box_ops.generalized_box_iou(src_boxes_xyxy, target_boxes_xyxy))).sum() / num_boxes
|
|
813
808
|
|
|
814
|
-
return (loss_bbox, loss_giou)
|
|
809
|
+
return (loss_ce, loss_bbox, loss_giou)
|
|
815
810
|
|
|
816
811
|
def _compute_denoising_loss(
|
|
817
812
|
self,
|
|
@@ -846,11 +841,9 @@ class RT_DETR_v2(DetectionBaseNet):
|
|
|
846
841
|
)
|
|
847
842
|
)
|
|
848
843
|
|
|
849
|
-
loss_ce = self.
|
|
844
|
+
loss_ce, loss_bbox, loss_giou = self._compute_layer_losses(
|
|
850
845
|
dn_out_logits[layer_idx], dn_out_bboxes[layer_idx], targets, indices, dn_num_boxes
|
|
851
846
|
)
|
|
852
|
-
loss_bbox, loss_giou = self._box_loss(dn_out_bboxes[layer_idx], targets, indices, dn_num_boxes)
|
|
853
|
-
|
|
854
847
|
loss_ce_list.append(loss_ce)
|
|
855
848
|
loss_bbox_list.append(loss_bbox)
|
|
856
849
|
loss_giou_list.append(loss_giou)
|
|
@@ -861,9 +854,7 @@ class RT_DETR_v2(DetectionBaseNet):
|
|
|
861
854
|
|
|
862
855
|
return (loss_ce_dn, loss_bbox_dn, loss_giou_dn)
|
|
863
856
|
|
|
864
|
-
|
|
865
|
-
@torch.compiler.disable() # type: ignore[untyped-decorator]
|
|
866
|
-
def _compute_loss_from_outputs( # pylint: disable=too-many-locals
|
|
857
|
+
def _compute_loss_from_outputs(
|
|
867
858
|
self,
|
|
868
859
|
targets: list[dict[str, torch.Tensor]],
|
|
869
860
|
out_bboxes: torch.Tensor,
|
|
@@ -880,7 +871,7 @@ class RT_DETR_v2(DetectionBaseNet):
|
|
|
880
871
|
if training_utils.is_dist_available_and_initialized() is True:
|
|
881
872
|
torch.distributed.all_reduce(num_boxes)
|
|
882
873
|
|
|
883
|
-
num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1)
|
|
874
|
+
num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1)
|
|
884
875
|
|
|
885
876
|
loss_ce_list = []
|
|
886
877
|
loss_bbox_list = []
|
|
@@ -889,19 +880,21 @@ class RT_DETR_v2(DetectionBaseNet):
|
|
|
889
880
|
# Decoder losses (all layers)
|
|
890
881
|
for layer_idx in range(out_logits.shape[0]):
|
|
891
882
|
indices = self.matcher(out_logits[layer_idx], out_bboxes[layer_idx], targets)
|
|
892
|
-
loss_ce = self.
|
|
893
|
-
|
|
883
|
+
loss_ce, loss_bbox, loss_giou = self._compute_layer_losses(
|
|
884
|
+
out_logits[layer_idx], out_bboxes[layer_idx], targets, indices, num_boxes
|
|
885
|
+
)
|
|
894
886
|
loss_ce_list.append(loss_ce)
|
|
895
887
|
loss_bbox_list.append(loss_bbox)
|
|
896
888
|
loss_giou_list.append(loss_giou)
|
|
897
889
|
|
|
898
890
|
# Encoder auxiliary loss
|
|
899
891
|
enc_indices = self.matcher(enc_topk_logits, enc_topk_bboxes, targets)
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
892
|
+
loss_ce, loss_bbox, loss_giou = self._compute_layer_losses(
|
|
893
|
+
enc_topk_logits, enc_topk_bboxes, targets, enc_indices, num_boxes
|
|
894
|
+
)
|
|
895
|
+
loss_ce_list.append(loss_ce)
|
|
896
|
+
loss_bbox_list.append(loss_bbox)
|
|
897
|
+
loss_giou_list.append(loss_giou)
|
|
905
898
|
|
|
906
899
|
loss_ce = torch.stack(loss_ce_list).sum() # VFL weight is 1
|
|
907
900
|
loss_bbox = torch.stack(loss_bbox_list).sum() * 5
|
|
@@ -935,11 +928,11 @@ class RT_DETR_v2(DetectionBaseNet):
|
|
|
935
928
|
images: Any,
|
|
936
929
|
masks: Optional[list[torch.Tensor]] = None,
|
|
937
930
|
) -> dict[str, torch.Tensor]:
|
|
938
|
-
device = encoder_features[0].device
|
|
939
931
|
for idx, target in enumerate(targets):
|
|
940
932
|
boxes = target["boxes"]
|
|
941
933
|
boxes = box_ops.box_convert(boxes, in_fmt="xyxy", out_fmt="cxcywh")
|
|
942
|
-
|
|
934
|
+
scale = images.image_sizes[idx].flip(0).repeat(2).float() # flip to [W, H], repeat to [W, H, W, H]
|
|
935
|
+
boxes = boxes / scale
|
|
943
936
|
targets[idx]["boxes"] = boxes
|
|
944
937
|
targets[idx]["labels"] = target["labels"] - 1 # No background
|
|
945
938
|
|
|
@@ -972,7 +965,7 @@ class RT_DETR_v2(DetectionBaseNet):
|
|
|
972
965
|
return losses
|
|
973
966
|
|
|
974
967
|
def postprocess_detections(
|
|
975
|
-
self, class_logits: torch.Tensor, box_regression: torch.Tensor,
|
|
968
|
+
self, class_logits: torch.Tensor, box_regression: torch.Tensor, image_sizes: torch.Tensor
|
|
976
969
|
) -> list[dict[str, torch.Tensor]]:
|
|
977
970
|
prob = class_logits.sigmoid()
|
|
978
971
|
topk_values, topk_indexes = torch.topk(prob.view(class_logits.shape[0], -1), k=self.decoder.num_queries, dim=1)
|
|
@@ -981,14 +974,12 @@ class RT_DETR_v2(DetectionBaseNet):
|
|
|
981
974
|
labels = topk_indexes % class_logits.shape[2]
|
|
982
975
|
labels += 1 # Background offset
|
|
983
976
|
|
|
984
|
-
target_sizes = torch.tensor(image_shapes, device=class_logits.device)
|
|
985
|
-
|
|
986
977
|
# Convert to [x0, y0, x1, y1] format
|
|
987
978
|
boxes = box_ops.box_convert(box_regression, in_fmt="cxcywh", out_fmt="xyxy")
|
|
988
|
-
boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).
|
|
979
|
+
boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).expand(-1, -1, 4))
|
|
989
980
|
|
|
990
981
|
# Convert from relative [0, 1] to absolute [0, height] coordinates
|
|
991
|
-
img_h, img_w =
|
|
982
|
+
img_h, img_w = image_sizes.unbind(1)
|
|
992
983
|
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
|
|
993
984
|
boxes = boxes * scale_fct[:, None, :]
|
|
994
985
|
|
|
@@ -1024,32 +1015,34 @@ class RT_DETR_v2(DetectionBaseNet):
|
|
|
1024
1015
|
|
|
1025
1016
|
return (None, None, None, None)
|
|
1026
1017
|
|
|
1018
|
+
def forward_net(
|
|
1019
|
+
self, x: torch.Tensor, masks: Optional[torch.Tensor]
|
|
1020
|
+
) -> tuple[list[torch.Tensor], Optional[list[torch.Tensor]]]:
|
|
1021
|
+
features: dict[str, torch.Tensor] = self.backbone.detection_features(x)
|
|
1022
|
+
feature_list = list(features.values())
|
|
1023
|
+
|
|
1024
|
+
mask_list: Optional[list[torch.Tensor]] = None
|
|
1025
|
+
if masks is not None:
|
|
1026
|
+
mask_list = []
|
|
1027
|
+
for feat in feature_list:
|
|
1028
|
+
m = F.interpolate(masks[None].float(), size=feat.shape[-2:], mode="nearest").to(torch.bool)[0]
|
|
1029
|
+
mask_list.append(m)
|
|
1030
|
+
|
|
1031
|
+
encoder_features = self.encoder(feature_list, masks=mask_list)
|
|
1032
|
+
|
|
1033
|
+
return (encoder_features, mask_list)
|
|
1034
|
+
|
|
1027
1035
|
def forward(
|
|
1028
1036
|
self,
|
|
1029
1037
|
x: torch.Tensor,
|
|
1030
1038
|
targets: Optional[list[dict[str, torch.Tensor]]] = None,
|
|
1031
1039
|
masks: Optional[torch.Tensor] = None,
|
|
1032
|
-
image_sizes: Optional[list[
|
|
1040
|
+
image_sizes: Optional[list[tuple[int, int]]] = None,
|
|
1033
1041
|
) -> tuple[list[dict[str, torch.Tensor]], dict[str, torch.Tensor]]:
|
|
1034
1042
|
self._input_check(targets)
|
|
1035
1043
|
images = self._to_img_list(x, image_sizes)
|
|
1036
1044
|
|
|
1037
|
-
|
|
1038
|
-
features: dict[str, torch.Tensor] = self.backbone.detection_features(x)
|
|
1039
|
-
feature_list = list(features.values())
|
|
1040
|
-
|
|
1041
|
-
# Hybrid encoder
|
|
1042
|
-
mask_list: list[torch.Tensor] = []
|
|
1043
|
-
for feat in feature_list:
|
|
1044
|
-
if masks is not None:
|
|
1045
|
-
mask_size = feat.shape[-2:]
|
|
1046
|
-
m = F.interpolate(masks[None].float(), size=mask_size, mode="nearest").to(torch.bool)[0]
|
|
1047
|
-
else:
|
|
1048
|
-
B, _, H, W = feat.size()
|
|
1049
|
-
m = torch.zeros(B, H, W, dtype=torch.bool, device=x.device)
|
|
1050
|
-
mask_list.append(m)
|
|
1051
|
-
|
|
1052
|
-
encoder_features = self.encoder(feature_list, masks=mask_list)
|
|
1045
|
+
encoder_features, mask_list = self.forward_net(x, masks)
|
|
1053
1046
|
|
|
1054
1047
|
# Prepare spatial shapes and level start index
|
|
1055
1048
|
spatial_shapes: list[list[int]] = []
|
|
@@ -1070,9 +1063,9 @@ class RT_DETR_v2(DetectionBaseNet):
|
|
|
1070
1063
|
else:
|
|
1071
1064
|
# Inference path - no CDN
|
|
1072
1065
|
out_bboxes, out_logits, _, _ = self.decoder(
|
|
1073
|
-
encoder_features, spatial_shapes, level_start_index, padding_mask=mask_list
|
|
1066
|
+
encoder_features, spatial_shapes, level_start_index, padding_mask=mask_list, return_intermediates=False
|
|
1074
1067
|
)
|
|
1075
|
-
detections = self.postprocess_detections(out_logits
|
|
1068
|
+
detections = self.postprocess_detections(out_logits, out_bboxes, images.image_sizes)
|
|
1076
1069
|
|
|
1077
1070
|
return (detections, losses)
|
|
1078
1071
|
|
birder/net/detection/ssd.py
CHANGED
|
@@ -30,6 +30,7 @@ from birder.net.detection.base import BoxCoder
|
|
|
30
30
|
from birder.net.detection.base import DetectionBaseNet
|
|
31
31
|
from birder.net.detection.base import ImageList
|
|
32
32
|
from birder.net.detection.base import Matcher
|
|
33
|
+
from birder.net.detection.base import clip_boxes_to_image
|
|
33
34
|
|
|
34
35
|
|
|
35
36
|
class SSDMatcher(Matcher):
|
|
@@ -303,6 +304,12 @@ class SSD(DetectionBaseNet):
|
|
|
303
304
|
topk_candidates = 400
|
|
304
305
|
positive_fraction = 0.25
|
|
305
306
|
|
|
307
|
+
self.score_thresh = score_thresh
|
|
308
|
+
self.nms_thresh = nms_thresh
|
|
309
|
+
self.detections_per_img = detections_per_img
|
|
310
|
+
self.topk_candidates = topk_candidates
|
|
311
|
+
self.neg_to_pos_ratio = (1.0 - positive_fraction) / positive_fraction
|
|
312
|
+
|
|
306
313
|
self.backbone.return_channels = self.backbone.return_channels[-2:]
|
|
307
314
|
self.backbone.return_stages = self.backbone.return_stages[-2:]
|
|
308
315
|
self.extra_blocks = nn.ModuleList(
|
|
@@ -325,11 +332,8 @@ class SSD(DetectionBaseNet):
|
|
|
325
332
|
self.head = SSDHead(self.backbone.return_channels + [512, 256, 256, 256], num_anchors, self.num_classes)
|
|
326
333
|
self.proposal_matcher = SSDMatcher(iou_thresh)
|
|
327
334
|
|
|
328
|
-
self.
|
|
329
|
-
|
|
330
|
-
self.detections_per_img = detections_per_img
|
|
331
|
-
self.topk_candidates = topk_candidates
|
|
332
|
-
self.neg_to_pos_ratio = (1.0 - positive_fraction) / positive_fraction
|
|
335
|
+
if self.export_mode is False:
|
|
336
|
+
self.forward = torch.compiler.disable(recursive=False)(self.forward) # type: ignore[method-assign]
|
|
333
337
|
|
|
334
338
|
def reset_classifier(self, num_classes: int) -> None:
|
|
335
339
|
self.num_classes = num_classes + 1
|
|
@@ -348,6 +352,8 @@ class SSD(DetectionBaseNet):
|
|
|
348
352
|
param.requires_grad_(True)
|
|
349
353
|
|
|
350
354
|
# pylint: disable=too-many-locals
|
|
355
|
+
@torch.jit.unused # type: ignore[untyped-decorator]
|
|
356
|
+
@torch.compiler.disable() # type: ignore[untyped-decorator]
|
|
351
357
|
def compute_loss(
|
|
352
358
|
self,
|
|
353
359
|
targets: list[dict[str, torch.Tensor]],
|
|
@@ -423,7 +429,7 @@ class SSD(DetectionBaseNet):
|
|
|
423
429
|
self,
|
|
424
430
|
head_outputs: dict[str, torch.Tensor],
|
|
425
431
|
image_anchors: list[torch.Tensor],
|
|
426
|
-
|
|
432
|
+
image_sizes: torch.Tensor,
|
|
427
433
|
) -> list[dict[str, torch.Tensor]]:
|
|
428
434
|
bbox_regression = head_outputs["bbox_regression"]
|
|
429
435
|
pred_scores = F.softmax(head_outputs["cls_logits"], dim=-1)
|
|
@@ -431,11 +437,10 @@ class SSD(DetectionBaseNet):
|
|
|
431
437
|
num_classes = pred_scores.size(-1)
|
|
432
438
|
device = pred_scores.device
|
|
433
439
|
detections: list[dict[str, torch.Tensor]] = []
|
|
434
|
-
for boxes, scores, anchors, image_shape in zip(bbox_regression, pred_scores, image_anchors,
|
|
440
|
+
for boxes, scores, anchors, image_shape in zip(bbox_regression, pred_scores, image_anchors, image_sizes):
|
|
435
441
|
boxes = self.box_coder.decode_single(boxes, anchors)
|
|
436
|
-
boxes =
|
|
442
|
+
boxes = clip_boxes_to_image(boxes, image_shape)
|
|
437
443
|
|
|
438
|
-
list_empty = True
|
|
439
444
|
image_boxes_list = []
|
|
440
445
|
image_scores_list = []
|
|
441
446
|
image_labels_list = []
|
|
@@ -447,51 +452,62 @@ class SSD(DetectionBaseNet):
|
|
|
447
452
|
box = boxes[keep_idxs]
|
|
448
453
|
|
|
449
454
|
# Keep only topk scoring predictions
|
|
450
|
-
num_topk = min(self.topk_candidates,
|
|
455
|
+
num_topk = min(self.topk_candidates, score.size(0))
|
|
451
456
|
score, idxs = score.topk(num_topk)
|
|
452
457
|
box = box[idxs]
|
|
453
|
-
if len(box) == 0 and list_empty is False:
|
|
454
|
-
continue
|
|
455
458
|
|
|
456
459
|
image_boxes_list.append(box)
|
|
457
460
|
image_scores_list.append(score)
|
|
458
461
|
image_labels_list.append(torch.full_like(score, fill_value=label, dtype=torch.int64, device=device))
|
|
459
|
-
list_empty = False
|
|
460
462
|
|
|
461
463
|
image_boxes = torch.concat(image_boxes_list, dim=0)
|
|
462
464
|
image_scores = torch.concat(image_scores_list, dim=0)
|
|
463
465
|
image_labels = torch.concat(image_labels_list, dim=0)
|
|
464
466
|
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
467
|
+
if self.export_mode is False:
|
|
468
|
+
# Non-maximum suppression
|
|
469
|
+
keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
|
|
470
|
+
keep = keep[: self.detections_per_img]
|
|
471
|
+
|
|
472
|
+
detections.append(
|
|
473
|
+
{
|
|
474
|
+
"boxes": image_boxes[keep],
|
|
475
|
+
"scores": image_scores[keep],
|
|
476
|
+
"labels": image_labels[keep],
|
|
477
|
+
}
|
|
478
|
+
)
|
|
479
|
+
else:
|
|
480
|
+
detections.append(
|
|
481
|
+
{
|
|
482
|
+
"boxes": image_boxes,
|
|
483
|
+
"scores": image_scores,
|
|
484
|
+
"labels": image_labels,
|
|
485
|
+
}
|
|
486
|
+
)
|
|
476
487
|
|
|
477
488
|
return detections
|
|
478
489
|
|
|
490
|
+
def forward_net(self, x: torch.Tensor) -> tuple[list[torch.Tensor], dict[str, torch.Tensor]]:
|
|
491
|
+
features = self.backbone.detection_features(x)
|
|
492
|
+
feature_list = list(features.values())
|
|
493
|
+
for extra_block in self.extra_blocks:
|
|
494
|
+
feature_list.append(extra_block(feature_list[-1]))
|
|
495
|
+
|
|
496
|
+
head_outputs = self.head(feature_list)
|
|
497
|
+
|
|
498
|
+
return (feature_list, head_outputs)
|
|
499
|
+
|
|
479
500
|
def forward(
|
|
480
501
|
self,
|
|
481
502
|
x: torch.Tensor,
|
|
482
503
|
targets: Optional[list[dict[str, torch.Tensor]]] = None,
|
|
483
504
|
masks: Optional[torch.Tensor] = None,
|
|
484
|
-
image_sizes: Optional[list[
|
|
505
|
+
image_sizes: Optional[list[tuple[int, int]]] = None,
|
|
485
506
|
) -> tuple[list[dict[str, torch.Tensor]], dict[str, torch.Tensor]]:
|
|
486
507
|
self._input_check(targets)
|
|
487
508
|
images = self._to_img_list(x, image_sizes)
|
|
488
509
|
|
|
489
|
-
|
|
490
|
-
feature_list = list(features.values())
|
|
491
|
-
for extra_block in self.extra_blocks:
|
|
492
|
-
feature_list.append(extra_block(feature_list[-1]))
|
|
493
|
-
|
|
494
|
-
head_outputs = self.head(feature_list)
|
|
510
|
+
feature_list, head_outputs = self.forward_net(x)
|
|
495
511
|
anchors = self.anchor_generator(images, feature_list)
|
|
496
512
|
|
|
497
513
|
losses = {}
|
birder/net/detection/ssdlite.py
CHANGED
|
@@ -50,7 +50,7 @@ class SSDLiteClassificationHead(SSDScoringHead):
|
|
|
50
50
|
if isinstance(layer, nn.Conv2d):
|
|
51
51
|
nn.init.xavier_uniform_(layer.weight)
|
|
52
52
|
if layer.bias is not None:
|
|
53
|
-
nn.init.
|
|
53
|
+
nn.init.zeros_(layer.bias)
|
|
54
54
|
|
|
55
55
|
super().__init__(cls_logits, num_classes)
|
|
56
56
|
|
|
@@ -79,7 +79,7 @@ class SSDLiteRegressionHead(SSDScoringHead):
|
|
|
79
79
|
if isinstance(layer, nn.Conv2d):
|
|
80
80
|
nn.init.xavier_uniform_(layer.weight)
|
|
81
81
|
if layer.bias is not None:
|
|
82
|
-
nn.init.
|
|
82
|
+
nn.init.zeros_(layer.bias)
|
|
83
83
|
|
|
84
84
|
super().__init__(bbox_reg, 4)
|
|
85
85
|
|