birder 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/adversarial/base.py +1 -1
- birder/adversarial/simba.py +4 -4
- birder/common/cli.py +1 -1
- birder/common/fs_ops.py +13 -13
- birder/common/lib.py +2 -2
- birder/common/masking.py +3 -3
- birder/common/training_cli.py +24 -2
- birder/common/training_utils.py +28 -4
- birder/data/collators/detection.py +9 -1
- birder/data/transforms/detection.py +27 -8
- birder/data/transforms/mosaic.py +1 -1
- birder/datahub/classification.py +3 -3
- birder/inference/classification.py +3 -3
- birder/inference/data_parallel.py +1 -1
- birder/inference/detection.py +5 -5
- birder/inference/wbf.py +1 -1
- birder/introspection/attention_rollout.py +6 -6
- birder/introspection/feature_pca.py +4 -4
- birder/introspection/gradcam.py +1 -1
- birder/introspection/guided_backprop.py +2 -2
- birder/introspection/transformer_attribution.py +4 -4
- birder/layers/attention_pool.py +2 -2
- birder/layers/layer_scale.py +1 -1
- birder/model_registry/model_registry.py +2 -1
- birder/net/__init__.py +4 -10
- birder/net/_rope_vit_configs.py +435 -0
- birder/net/_vit_configs.py +466 -0
- birder/net/alexnet.py +5 -5
- birder/net/base.py +28 -3
- birder/net/biformer.py +18 -17
- birder/net/cait.py +7 -7
- birder/net/cas_vit.py +1 -1
- birder/net/coat.py +27 -27
- birder/net/conv2former.py +3 -3
- birder/net/convmixer.py +1 -1
- birder/net/convnext_v1.py +3 -11
- birder/net/convnext_v1_iso.py +198 -0
- birder/net/convnext_v2.py +2 -10
- birder/net/crossformer.py +9 -9
- birder/net/crossvit.py +6 -6
- birder/net/cspnet.py +1 -1
- birder/net/cswin_transformer.py +10 -10
- birder/net/davit.py +11 -11
- birder/net/deit.py +68 -29
- birder/net/deit3.py +69 -204
- birder/net/densenet.py +9 -8
- birder/net/detection/__init__.py +4 -0
- birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
- birder/net/detection/base.py +6 -5
- birder/net/detection/deformable_detr.py +31 -30
- birder/net/detection/detr.py +14 -11
- birder/net/detection/efficientdet.py +10 -29
- birder/net/detection/faster_rcnn.py +22 -22
- birder/net/detection/fcos.py +8 -8
- birder/net/detection/plain_detr.py +852 -0
- birder/net/detection/retinanet.py +4 -4
- birder/net/detection/rt_detr_v1.py +81 -25
- birder/net/detection/rt_detr_v2.py +1147 -0
- birder/net/detection/ssd.py +5 -5
- birder/net/detection/yolo_v2.py +12 -12
- birder/net/detection/yolo_v3.py +19 -19
- birder/net/detection/yolo_v4.py +16 -16
- birder/net/detection/yolo_v4_tiny.py +3 -3
- birder/net/dpn.py +1 -2
- birder/net/edgenext.py +5 -4
- birder/net/edgevit.py +13 -14
- birder/net/efficientformer_v1.py +3 -2
- birder/net/efficientformer_v2.py +18 -31
- birder/net/efficientnet_v2.py +3 -0
- birder/net/efficientvim.py +9 -9
- birder/net/efficientvit_mit.py +7 -7
- birder/net/efficientvit_msft.py +3 -3
- birder/net/fasternet.py +3 -3
- birder/net/fastvit.py +5 -12
- birder/net/flexivit.py +50 -58
- birder/net/focalnet.py +5 -9
- birder/net/gc_vit.py +11 -11
- birder/net/ghostnet_v1.py +1 -1
- birder/net/ghostnet_v2.py +1 -1
- birder/net/groupmixformer.py +13 -13
- birder/net/hgnet_v1.py +6 -6
- birder/net/hgnet_v2.py +4 -4
- birder/net/hiera.py +6 -6
- birder/net/hieradet.py +9 -9
- birder/net/hornet.py +3 -3
- birder/net/iformer.py +4 -4
- birder/net/inception_next.py +5 -15
- birder/net/inception_resnet_v1.py +3 -3
- birder/net/inception_resnet_v2.py +7 -4
- birder/net/inception_v3.py +3 -0
- birder/net/inception_v4.py +3 -0
- birder/net/levit.py +3 -3
- birder/net/lit_v1.py +13 -15
- birder/net/lit_v1_tiny.py +9 -9
- birder/net/lit_v2.py +14 -15
- birder/net/maxvit.py +11 -23
- birder/net/metaformer.py +5 -5
- birder/net/mim/crossmae.py +6 -6
- birder/net/mim/fcmae.py +3 -5
- birder/net/mim/mae_hiera.py +7 -7
- birder/net/mim/mae_vit.py +4 -6
- birder/net/mim/simmim.py +3 -4
- birder/net/mobilenet_v1.py +0 -9
- birder/net/mobilenet_v2.py +38 -44
- birder/net/{mobilenet_v3_large.py → mobilenet_v3.py} +37 -10
- birder/net/mobilenet_v4_hybrid.py +4 -4
- birder/net/mobileone.py +5 -12
- birder/net/mobilevit_v1.py +7 -34
- birder/net/mobilevit_v2.py +6 -54
- birder/net/moganet.py +8 -5
- birder/net/mvit_v2.py +30 -30
- birder/net/nextvit.py +2 -2
- birder/net/nfnet.py +4 -0
- birder/net/pit.py +11 -26
- birder/net/pvt_v1.py +9 -9
- birder/net/pvt_v2.py +10 -16
- birder/net/regionvit.py +15 -15
- birder/net/regnet.py +1 -1
- birder/net/repghost.py +5 -35
- birder/net/repvgg.py +3 -5
- birder/net/repvit.py +2 -2
- birder/net/resmlp.py +2 -2
- birder/net/resnest.py +4 -1
- birder/net/resnet_v1.py +125 -1
- birder/net/resnet_v2.py +75 -1
- birder/net/resnext.py +35 -1
- birder/net/rope_deit3.py +62 -151
- birder/net/rope_flexivit.py +46 -33
- birder/net/rope_vit.py +44 -758
- birder/net/sequencer2d.py +3 -4
- birder/net/shufflenet_v1.py +1 -1
- birder/net/shufflenet_v2.py +1 -1
- birder/net/simple_vit.py +69 -21
- birder/net/smt.py +8 -8
- birder/net/squeezenet.py +5 -12
- birder/net/squeezenext.py +0 -24
- birder/net/ssl/barlow_twins.py +1 -1
- birder/net/ssl/byol.py +2 -2
- birder/net/ssl/capi.py +4 -4
- birder/net/ssl/data2vec.py +1 -1
- birder/net/ssl/data2vec2.py +1 -1
- birder/net/ssl/dino_v2.py +13 -3
- birder/net/ssl/franca.py +28 -4
- birder/net/ssl/i_jepa.py +5 -5
- birder/net/ssl/ibot.py +1 -1
- birder/net/ssl/mmcr.py +1 -1
- birder/net/swiftformer.py +13 -3
- birder/net/swin_transformer_v1.py +4 -5
- birder/net/swin_transformer_v2.py +5 -8
- birder/net/tiny_vit.py +6 -19
- birder/net/transnext.py +19 -19
- birder/net/uniformer.py +4 -4
- birder/net/van.py +2 -2
- birder/net/vgg.py +1 -10
- birder/net/vit.py +72 -987
- birder/net/vit_parallel.py +35 -20
- birder/net/vit_sam.py +23 -48
- birder/net/vovnet_v2.py +1 -1
- birder/net/xcit.py +16 -13
- birder/ops/msda.py +4 -4
- birder/ops/swattention.py +10 -10
- birder/results/classification.py +3 -3
- birder/results/gui.py +8 -8
- birder/scripts/benchmark.py +37 -12
- birder/scripts/evaluate.py +1 -1
- birder/scripts/predict.py +3 -3
- birder/scripts/predict_detection.py +2 -2
- birder/scripts/train.py +63 -15
- birder/scripts/train_barlow_twins.py +10 -7
- birder/scripts/train_byol.py +10 -7
- birder/scripts/train_capi.py +15 -10
- birder/scripts/train_data2vec.py +10 -7
- birder/scripts/train_data2vec2.py +10 -7
- birder/scripts/train_detection.py +29 -14
- birder/scripts/train_dino_v1.py +13 -9
- birder/scripts/train_dino_v2.py +27 -14
- birder/scripts/train_dino_v2_dist.py +28 -15
- birder/scripts/train_franca.py +16 -9
- birder/scripts/train_i_jepa.py +12 -9
- birder/scripts/train_ibot.py +15 -11
- birder/scripts/train_kd.py +64 -17
- birder/scripts/train_mim.py +11 -8
- birder/scripts/train_mmcr.py +11 -8
- birder/scripts/train_rotnet.py +11 -7
- birder/scripts/train_simclr.py +10 -7
- birder/scripts/train_vicreg.py +10 -7
- birder/tools/adversarial.py +4 -4
- birder/tools/auto_anchors.py +5 -5
- birder/tools/avg_model.py +1 -1
- birder/tools/convert_model.py +30 -22
- birder/tools/det_results.py +1 -1
- birder/tools/download_model.py +1 -1
- birder/tools/ensemble_model.py +1 -1
- birder/tools/introspection.py +12 -3
- birder/tools/labelme_to_coco.py +2 -2
- birder/tools/model_info.py +15 -15
- birder/tools/pack.py +8 -8
- birder/tools/quantize_model.py +53 -4
- birder/tools/results.py +2 -2
- birder/tools/show_det_iterator.py +19 -6
- birder/tools/show_iterator.py +2 -2
- birder/tools/similarity.py +5 -5
- birder/tools/stats.py +4 -6
- birder/tools/voc_to_coco.py +1 -1
- birder/version.py +1 -1
- {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
- birder-0.4.1.dist-info/RECORD +300 -0
- {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
- birder/net/mobilenet_v3_small.py +0 -43
- birder/net/se_resnet_v1.py +0 -105
- birder/net/se_resnet_v2.py +0 -59
- birder/net/se_resnext.py +0 -30
- birder-0.3.3.dist-info/RECORD +0 -299
- {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
- {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -120,7 +120,7 @@ class RetinaNetClassificationHead(nn.Module):
|
|
|
120
120
|
cls_logits = self.cls_logits(cls_logits)
|
|
121
121
|
|
|
122
122
|
# Permute classification output from (N, A * K, H, W) to (N, HWA, K).
|
|
123
|
-
|
|
123
|
+
N, _, H, W = cls_logits.shape
|
|
124
124
|
cls_logits = cls_logits.view(N, -1, self.num_classes, H, W)
|
|
125
125
|
cls_logits = cls_logits.permute(0, 3, 4, 1, 2)
|
|
126
126
|
cls_logits = cls_logits.reshape(N, -1, self.num_classes) # Size=(N, HWA, K)
|
|
@@ -202,7 +202,7 @@ class RetinaNetRegressionHead(nn.Module):
|
|
|
202
202
|
bbox_regression = self.bbox_reg(bbox_regression)
|
|
203
203
|
|
|
204
204
|
# Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4).
|
|
205
|
-
|
|
205
|
+
N, _, H, W = bbox_regression.size()
|
|
206
206
|
bbox_regression = bbox_regression.view(N, -1, 4, H, W)
|
|
207
207
|
bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2)
|
|
208
208
|
bbox_regression = bbox_regression.reshape(N, -1, 4) # Size=(N, HWA, 4)
|
|
@@ -395,7 +395,7 @@ class RetinaNet(DetectionBaseNet):
|
|
|
395
395
|
|
|
396
396
|
# Keep only topk scoring predictions
|
|
397
397
|
num_topk = min(self.topk_candidates, int(topk_idxs.size(0)))
|
|
398
|
-
|
|
398
|
+
scores_per_level, idxs = scores_per_level.topk(num_topk)
|
|
399
399
|
topk_idxs = topk_idxs[idxs]
|
|
400
400
|
|
|
401
401
|
anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor")
|
|
@@ -417,7 +417,7 @@ class RetinaNet(DetectionBaseNet):
|
|
|
417
417
|
|
|
418
418
|
# Non-maximum suppression
|
|
419
419
|
if self.soft_nms is not None:
|
|
420
|
-
|
|
420
|
+
soft_scores, keep = self.soft_nms(image_boxes, image_scores, image_labels, score_threshold=0.001)
|
|
421
421
|
image_scores[keep] = soft_scores
|
|
422
422
|
else:
|
|
423
423
|
keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
"""
|
|
2
|
-
RT-DETR (Real-Time DEtection TRansformer), adapted from
|
|
2
|
+
RT-DETR v1 (Real-Time DEtection TRansformer), adapted from
|
|
3
3
|
https://github.com/lyuwenyu/RT-DETR/tree/main/rtdetr_pytorch
|
|
4
4
|
|
|
5
5
|
Paper "DETRs Beat YOLOs on Real-time Object Detection", https://arxiv.org/abs/2304.08069
|
|
@@ -114,7 +114,7 @@ def get_contrastive_denoising_training_group( # pylint: disable=too-many-locals
|
|
|
114
114
|
# Embed class labels
|
|
115
115
|
input_query_class = class_embed(input_query_class)
|
|
116
116
|
|
|
117
|
-
# Create attention mask
|
|
117
|
+
# Create attention mask (True = masked)
|
|
118
118
|
target_size = total_denoising_queries + num_queries
|
|
119
119
|
attn_mask = torch.zeros([target_size, target_size], dtype=torch.bool, device=device)
|
|
120
120
|
attn_mask[total_denoising_queries:, :total_denoising_queries] = True
|
|
@@ -212,10 +212,69 @@ class CSPRepLayer(nn.Module):
|
|
|
212
212
|
return self.conv3(x1 + x2)
|
|
213
213
|
|
|
214
214
|
|
|
215
|
+
class MultiheadAttention(nn.Module):
|
|
216
|
+
def __init__(self, d_model: int, num_heads: int, attn_drop: float = 0.0, proj_drop: float = 0.0) -> None:
|
|
217
|
+
super().__init__()
|
|
218
|
+
assert d_model % num_heads == 0, "d_model should be divisible by num_heads"
|
|
219
|
+
|
|
220
|
+
self.num_heads = num_heads
|
|
221
|
+
self.head_dim = d_model // num_heads
|
|
222
|
+
self.scale = self.head_dim**-0.5
|
|
223
|
+
|
|
224
|
+
self.q_proj = nn.Linear(d_model, d_model)
|
|
225
|
+
self.k_proj = nn.Linear(d_model, d_model)
|
|
226
|
+
self.v_proj = nn.Linear(d_model, d_model)
|
|
227
|
+
self.attn_drop = nn.Dropout(attn_drop)
|
|
228
|
+
self.proj = nn.Linear(d_model, d_model)
|
|
229
|
+
self.proj_drop = nn.Dropout(proj_drop)
|
|
230
|
+
|
|
231
|
+
self.reset_parameters()
|
|
232
|
+
|
|
233
|
+
def reset_parameters(self) -> None:
|
|
234
|
+
nn.init.xavier_uniform_(self.q_proj.weight)
|
|
235
|
+
nn.init.xavier_uniform_(self.k_proj.weight)
|
|
236
|
+
nn.init.xavier_uniform_(self.v_proj.weight)
|
|
237
|
+
nn.init.xavier_uniform_(self.proj.weight)
|
|
238
|
+
if self.q_proj.bias is not None:
|
|
239
|
+
nn.init.zeros_(self.q_proj.bias)
|
|
240
|
+
nn.init.zeros_(self.k_proj.bias)
|
|
241
|
+
nn.init.zeros_(self.v_proj.bias)
|
|
242
|
+
nn.init.zeros_(self.proj.bias)
|
|
243
|
+
|
|
244
|
+
def forward(
|
|
245
|
+
self,
|
|
246
|
+
query: torch.Tensor,
|
|
247
|
+
key: torch.Tensor,
|
|
248
|
+
value: torch.Tensor,
|
|
249
|
+
key_padding_mask: Optional[torch.Tensor] = None,
|
|
250
|
+
) -> torch.Tensor:
|
|
251
|
+
B, l_q, C = query.shape
|
|
252
|
+
q = self.q_proj(query).reshape(B, l_q, self.num_heads, self.head_dim).transpose(1, 2)
|
|
253
|
+
k = self.k_proj(key).reshape(B, key.size(1), self.num_heads, self.head_dim).transpose(1, 2)
|
|
254
|
+
v = self.v_proj(value).reshape(B, value.size(1), self.num_heads, self.head_dim).transpose(1, 2)
|
|
255
|
+
|
|
256
|
+
if key_padding_mask is not None:
|
|
257
|
+
# key_padding_mask is expected to be boolean (True = masked)
|
|
258
|
+
# SDPA expects True = attend, so we invert
|
|
259
|
+
attn_mask = ~key_padding_mask[:, None, None, :]
|
|
260
|
+
else:
|
|
261
|
+
attn_mask = None
|
|
262
|
+
|
|
263
|
+
attn = F.scaled_dot_product_attention( # pylint: disable=not-callable
|
|
264
|
+
q, k, v, attn_mask=attn_mask, dropout_p=self.attn_drop.p if self.training else 0.0, scale=self.scale
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
attn = attn.transpose(1, 2).reshape(B, l_q, C)
|
|
268
|
+
x = self.proj(attn)
|
|
269
|
+
x = self.proj_drop(x)
|
|
270
|
+
|
|
271
|
+
return x
|
|
272
|
+
|
|
273
|
+
|
|
215
274
|
class TransformerEncoderLayer(nn.Module):
|
|
216
275
|
def __init__(self, d_model: int, num_heads: int, dim_feedforward: int, dropout: float) -> None:
|
|
217
276
|
super().__init__()
|
|
218
|
-
self.self_attn =
|
|
277
|
+
self.self_attn = MultiheadAttention(d_model, num_heads, attn_drop=dropout)
|
|
219
278
|
|
|
220
279
|
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
|
221
280
|
self.dropout = nn.Dropout(dropout)
|
|
@@ -231,10 +290,8 @@ class TransformerEncoderLayer(nn.Module):
|
|
|
231
290
|
def forward(
|
|
232
291
|
self, src: torch.Tensor, pos: torch.Tensor, key_padding_mask: Optional[torch.Tensor] = None
|
|
233
292
|
) -> torch.Tensor:
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
(src2, _) = self.self_attn(q, k, value=src, key_padding_mask=key_padding_mask, need_weights=False)
|
|
293
|
+
qk = src + pos
|
|
294
|
+
src2 = self.self_attn(qk, qk, value=src, key_padding_mask=key_padding_mask)
|
|
238
295
|
src = src + self.dropout1(src2)
|
|
239
296
|
src = self.norm1(src)
|
|
240
297
|
|
|
@@ -268,7 +325,7 @@ class AIFI(nn.Module):
|
|
|
268
325
|
self._pos_cache.clear()
|
|
269
326
|
|
|
270
327
|
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
271
|
-
|
|
328
|
+
B, C, H, W = x.size()
|
|
272
329
|
x = x.flatten(2).permute(0, 2, 1)
|
|
273
330
|
|
|
274
331
|
use_cache = self.use_cache is True and torch.jit.is_tracing() is False and torch.jit.is_scripting() is False
|
|
@@ -522,7 +579,7 @@ class RT_DETRDecoder(nn.Module):
|
|
|
522
579
|
spatial_shapes: list[list[int]],
|
|
523
580
|
memory_padding_mask: Optional[torch.Tensor] = None,
|
|
524
581
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
525
|
-
|
|
582
|
+
anchors, valid_mask = self._generate_anchors(spatial_shapes, device=memory.device, dtype=memory.dtype)
|
|
526
583
|
if memory_padding_mask is not None:
|
|
527
584
|
valid_mask = valid_mask & ~memory_padding_mask.unsqueeze(-1)
|
|
528
585
|
|
|
@@ -535,7 +592,7 @@ class RT_DETRDecoder(nn.Module):
|
|
|
535
592
|
enc_outputs_coord_unact = self.enc_bbox_head(output_memory) + anchors
|
|
536
593
|
|
|
537
594
|
# Select top-k queries based on classification confidence
|
|
538
|
-
|
|
595
|
+
_, topk_ind = torch.topk(enc_outputs_class.max(dim=-1).values, self.num_queries, dim=1)
|
|
539
596
|
|
|
540
597
|
# Gather reference points
|
|
541
598
|
reference_points_unact = enc_outputs_coord_unact.gather(
|
|
@@ -577,7 +634,7 @@ class RT_DETRDecoder(nn.Module):
|
|
|
577
634
|
memory_padding_mask = torch.concat(mask_flatten, dim=1) if mask_flatten else None
|
|
578
635
|
|
|
579
636
|
# Get decoder input (query selection)
|
|
580
|
-
|
|
637
|
+
target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = self._get_decoder_input(
|
|
581
638
|
memory, spatial_shapes, memory_padding_mask
|
|
582
639
|
)
|
|
583
640
|
|
|
@@ -858,7 +915,7 @@ class RT_DETR_v1(DetectionBaseNet):
|
|
|
858
915
|
loss_ce = self._class_loss(
|
|
859
916
|
dn_out_logits[layer_idx], dn_out_bboxes[layer_idx], targets, indices, dn_num_boxes
|
|
860
917
|
)
|
|
861
|
-
|
|
918
|
+
loss_bbox, loss_giou = self._box_loss(dn_out_bboxes[layer_idx], targets, indices, dn_num_boxes)
|
|
862
919
|
|
|
863
920
|
loss_ce_list.append(loss_ce)
|
|
864
921
|
loss_bbox_list.append(loss_bbox)
|
|
@@ -899,7 +956,7 @@ class RT_DETR_v1(DetectionBaseNet):
|
|
|
899
956
|
for layer_idx in range(out_logits.shape[0]):
|
|
900
957
|
indices = self.matcher(out_logits[layer_idx], out_bboxes[layer_idx], targets)
|
|
901
958
|
loss_ce = self._class_loss(out_logits[layer_idx], out_bboxes[layer_idx], targets, indices, num_boxes)
|
|
902
|
-
|
|
959
|
+
loss_bbox, loss_giou = self._box_loss(out_bboxes[layer_idx], targets, indices, num_boxes)
|
|
903
960
|
loss_ce_list.append(loss_ce)
|
|
904
961
|
loss_bbox_list.append(loss_bbox)
|
|
905
962
|
loss_giou_list.append(loss_giou)
|
|
@@ -907,7 +964,7 @@ class RT_DETR_v1(DetectionBaseNet):
|
|
|
907
964
|
# Encoder auxiliary loss
|
|
908
965
|
enc_indices = self.matcher(enc_topk_logits, enc_topk_bboxes, targets)
|
|
909
966
|
loss_ce_enc = self._class_loss(enc_topk_logits, enc_topk_bboxes, targets, enc_indices, num_boxes)
|
|
910
|
-
|
|
967
|
+
loss_bbox_enc, loss_giou_enc = self._box_loss(enc_topk_bboxes, targets, enc_indices, num_boxes)
|
|
911
968
|
loss_ce_list.append(loss_ce_enc)
|
|
912
969
|
loss_bbox_list.append(loss_bbox_enc)
|
|
913
970
|
loss_giou_list.append(loss_giou_enc)
|
|
@@ -918,7 +975,7 @@ class RT_DETR_v1(DetectionBaseNet):
|
|
|
918
975
|
|
|
919
976
|
# Add denoising loss if available
|
|
920
977
|
if dn_out_bboxes is not None and dn_out_logits is not None and dn_meta is not None:
|
|
921
|
-
|
|
978
|
+
loss_ce_dn, loss_bbox_dn, loss_giou_dn = self._compute_denoising_loss(
|
|
922
979
|
dn_out_bboxes, dn_out_logits, targets, dn_meta, num_boxes
|
|
923
980
|
)
|
|
924
981
|
loss_ce = loss_ce + loss_ce_dn
|
|
@@ -952,9 +1009,9 @@ class RT_DETR_v1(DetectionBaseNet):
|
|
|
952
1009
|
targets[idx]["boxes"] = boxes
|
|
953
1010
|
targets[idx]["labels"] = target["labels"] - 1 # No background
|
|
954
1011
|
|
|
955
|
-
|
|
1012
|
+
denoising_class, denoising_bbox_unact, attn_mask, dn_meta = self._prepare_cdn_queries(targets)
|
|
956
1013
|
|
|
957
|
-
|
|
1014
|
+
out_bboxes, out_logits, enc_topk_bboxes, enc_topk_logits = self.decoder(
|
|
958
1015
|
encoder_features,
|
|
959
1016
|
spatial_shapes,
|
|
960
1017
|
level_start_index,
|
|
@@ -965,7 +1022,7 @@ class RT_DETR_v1(DetectionBaseNet):
|
|
|
965
1022
|
)
|
|
966
1023
|
|
|
967
1024
|
if dn_meta is not None:
|
|
968
|
-
|
|
1025
|
+
dn_num_split, _num_queries = dn_meta["dn_num_split"]
|
|
969
1026
|
dn_out_bboxes = out_bboxes[:, :, :dn_num_split]
|
|
970
1027
|
dn_out_logits = out_logits[:, :, :dn_num_split]
|
|
971
1028
|
out_bboxes = out_bboxes[:, :, dn_num_split:]
|
|
@@ -984,9 +1041,7 @@ class RT_DETR_v1(DetectionBaseNet):
|
|
|
984
1041
|
self, class_logits: torch.Tensor, box_regression: torch.Tensor, image_shapes: list[tuple[int, int]]
|
|
985
1042
|
) -> list[dict[str, torch.Tensor]]:
|
|
986
1043
|
prob = class_logits.sigmoid()
|
|
987
|
-
|
|
988
|
-
prob.view(class_logits.shape[0], -1), k=self.decoder.num_queries, dim=1
|
|
989
|
-
)
|
|
1044
|
+
topk_values, topk_indexes = torch.topk(prob.view(class_logits.shape[0], -1), k=self.decoder.num_queries, dim=1)
|
|
990
1045
|
scores = topk_values
|
|
991
1046
|
topk_boxes = topk_indexes // class_logits.shape[2]
|
|
992
1047
|
labels = topk_indexes % class_logits.shape[2]
|
|
@@ -999,7 +1054,7 @@ class RT_DETR_v1(DetectionBaseNet):
|
|
|
999
1054
|
boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
|
|
1000
1055
|
|
|
1001
1056
|
# Convert from relative [0, 1] to absolute [0, height] coordinates
|
|
1002
|
-
|
|
1057
|
+
img_h, img_w = target_sizes.unbind(1)
|
|
1003
1058
|
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
|
|
1004
1059
|
boxes = boxes * scale_fct[:, None, :]
|
|
1005
1060
|
|
|
@@ -1056,7 +1111,7 @@ class RT_DETR_v1(DetectionBaseNet):
|
|
|
1056
1111
|
mask_size = feat.shape[-2:]
|
|
1057
1112
|
m = F.interpolate(masks[None].float(), size=mask_size, mode="nearest").to(torch.bool)[0]
|
|
1058
1113
|
else:
|
|
1059
|
-
|
|
1114
|
+
B, _, H, W = feat.size()
|
|
1060
1115
|
m = torch.zeros(B, H, W, dtype=torch.bool, device=x.device)
|
|
1061
1116
|
mask_list.append(m)
|
|
1062
1117
|
|
|
@@ -1080,7 +1135,7 @@ class RT_DETR_v1(DetectionBaseNet):
|
|
|
1080
1135
|
losses = self.compute_loss(encoder_features, spatial_shapes, level_start_index, targets, images, mask_list)
|
|
1081
1136
|
else:
|
|
1082
1137
|
# Inference path - no CDN
|
|
1083
|
-
|
|
1138
|
+
out_bboxes, out_logits, _, _ = self.decoder(
|
|
1084
1139
|
encoder_features, spatial_shapes, level_start_index, padding_mask=mask_list
|
|
1085
1140
|
)
|
|
1086
1141
|
detections = self.postprocess_detections(out_logits[-1], out_bboxes[-1], images.image_sizes)
|
|
@@ -1100,6 +1155,7 @@ class RT_DETR_v1(DetectionBaseNet):
|
|
|
1100
1155
|
|
|
1101
1156
|
|
|
1102
1157
|
registry.register_model_config(
|
|
1103
|
-
"
|
|
1158
|
+
"rt_detr_v1_t", RT_DETR_v1, config={"num_decoder_layers": 3, "expansion": 0.5, "depth_multiplier": 0.33}
|
|
1104
1159
|
)
|
|
1160
|
+
registry.register_model_config("rt_detr_v1_s", RT_DETR_v1, config={"num_decoder_layers": 3, "expansion": 0.5})
|
|
1105
1161
|
registry.register_model_config("rt_detr_v1", RT_DETR_v1, config={"num_decoder_layers": 6})
|