birder 0.2.3__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/common/training_cli.py +6 -0
- birder/common/training_utils.py +215 -31
- birder/data/collators/detection.py +1 -0
- birder/data/dataloader/webdataset.py +12 -2
- birder/kernels/load_kernel.py +16 -11
- birder/kernels/soft_nms/soft_nms.cpp +17 -18
- birder/net/cait.py +4 -3
- birder/net/convnext_v1.py +5 -0
- birder/net/crossformer.py +33 -30
- birder/net/crossvit.py +4 -3
- birder/net/deit.py +3 -3
- birder/net/deit3.py +3 -3
- birder/net/detection/deformable_detr.py +2 -5
- birder/net/detection/detr.py +2 -5
- birder/net/detection/efficientdet.py +2 -7
- birder/net/detection/fcos.py +2 -7
- birder/net/detection/retinanet.py +2 -7
- birder/net/detection/rt_detr_v1.py +1 -0
- birder/net/efficientformer_v1.py +15 -9
- birder/net/efficientformer_v2.py +39 -29
- birder/net/efficientvit_msft.py +9 -7
- birder/net/fastvit.py +1 -0
- birder/net/flexivit.py +5 -4
- birder/net/hiera.py +12 -9
- birder/net/hornet.py +9 -7
- birder/net/iformer.py +8 -6
- birder/net/levit.py +42 -30
- birder/net/lit_v1_tiny.py +15 -0
- birder/net/maxvit.py +67 -55
- birder/net/mobileone.py +1 -0
- birder/net/mvit_v2.py +13 -12
- birder/net/pit.py +4 -3
- birder/net/pvt_v1.py +4 -1
- birder/net/repghost.py +1 -0
- birder/net/repvgg.py +1 -0
- birder/net/repvit.py +1 -0
- birder/net/rope_deit3.py +5 -3
- birder/net/rope_flexivit.py +7 -4
- birder/net/rope_vit.py +10 -5
- birder/net/simple_vit.py +9 -6
- birder/net/swin_transformer_v1.py +71 -68
- birder/net/swin_transformer_v2.py +38 -31
- birder/net/tiny_vit.py +20 -10
- birder/net/transnext.py +38 -28
- birder/net/vit.py +5 -4
- birder/net/vit_parallel.py +5 -4
- birder/net/vit_sam.py +38 -37
- birder/net/vovnet_v1.py +15 -0
- birder/ops/msda.py +108 -43
- birder/ops/swattention.py +124 -61
- birder/results/detection.py +4 -0
- birder/scripts/benchmark.py +21 -12
- birder/scripts/predict.py +7 -0
- birder/scripts/train.py +39 -13
- birder/scripts/train_barlow_twins.py +35 -12
- birder/scripts/train_byol.py +35 -12
- birder/scripts/train_capi.py +41 -15
- birder/scripts/train_data2vec.py +37 -14
- birder/scripts/train_data2vec2.py +37 -14
- birder/scripts/train_detection.py +36 -11
- birder/scripts/train_dino_v1.py +51 -14
- birder/scripts/train_dino_v2.py +78 -19
- birder/scripts/train_dino_v2_dist.py +76 -17
- birder/scripts/train_franca.py +43 -19
- birder/scripts/train_i_jepa.py +37 -14
- birder/scripts/train_ibot.py +43 -20
- birder/scripts/train_kd.py +39 -13
- birder/scripts/train_mim.py +35 -12
- birder/scripts/train_mmcr.py +35 -12
- birder/scripts/train_rotnet.py +36 -13
- birder/scripts/train_simclr.py +35 -12
- birder/scripts/train_vicreg.py +35 -12
- birder/tools/convert_model.py +18 -15
- birder/tools/det_results.py +114 -2
- birder/tools/quantize_model.py +73 -67
- birder/version.py +1 -1
- {birder-0.2.3.dist-info → birder-0.3.0.dist-info}/METADATA +2 -1
- {birder-0.2.3.dist-info → birder-0.3.0.dist-info}/RECORD +82 -82
- {birder-0.2.3.dist-info → birder-0.3.0.dist-info}/WHEEL +0 -0
- {birder-0.2.3.dist-info → birder-0.3.0.dist-info}/entry_points.txt +0 -0
- {birder-0.2.3.dist-info → birder-0.3.0.dist-info}/licenses/LICENSE +0 -0
- {birder-0.2.3.dist-info → birder-0.3.0.dist-info}/top_level.txt +0 -0
birder/net/detection/fcos.py
CHANGED
|
@@ -455,13 +455,8 @@ class FCOS(DetectionBaseNet):
|
|
|
455
455
|
|
|
456
456
|
# Non-maximum suppression
|
|
457
457
|
if self.soft_nms is not None:
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
(soft_scores, keep) = self.soft_nms(
|
|
461
|
-
image_boxes.cpu(), image_scores.cpu(), image_labels.cpu(), score_threshold=0.001
|
|
462
|
-
)
|
|
463
|
-
keep = keep.to(device)
|
|
464
|
-
image_scores[keep] = soft_scores.to(device)
|
|
458
|
+
(soft_scores, keep) = self.soft_nms(image_boxes, image_scores, image_labels, score_threshold=0.001)
|
|
459
|
+
image_scores[keep] = soft_scores
|
|
465
460
|
else:
|
|
466
461
|
keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
|
|
467
462
|
|
|
@@ -417,13 +417,8 @@ class RetinaNet(DetectionBaseNet):
|
|
|
417
417
|
|
|
418
418
|
# Non-maximum suppression
|
|
419
419
|
if self.soft_nms is not None:
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
(soft_scores, keep) = self.soft_nms(
|
|
423
|
-
image_boxes.cpu(), image_scores.cpu(), image_labels.cpu(), score_threshold=0.001
|
|
424
|
-
)
|
|
425
|
-
keep = keep.to(device)
|
|
426
|
-
image_scores[keep] = soft_scores.to(device)
|
|
420
|
+
(soft_scores, keep) = self.soft_nms(image_boxes, image_scores, image_labels, score_threshold=0.001)
|
|
421
|
+
image_scores[keep] = soft_scores
|
|
427
422
|
else:
|
|
428
423
|
keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
|
|
429
424
|
|
birder/net/efficientformer_v1.py
CHANGED
|
@@ -357,16 +357,22 @@ class EfficientFormer_v1(BaseNet):
|
|
|
357
357
|
resolution = (int(new_size[0] / (2**5)), int(new_size[1] / (2**5)))
|
|
358
358
|
for m in self.body.modules():
|
|
359
359
|
if isinstance(m, Attention):
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
360
|
+
with torch.no_grad():
|
|
361
|
+
m.attention_biases = nn.Parameter(
|
|
362
|
+
interpolate_attention_bias(m.attention_biases, old_resolution, resolution)
|
|
363
|
+
)
|
|
363
364
|
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
365
|
+
device = m.attention_biases.device
|
|
366
|
+
pos = torch.stack(
|
|
367
|
+
torch.meshgrid(
|
|
368
|
+
torch.arange(resolution[0], device=device),
|
|
369
|
+
torch.arange(resolution[1], device=device),
|
|
370
|
+
indexing="ij",
|
|
371
|
+
)
|
|
372
|
+
).flatten(1)
|
|
373
|
+
rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
|
|
374
|
+
rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
|
|
375
|
+
m.attention_bias_idxs = nn.Buffer(rel_pos)
|
|
370
376
|
|
|
371
377
|
|
|
372
378
|
registry.register_model_config(
|
birder/net/efficientformer_v2.py
CHANGED
|
@@ -554,26 +554,30 @@ class EfficientFormer_v2(DetectorBackbone):
|
|
|
554
554
|
attn.N = attn.resolution[0] * attn.resolution[1]
|
|
555
555
|
attn.N2 = attn.resolution2[0] * attn.resolution2[1]
|
|
556
556
|
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
k_pos = torch.stack(
|
|
563
|
-
torch.meshgrid(
|
|
564
|
-
torch.arange(attn.resolution[0]), torch.arange(attn.resolution[1]), indexing="ij"
|
|
557
|
+
with torch.no_grad():
|
|
558
|
+
# Interpolate attention_biases
|
|
559
|
+
attn.attention_biases = nn.Parameter(
|
|
560
|
+
interpolate_attention_bias(attn.attention_biases, old_base, new_base)
|
|
565
561
|
)
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
torch.
|
|
569
|
-
torch.
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
562
|
+
|
|
563
|
+
device = attn.attention_biases.device
|
|
564
|
+
k_pos = torch.stack(
|
|
565
|
+
torch.meshgrid(
|
|
566
|
+
torch.arange(attn.resolution[0], device=device),
|
|
567
|
+
torch.arange(attn.resolution[1], device=device),
|
|
568
|
+
indexing="ij",
|
|
569
|
+
)
|
|
570
|
+
).flatten(1)
|
|
571
|
+
q_pos = torch.stack(
|
|
572
|
+
torch.meshgrid(
|
|
573
|
+
torch.arange(0, attn.resolution[0], step=2, device=device),
|
|
574
|
+
torch.arange(0, attn.resolution[1], step=2, device=device),
|
|
575
|
+
indexing="ij",
|
|
576
|
+
)
|
|
577
|
+
).flatten(1)
|
|
578
|
+
rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs()
|
|
579
|
+
rel_pos = (rel_pos[0] * attn.resolution[1]) + rel_pos[1]
|
|
580
|
+
attn.attention_bias_idxs = nn.Buffer(rel_pos.to(torch.long), persistent=False)
|
|
577
581
|
|
|
578
582
|
old_base = (old_base[0] // 2, old_base[1] // 2)
|
|
579
583
|
new_base = (new_base[0] // 2, new_base[1] // 2)
|
|
@@ -590,16 +594,22 @@ class EfficientFormer_v2(DetectorBackbone):
|
|
|
590
594
|
m.token_mixer.resolution = c_new_base
|
|
591
595
|
m.token_mixer.N = m.token_mixer.resolution[0] * m.token_mixer.resolution[1]
|
|
592
596
|
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
597
|
+
with torch.no_grad():
|
|
598
|
+
m.token_mixer.attention_biases = nn.Parameter(
|
|
599
|
+
interpolate_attention_bias(m.token_mixer.attention_biases, c_old_base, c_new_base)
|
|
600
|
+
)
|
|
601
|
+
|
|
602
|
+
device = m.token_mixer.attention_biases.device
|
|
603
|
+
pos = torch.stack(
|
|
604
|
+
torch.meshgrid(
|
|
605
|
+
torch.arange(c_new_base[0], device=device),
|
|
606
|
+
torch.arange(c_new_base[1], device=device),
|
|
607
|
+
indexing="ij",
|
|
608
|
+
)
|
|
609
|
+
).flatten(1)
|
|
610
|
+
rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
|
|
611
|
+
rel_pos = (rel_pos[0] * c_new_base[1]) + rel_pos[1]
|
|
612
|
+
m.token_mixer.attention_bias_idxs = nn.Buffer(rel_pos.to(torch.long), persistent=False)
|
|
603
613
|
|
|
604
614
|
|
|
605
615
|
registry.register_model_config(
|
birder/net/efficientvit_msft.py
CHANGED
|
@@ -497,14 +497,16 @@ class EfficientViT_MSFT(DetectorBackbone):
|
|
|
497
497
|
|
|
498
498
|
idxs.append(attention_offsets[offset])
|
|
499
499
|
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
500
|
+
with torch.no_grad():
|
|
501
|
+
m.mixer.m.attn.attention_biases = nn.Parameter(
|
|
502
|
+
interpolate_attention_bias(
|
|
503
|
+
m.mixer.m.attn.attention_biases, old_window_resolution, window_resolution
|
|
504
|
+
)
|
|
505
|
+
)
|
|
506
|
+
device = m.mixer.m.attn.attention_biases.device
|
|
507
|
+
m.mixer.m.attn.attention_bias_idxs = nn.Buffer(
|
|
508
|
+
torch.tensor(idxs, device=device, dtype=torch.long).view(N, N), persistent=False
|
|
503
509
|
)
|
|
504
|
-
)
|
|
505
|
-
m.mixer.m.attn.attention_bias_idxs = nn.Buffer(
|
|
506
|
-
torch.LongTensor(idxs).view(N, N), persistent=False
|
|
507
|
-
)
|
|
508
510
|
|
|
509
511
|
|
|
510
512
|
registry.register_model_config(
|
birder/net/fastvit.py
CHANGED
|
@@ -879,6 +879,7 @@ class FastViT(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
879
879
|
x = self.forward_features(x)
|
|
880
880
|
return self.features(x)
|
|
881
881
|
|
|
882
|
+
@torch.no_grad() # type: ignore[untyped-decorator]
|
|
882
883
|
def reparameterize_model(self) -> None:
|
|
883
884
|
for module in self.modules():
|
|
884
885
|
if hasattr(module, "reparameterize") is True:
|
birder/net/flexivit.py
CHANGED
|
@@ -519,15 +519,16 @@ class FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
|
|
|
519
519
|
else:
|
|
520
520
|
num_prefix_tokens = 0
|
|
521
521
|
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
522
|
+
with torch.no_grad():
|
|
523
|
+
pos_embedding = adjust_position_embedding(
|
|
524
|
+
# On rounding error see: https://github.com/facebookresearch/dino/issues/8
|
|
525
525
|
self.pos_embedding,
|
|
526
526
|
(old_size[0] // self.patch_size, old_size[1] // self.patch_size),
|
|
527
527
|
(new_size[0] // self.patch_size, new_size[1] // self.patch_size),
|
|
528
528
|
num_prefix_tokens,
|
|
529
529
|
)
|
|
530
|
-
|
|
530
|
+
|
|
531
|
+
self.pos_embedding = nn.Parameter(pos_embedding)
|
|
531
532
|
|
|
532
533
|
def adjust_patch_size(self, patch_size: int) -> None:
|
|
533
534
|
if self.patch_size == patch_size:
|
birder/net/hiera.py
CHANGED
|
@@ -612,23 +612,26 @@ class Hiera(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin):
|
|
|
612
612
|
|
|
613
613
|
if self.pos_embed_win is not None:
|
|
614
614
|
global_pos_size = (new_size[0] // 2**4, new_size[1] // 2**4)
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
615
|
+
with torch.no_grad():
|
|
616
|
+
pos_embed = F.interpolate(
|
|
617
|
+
self.pos_embed,
|
|
618
|
+
size=global_pos_size,
|
|
619
|
+
mode="bicubic",
|
|
620
|
+
antialias=True,
|
|
621
|
+
)
|
|
622
|
+
|
|
621
623
|
self.pos_embed = nn.Parameter(pos_embed)
|
|
622
624
|
|
|
623
625
|
else:
|
|
624
|
-
|
|
625
|
-
adjust_position_embedding(
|
|
626
|
+
with torch.no_grad():
|
|
627
|
+
pos_embed = adjust_position_embedding(
|
|
626
628
|
self.pos_embed,
|
|
627
629
|
(old_size[0] // self.patch_stride[0], old_size[1] // self.patch_stride[1]),
|
|
628
630
|
(new_size[0] // self.patch_stride[0], new_size[1] // self.patch_stride[1]),
|
|
629
631
|
0,
|
|
630
632
|
)
|
|
631
|
-
|
|
633
|
+
|
|
634
|
+
self.pos_embed = nn.Parameter(pos_embed)
|
|
632
635
|
|
|
633
636
|
# Re-init vars
|
|
634
637
|
self.tokens_spatial_shape = [i // s for i, s in zip(new_size, self.patch_stride)]
|
birder/net/hornet.py
CHANGED
|
@@ -332,13 +332,15 @@ class HorNet(DetectorBackbone):
|
|
|
332
332
|
for m in module.modules():
|
|
333
333
|
if isinstance(m, HorBlock):
|
|
334
334
|
if isinstance(m.gn_conv.dwconv, GlobalLocalFilter):
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
weight.
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
335
|
+
with torch.no_grad():
|
|
336
|
+
weight = m.gn_conv.dwconv.complex_weight
|
|
337
|
+
weight = F.interpolate(
|
|
338
|
+
weight.permute(3, 0, 1, 2),
|
|
339
|
+
size=(gn_conv_h[i], gn_conv_w[i]),
|
|
340
|
+
mode="bilinear",
|
|
341
|
+
align_corners=True,
|
|
342
|
+
).permute(1, 2, 3, 0)
|
|
343
|
+
|
|
342
344
|
m.gn_conv.dwconv.complex_weight = nn.Parameter(weight)
|
|
343
345
|
|
|
344
346
|
|
birder/net/iformer.py
CHANGED
|
@@ -477,12 +477,14 @@ class iFormer(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
477
477
|
resolution = (new_size[0] // 4, new_size[1] // 4)
|
|
478
478
|
for stage in self.body.modules():
|
|
479
479
|
if isinstance(stage, InceptionTransformerStage):
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
pos_embedding
|
|
484
|
-
|
|
485
|
-
|
|
480
|
+
with torch.no_grad():
|
|
481
|
+
orig_dtype = stage.pos_embed.dtype
|
|
482
|
+
pos_embedding = stage.pos_embed.float()
|
|
483
|
+
pos_embedding = F.interpolate(
|
|
484
|
+
pos_embedding.permute(0, 3, 1, 2), size=resolution, mode="bilinear"
|
|
485
|
+
).permute(0, 2, 3, 1)
|
|
486
|
+
pos_embedding = pos_embedding.to(orig_dtype)
|
|
487
|
+
|
|
486
488
|
stage.pos_embed = nn.Parameter(pos_embedding)
|
|
487
489
|
stage.resolution = resolution
|
|
488
490
|
resolution = (resolution[0] // 2, resolution[1] // 2)
|
birder/net/levit.py
CHANGED
|
@@ -454,42 +454,54 @@ class LeViT(BaseNet):
|
|
|
454
454
|
# Update Subsample resolution
|
|
455
455
|
m.q[0].resolution = resolution
|
|
456
456
|
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
# Rebuild attention bias indices
|
|
463
|
-
k_pos = torch.stack(
|
|
464
|
-
torch.meshgrid(torch.arange(resolution[0]), torch.arange(resolution[1]), indexing="ij")
|
|
465
|
-
).flatten(1)
|
|
466
|
-
q_pos = torch.stack(
|
|
467
|
-
torch.meshgrid(
|
|
468
|
-
torch.arange(0, resolution[0], step=m.stride),
|
|
469
|
-
torch.arange(0, resolution[1], step=m.stride),
|
|
470
|
-
indexing="ij",
|
|
457
|
+
with torch.no_grad():
|
|
458
|
+
# Interpolate attention biases
|
|
459
|
+
m.attention_biases = nn.Parameter(
|
|
460
|
+
interpolate_attention_bias(m.attention_biases, old_resolution, resolution)
|
|
471
461
|
)
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
462
|
+
|
|
463
|
+
# Rebuild attention bias indices
|
|
464
|
+
device = m.attention_biases.device
|
|
465
|
+
k_pos = torch.stack(
|
|
466
|
+
torch.meshgrid(
|
|
467
|
+
torch.arange(resolution[0], device=device),
|
|
468
|
+
torch.arange(resolution[1], device=device),
|
|
469
|
+
indexing="ij",
|
|
470
|
+
)
|
|
471
|
+
).flatten(1)
|
|
472
|
+
q_pos = torch.stack(
|
|
473
|
+
torch.meshgrid(
|
|
474
|
+
torch.arange(0, resolution[0], step=m.stride, device=device),
|
|
475
|
+
torch.arange(0, resolution[1], step=m.stride, device=device),
|
|
476
|
+
indexing="ij",
|
|
477
|
+
)
|
|
478
|
+
).flatten(1)
|
|
479
|
+
rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs()
|
|
480
|
+
rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
|
|
481
|
+
m.attention_bias_idxs = nn.Buffer(rel_pos, persistent=False)
|
|
476
482
|
|
|
477
483
|
old_resolution = ((old_resolution[0] - 1) // 2 + 1, (old_resolution[1] - 1) // 2 + 1)
|
|
478
484
|
resolution = ((resolution[0] - 1) // 2 + 1, (resolution[1] - 1) // 2 + 1)
|
|
479
485
|
|
|
480
486
|
elif isinstance(m, Attention):
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
487
|
+
with torch.no_grad():
|
|
488
|
+
# Interpolate attention biases
|
|
489
|
+
m.attention_biases = nn.Parameter(
|
|
490
|
+
interpolate_attention_bias(m.attention_biases, old_resolution, resolution)
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
# Rebuild attention bias indices
|
|
494
|
+
device = m.attention_biases.device
|
|
495
|
+
pos = torch.stack(
|
|
496
|
+
torch.meshgrid(
|
|
497
|
+
torch.arange(resolution[0], device=device),
|
|
498
|
+
torch.arange(resolution[1], device=device),
|
|
499
|
+
indexing="ij",
|
|
500
|
+
)
|
|
501
|
+
).flatten(1)
|
|
502
|
+
rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
|
|
503
|
+
rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
|
|
504
|
+
m.attention_bias_idxs = nn.Buffer(rel_pos, persistent=False)
|
|
493
505
|
|
|
494
506
|
|
|
495
507
|
registry.register_model_config(
|
birder/net/lit_v1_tiny.py
CHANGED
|
@@ -340,3 +340,18 @@ registry.register_model_config(
|
|
|
340
340
|
"drop_path_rate": 0.1,
|
|
341
341
|
},
|
|
342
342
|
)
|
|
343
|
+
|
|
344
|
+
registry.register_weights(
|
|
345
|
+
"lit_v1_t_il-common",
|
|
346
|
+
{
|
|
347
|
+
"description": "LIT v1 Tiny model trained on the il-common dataset",
|
|
348
|
+
"resolution": (256, 256),
|
|
349
|
+
"formats": {
|
|
350
|
+
"pt": {
|
|
351
|
+
"file_size": 75.2,
|
|
352
|
+
"sha256": "93813b2716eb9f33e06dc15ab2ba335c6d219354d2983bbc4f834f8f4e688e5c",
|
|
353
|
+
}
|
|
354
|
+
},
|
|
355
|
+
"net": {"network": "lit_v1_t", "tag": "il-common"},
|
|
356
|
+
},
|
|
357
|
+
)
|
birder/net/maxvit.py
CHANGED
|
@@ -52,8 +52,10 @@ def _make_block_input_shapes(input_size: tuple[int, int], n_blocks: int) -> list
|
|
|
52
52
|
return shapes
|
|
53
53
|
|
|
54
54
|
|
|
55
|
-
def _get_relative_position_index(height: int, width: int) -> torch.Tensor:
|
|
56
|
-
coords = torch.stack(
|
|
55
|
+
def _get_relative_position_index(height: int, width: int, device: torch.device | None = None) -> torch.Tensor:
|
|
56
|
+
coords = torch.stack(
|
|
57
|
+
torch.meshgrid([torch.arange(height, device=device), torch.arange(width, device=device)], indexing="ij")
|
|
58
|
+
)
|
|
57
59
|
coords_flat = torch.flatten(coords, 1)
|
|
58
60
|
relative_coords = coords_flat[:, :, None] - coords_flat[:, None, :]
|
|
59
61
|
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
|
@@ -152,7 +154,9 @@ class RelativePositionalMultiHeadAttention(nn.Module):
|
|
|
152
154
|
self.relative_position_bias_table = nn.Parameter(
|
|
153
155
|
torch.empty(((2 * self.size[0] - 1) * (2 * self.size[1] - 1), self.n_heads), dtype=torch.float32),
|
|
154
156
|
)
|
|
155
|
-
self.relative_position_index = nn.Buffer(
|
|
157
|
+
self.relative_position_index = nn.Buffer(
|
|
158
|
+
_get_relative_position_index(self.size[0], self.size[1], device=self.relative_position_bias_table.device)
|
|
159
|
+
)
|
|
156
160
|
|
|
157
161
|
# Initialize with truncated normal the bias
|
|
158
162
|
nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
|
|
@@ -682,60 +686,68 @@ class MaxViT(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
682
686
|
old_attn_size = attn.size
|
|
683
687
|
attn.size = self.partition_size
|
|
684
688
|
attn.max_seq_len = self.partition_size[0] * self.partition_size[1]
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
rel_pos_bias = attn.relative_position_bias_table
|
|
693
|
-
rel_pos_bias = rel_pos_bias.detach()
|
|
694
|
-
|
|
695
|
-
num_attn_heads = rel_pos_bias.size(1)
|
|
696
|
-
src_size = (2 * old_attn_size[0] - 1, 2 * old_attn_size[1] - 1)
|
|
697
|
-
|
|
698
|
-
def _calc(src: int, dst: int) -> list[float]:
|
|
699
|
-
(left, right) = 1.01, 1.5
|
|
700
|
-
while right - left > 1e-6:
|
|
701
|
-
q = (left + right) / 2.0
|
|
702
|
-
gp = (1.0 - q ** (src // 2)) / (1.0 - q) # Geometric progression
|
|
703
|
-
if gp > dst // 2:
|
|
704
|
-
right = q
|
|
705
|
-
|
|
706
|
-
else:
|
|
707
|
-
left = q
|
|
708
|
-
|
|
709
|
-
dis = []
|
|
710
|
-
cur = 1.0
|
|
711
|
-
for i in range(src // 2):
|
|
712
|
-
dis.append(cur)
|
|
713
|
-
cur += q ** (i + 1)
|
|
714
|
-
|
|
715
|
-
r_ids = [-_ for _ in reversed(dis)]
|
|
716
|
-
return r_ids + [0] + dis
|
|
717
|
-
|
|
718
|
-
y = _calc(src_size[0], dst_size[0])
|
|
719
|
-
x = _calc(src_size[1], dst_size[1])
|
|
720
|
-
|
|
721
|
-
ty = dst_size[0] // 2.0
|
|
722
|
-
tx = dst_size[1] // 2.0
|
|
723
|
-
dy = torch.arange(-ty, ty + 0.1, 1.0)
|
|
724
|
-
dx = torch.arange(-tx, tx + 0.1, 1.0)
|
|
725
|
-
dxy = torch.meshgrid(dx, dy, indexing="ij")
|
|
726
|
-
|
|
727
|
-
all_rel_pos_bias = []
|
|
728
|
-
for i in range(num_attn_heads):
|
|
729
|
-
z = rel_pos_bias[:, i].view(src_size[0], src_size[1]).float()
|
|
730
|
-
rgi = interpolate.RegularGridInterpolator(
|
|
731
|
-
(x, y), z.numpy().T, method="cubic", bounds_error=False, fill_value=None
|
|
689
|
+
with torch.no_grad():
|
|
690
|
+
attn.relative_position_index = nn.Buffer(
|
|
691
|
+
_get_relative_position_index(
|
|
692
|
+
attn.size[0],
|
|
693
|
+
attn.size[1],
|
|
694
|
+
device=attn.relative_position_bias_table.device,
|
|
695
|
+
)
|
|
732
696
|
)
|
|
733
|
-
r = torch.Tensor(rgi(dxy)).T.contiguous().to(rel_pos_bias.device)
|
|
734
|
-
|
|
735
|
-
r = r.view(-1, 1)
|
|
736
|
-
all_rel_pos_bias.append(r)
|
|
737
697
|
|
|
738
|
-
|
|
698
|
+
# Interpolate relative_position_bias_table, adapted from
|
|
699
|
+
# https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/pos_embed_rel.py
|
|
700
|
+
dst_size = (2 * attn.size[0] - 1, 2 * attn.size[1] - 1)
|
|
701
|
+
rel_pos_bias = attn.relative_position_bias_table.detach()
|
|
702
|
+
rel_pos_device = rel_pos_bias.device
|
|
703
|
+
rel_pos_bias = rel_pos_bias.float().cpu()
|
|
704
|
+
|
|
705
|
+
num_attn_heads = rel_pos_bias.size(1)
|
|
706
|
+
src_size = (2 * old_attn_size[0] - 1, 2 * old_attn_size[1] - 1)
|
|
707
|
+
|
|
708
|
+
def _calc(src: int, dst: int) -> list[float]:
|
|
709
|
+
(left, right) = 1.01, 1.5
|
|
710
|
+
while right - left > 1e-6:
|
|
711
|
+
q = (left + right) / 2.0
|
|
712
|
+
gp = (1.0 - q ** (src // 2)) / (1.0 - q) # Geometric progression
|
|
713
|
+
if gp > dst // 2:
|
|
714
|
+
right = q
|
|
715
|
+
|
|
716
|
+
else:
|
|
717
|
+
left = q
|
|
718
|
+
|
|
719
|
+
dis = []
|
|
720
|
+
cur = 1.0
|
|
721
|
+
for i in range(src // 2):
|
|
722
|
+
dis.append(cur)
|
|
723
|
+
cur += q ** (i + 1)
|
|
724
|
+
|
|
725
|
+
r_ids = [-_ for _ in reversed(dis)]
|
|
726
|
+
return r_ids + [0] + dis
|
|
727
|
+
|
|
728
|
+
y = _calc(src_size[0], dst_size[0])
|
|
729
|
+
x = _calc(src_size[1], dst_size[1])
|
|
730
|
+
|
|
731
|
+
ty = dst_size[0] // 2.0
|
|
732
|
+
tx = dst_size[1] // 2.0
|
|
733
|
+
dy = torch.arange(-ty, ty + 0.1, 1.0)
|
|
734
|
+
dx = torch.arange(-tx, tx + 0.1, 1.0)
|
|
735
|
+
dxy = torch.meshgrid(dx, dy, indexing="ij")
|
|
736
|
+
|
|
737
|
+
all_rel_pos_bias = []
|
|
738
|
+
for i in range(num_attn_heads):
|
|
739
|
+
z = rel_pos_bias[:, i].view(src_size[0], src_size[1])
|
|
740
|
+
rgi = interpolate.RegularGridInterpolator(
|
|
741
|
+
(x, y), z.numpy().T, method="cubic", bounds_error=False, fill_value=None
|
|
742
|
+
)
|
|
743
|
+
r = torch.tensor(
|
|
744
|
+
rgi(dxy), device=rel_pos_device, dtype=rel_pos_bias.dtype
|
|
745
|
+
).T.contiguous()
|
|
746
|
+
|
|
747
|
+
r = r.view(-1, 1)
|
|
748
|
+
all_rel_pos_bias.append(r)
|
|
749
|
+
|
|
750
|
+
rel_pos_bias = torch.concat(all_rel_pos_bias, dim=-1)
|
|
739
751
|
attn.relative_position_bias_table = nn.Parameter(rel_pos_bias)
|
|
740
752
|
|
|
741
753
|
new_grid_size = m.grid_size
|
birder/net/mobileone.py
CHANGED
|
@@ -380,6 +380,7 @@ class MobileOne(DetectorBackbone):
|
|
|
380
380
|
x = self.forward_features(x)
|
|
381
381
|
return self.features(x)
|
|
382
382
|
|
|
383
|
+
@torch.no_grad() # type: ignore[untyped-decorator]
|
|
383
384
|
def reparameterize_model(self) -> None:
|
|
384
385
|
for module in self.modules():
|
|
385
386
|
if hasattr(module, "reparameterize") is True:
|
birder/net/mvit_v2.py
CHANGED
|
@@ -638,18 +638,19 @@ class MViT_v2(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
638
638
|
rel_sp_dim_h = 2 * max(q_size_h, kv_size_h) - 1
|
|
639
639
|
rel_sp_dim_w = 2 * max(q_size_w, kv_size_w) - 1
|
|
640
640
|
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
641
|
+
with torch.no_grad():
|
|
642
|
+
rel_pos_h = m.attn.rel_pos_h
|
|
643
|
+
rel_pos_h_resized = F.interpolate(
|
|
644
|
+
rel_pos_h.reshape(1, rel_pos_h.shape[0], -1).permute(0, 2, 1),
|
|
645
|
+
size=rel_sp_dim_h,
|
|
646
|
+
mode="linear",
|
|
647
|
+
)
|
|
648
|
+
rel_pos_w = m.attn.rel_pos_w
|
|
649
|
+
rel_pos_w_resized = F.interpolate(
|
|
650
|
+
rel_pos_w.reshape(1, rel_pos_w.shape[0], -1).permute(0, 2, 1),
|
|
651
|
+
size=rel_sp_dim_w,
|
|
652
|
+
mode="linear",
|
|
653
|
+
)
|
|
653
654
|
|
|
654
655
|
m.attn.rel_pos_h = nn.Parameter(rel_pos_h_resized.reshape(-1, rel_sp_dim_h).permute(1, 0))
|
|
655
656
|
m.attn.rel_pos_w = nn.Parameter(rel_pos_w_resized.reshape(-1, rel_sp_dim_w).permute(1, 0))
|
birder/net/pit.py
CHANGED
|
@@ -258,9 +258,10 @@ class PiT(DetectorBackbone):
|
|
|
258
258
|
height = (new_size[0] - self.patch_size[0]) // self.patch_stride[0] + 1
|
|
259
259
|
width = (new_size[1] - self.patch_size[1]) // self.patch_stride[1] + 1
|
|
260
260
|
|
|
261
|
-
|
|
262
|
-
F.interpolate(self.pos_embed, (height, width), mode="bicubic")
|
|
263
|
-
|
|
261
|
+
with torch.no_grad():
|
|
262
|
+
pos_embed = F.interpolate(self.pos_embed, (height, width), mode="bicubic")
|
|
263
|
+
|
|
264
|
+
self.pos_embed = nn.Parameter(pos_embed)
|
|
264
265
|
|
|
265
266
|
|
|
266
267
|
registry.register_model_config(
|
birder/net/pvt_v1.py
CHANGED
|
@@ -308,7 +308,10 @@ class PVT_v1(DetectorBackbone):
|
|
|
308
308
|
s = (new_size[0] // 4, new_size[1] // 4)
|
|
309
309
|
for m in self.body.modules():
|
|
310
310
|
if isinstance(m, PyramidVisionTransformerStage):
|
|
311
|
-
|
|
311
|
+
with torch.no_grad():
|
|
312
|
+
pos_embed = adjust_position_embedding(m.pos_embed, old_s, s, 0)
|
|
313
|
+
|
|
314
|
+
m.pos_embed = nn.Parameter(pos_embed)
|
|
312
315
|
old_s = (old_s[0] // 2, old_s[1] // 2)
|
|
313
316
|
s = (s[0] // 2, s[1] // 2)
|
|
314
317
|
|
birder/net/repghost.py
CHANGED
|
@@ -338,6 +338,7 @@ class RepGhost(DetectorBackbone):
|
|
|
338
338
|
x = self.forward_features(x)
|
|
339
339
|
return self.features(x)
|
|
340
340
|
|
|
341
|
+
@torch.no_grad() # type: ignore[untyped-decorator]
|
|
341
342
|
def reparameterize_model(self) -> None:
|
|
342
343
|
for module in self.modules():
|
|
343
344
|
if hasattr(module, "reparameterize") is True:
|