birder 0.2.2__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/common/lib.py +2 -9
- birder/common/training_cli.py +24 -0
- birder/common/training_utils.py +338 -41
- birder/data/collators/detection.py +11 -3
- birder/data/dataloader/webdataset.py +12 -2
- birder/data/datasets/coco.py +8 -10
- birder/data/transforms/detection.py +30 -13
- birder/inference/detection.py +108 -4
- birder/inference/wbf.py +226 -0
- birder/kernels/load_kernel.py +16 -11
- birder/kernels/soft_nms/soft_nms.cpp +17 -18
- birder/net/__init__.py +8 -0
- birder/net/cait.py +4 -3
- birder/net/convnext_v1.py +5 -0
- birder/net/crossformer.py +33 -30
- birder/net/crossvit.py +4 -3
- birder/net/deit.py +3 -3
- birder/net/deit3.py +3 -3
- birder/net/detection/deformable_detr.py +2 -5
- birder/net/detection/detr.py +2 -5
- birder/net/detection/efficientdet.py +67 -93
- birder/net/detection/fcos.py +2 -7
- birder/net/detection/retinanet.py +2 -7
- birder/net/detection/rt_detr_v1.py +2 -0
- birder/net/detection/yolo_anchors.py +205 -0
- birder/net/detection/yolo_v2.py +25 -24
- birder/net/detection/yolo_v3.py +39 -40
- birder/net/detection/yolo_v4.py +28 -26
- birder/net/detection/yolo_v4_tiny.py +24 -20
- birder/net/efficientformer_v1.py +15 -9
- birder/net/efficientformer_v2.py +39 -29
- birder/net/efficientvit_msft.py +9 -7
- birder/net/fasternet.py +1 -1
- birder/net/fastvit.py +1 -0
- birder/net/flexivit.py +5 -4
- birder/net/gc_vit.py +671 -0
- birder/net/hiera.py +12 -9
- birder/net/hornet.py +9 -7
- birder/net/iformer.py +8 -6
- birder/net/levit.py +42 -30
- birder/net/lit_v1.py +472 -0
- birder/net/lit_v1_tiny.py +357 -0
- birder/net/lit_v2.py +436 -0
- birder/net/maxvit.py +67 -55
- birder/net/mobilenet_v4_hybrid.py +1 -1
- birder/net/mobileone.py +1 -0
- birder/net/mvit_v2.py +13 -12
- birder/net/pit.py +4 -3
- birder/net/pvt_v1.py +4 -1
- birder/net/repghost.py +1 -0
- birder/net/repvgg.py +1 -0
- birder/net/repvit.py +1 -0
- birder/net/resnet_v1.py +1 -1
- birder/net/resnext.py +67 -25
- birder/net/rope_deit3.py +5 -3
- birder/net/rope_flexivit.py +7 -4
- birder/net/rope_vit.py +10 -5
- birder/net/se_resnet_v1.py +46 -0
- birder/net/se_resnext.py +3 -0
- birder/net/simple_vit.py +11 -8
- birder/net/swin_transformer_v1.py +71 -68
- birder/net/swin_transformer_v2.py +38 -31
- birder/net/tiny_vit.py +20 -10
- birder/net/transnext.py +38 -28
- birder/net/vit.py +5 -19
- birder/net/vit_parallel.py +5 -4
- birder/net/vit_sam.py +38 -37
- birder/net/vovnet_v1.py +15 -0
- birder/net/vovnet_v2.py +31 -1
- birder/ops/msda.py +108 -43
- birder/ops/swattention.py +124 -61
- birder/results/detection.py +4 -0
- birder/scripts/benchmark.py +110 -32
- birder/scripts/predict.py +8 -0
- birder/scripts/predict_detection.py +18 -11
- birder/scripts/train.py +48 -46
- birder/scripts/train_barlow_twins.py +44 -45
- birder/scripts/train_byol.py +44 -45
- birder/scripts/train_capi.py +50 -49
- birder/scripts/train_data2vec.py +45 -47
- birder/scripts/train_data2vec2.py +45 -47
- birder/scripts/train_detection.py +83 -50
- birder/scripts/train_dino_v1.py +60 -47
- birder/scripts/train_dino_v2.py +86 -52
- birder/scripts/train_dino_v2_dist.py +84 -50
- birder/scripts/train_franca.py +51 -52
- birder/scripts/train_i_jepa.py +45 -47
- birder/scripts/train_ibot.py +51 -53
- birder/scripts/train_kd.py +194 -76
- birder/scripts/train_mim.py +44 -45
- birder/scripts/train_mmcr.py +44 -45
- birder/scripts/train_rotnet.py +45 -46
- birder/scripts/train_simclr.py +44 -45
- birder/scripts/train_vicreg.py +44 -45
- birder/tools/auto_anchors.py +20 -1
- birder/tools/convert_model.py +18 -15
- birder/tools/det_results.py +114 -2
- birder/tools/pack.py +172 -103
- birder/tools/quantize_model.py +73 -67
- birder/tools/show_det_iterator.py +10 -1
- birder/version.py +1 -1
- {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/METADATA +4 -3
- {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/RECORD +107 -101
- {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/WHEEL +0 -0
- {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/entry_points.txt +0 -0
- {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/licenses/LICENSE +0 -0
- {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/top_level.txt +0 -0
|
@@ -15,18 +15,13 @@ import torch
|
|
|
15
15
|
from torch import nn
|
|
16
16
|
from torchvision.ops import Conv2dNormActivation
|
|
17
17
|
|
|
18
|
+
from birder.model_registry import registry
|
|
18
19
|
from birder.net.base import DetectorBackbone
|
|
20
|
+
from birder.net.detection.yolo_anchors import resolve_anchor_groups
|
|
19
21
|
from birder.net.detection.yolo_v3 import YOLOAnchorGenerator
|
|
20
22
|
from birder.net.detection.yolo_v3 import YOLOHead
|
|
21
|
-
from birder.net.detection.yolo_v3 import scale_anchors
|
|
22
23
|
from birder.net.detection.yolo_v4 import YOLO_v4
|
|
23
24
|
|
|
24
|
-
# Default anchors from YOLO v4 Tiny (COCO)
|
|
25
|
-
DEFAULT_ANCHORS = [
|
|
26
|
-
[(10.0, 14.0), (23.0, 27.0), (37.0, 58.0)], # Medium
|
|
27
|
-
[(81.0, 82.0), (135.0, 169.0), (344.0, 319.0)], # Large
|
|
28
|
-
]
|
|
29
|
-
|
|
30
25
|
# Scale factors per detection scale to eliminate grid sensitivity
|
|
31
26
|
DEFAULT_SCALE_XY = [1.05, 1.05] # [medium, large]
|
|
32
27
|
|
|
@@ -92,7 +87,6 @@ class YOLOTinyNeck(nn.Module):
|
|
|
92
87
|
# pylint: disable=invalid-name
|
|
93
88
|
class YOLO_v4_Tiny(YOLO_v4):
|
|
94
89
|
default_size = (416, 416)
|
|
95
|
-
auto_register = True
|
|
96
90
|
|
|
97
91
|
def __init__(
|
|
98
92
|
self,
|
|
@@ -104,22 +98,26 @@ class YOLO_v4_Tiny(YOLO_v4):
|
|
|
104
98
|
export_mode: bool = False,
|
|
105
99
|
) -> None:
|
|
106
100
|
super().__init__(num_classes, backbone, config=config, size=size, export_mode=export_mode)
|
|
107
|
-
assert self.config is None, "
|
|
101
|
+
assert self.config is not None, "must set config"
|
|
108
102
|
|
|
109
103
|
# self.num_classes = self.num_classes - 1 (Subtracted at parent)
|
|
110
104
|
|
|
111
105
|
score_thresh = 0.05
|
|
112
106
|
nms_thresh = 0.45
|
|
113
107
|
detections_per_img = 300
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
self.
|
|
108
|
+
ignore_thresh = 0.7
|
|
109
|
+
noobj_coeff = 0.25
|
|
110
|
+
coord_coeff = 3.0
|
|
111
|
+
obj_coeff = 1.0
|
|
112
|
+
cls_coeff = 1.0
|
|
113
|
+
label_smoothing = 0.1
|
|
114
|
+
anchor_spec = self.config["anchors"]
|
|
115
|
+
|
|
116
|
+
self.ignore_thresh = ignore_thresh
|
|
117
|
+
self.noobj_coeff = noobj_coeff
|
|
118
|
+
self.coord_coeff = coord_coeff
|
|
119
|
+
self.obj_coeff = obj_coeff
|
|
120
|
+
self.cls_coeff = cls_coeff
|
|
123
121
|
self.scale_xy = DEFAULT_SCALE_XY
|
|
124
122
|
self.score_thresh = score_thresh
|
|
125
123
|
self.nms_thresh = nms_thresh
|
|
@@ -128,12 +126,18 @@ class YOLO_v4_Tiny(YOLO_v4):
|
|
|
128
126
|
self.backbone.return_channels = self.backbone.return_channels[-2:]
|
|
129
127
|
self.backbone.return_stages = self.backbone.return_stages[-2:]
|
|
130
128
|
|
|
131
|
-
self.label_smoothing =
|
|
129
|
+
self.label_smoothing = label_smoothing
|
|
132
130
|
self.smooth_positive = 1.0 - self.label_smoothing
|
|
133
131
|
self.smooth_negative = self.label_smoothing / self.num_classes
|
|
134
132
|
|
|
135
133
|
self.neck = YOLOTinyNeck(self.backbone.return_channels)
|
|
136
134
|
|
|
137
|
-
|
|
135
|
+
anchors = resolve_anchor_groups(
|
|
136
|
+
anchor_spec, anchor_format="pixels", model_size=self.size, model_strides=(16, 32)
|
|
137
|
+
)
|
|
138
|
+
self.anchor_generator = YOLOAnchorGenerator(anchors)
|
|
138
139
|
num_anchors = self.anchor_generator.num_anchors_per_location()
|
|
139
140
|
self.head = YOLOHead(self.neck.out_channels, num_anchors, self.num_classes)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
registry.register_model_config("yolo_v4_tiny", YOLO_v4_Tiny, config={"anchors": "yolo_v4_tiny"})
|
birder/net/efficientformer_v1.py
CHANGED
|
@@ -357,16 +357,22 @@ class EfficientFormer_v1(BaseNet):
|
|
|
357
357
|
resolution = (int(new_size[0] / (2**5)), int(new_size[1] / (2**5)))
|
|
358
358
|
for m in self.body.modules():
|
|
359
359
|
if isinstance(m, Attention):
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
360
|
+
with torch.no_grad():
|
|
361
|
+
m.attention_biases = nn.Parameter(
|
|
362
|
+
interpolate_attention_bias(m.attention_biases, old_resolution, resolution)
|
|
363
|
+
)
|
|
363
364
|
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
365
|
+
device = m.attention_biases.device
|
|
366
|
+
pos = torch.stack(
|
|
367
|
+
torch.meshgrid(
|
|
368
|
+
torch.arange(resolution[0], device=device),
|
|
369
|
+
torch.arange(resolution[1], device=device),
|
|
370
|
+
indexing="ij",
|
|
371
|
+
)
|
|
372
|
+
).flatten(1)
|
|
373
|
+
rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
|
|
374
|
+
rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
|
|
375
|
+
m.attention_bias_idxs = nn.Buffer(rel_pos)
|
|
370
376
|
|
|
371
377
|
|
|
372
378
|
registry.register_model_config(
|
birder/net/efficientformer_v2.py
CHANGED
|
@@ -554,26 +554,30 @@ class EfficientFormer_v2(DetectorBackbone):
|
|
|
554
554
|
attn.N = attn.resolution[0] * attn.resolution[1]
|
|
555
555
|
attn.N2 = attn.resolution2[0] * attn.resolution2[1]
|
|
556
556
|
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
k_pos = torch.stack(
|
|
563
|
-
torch.meshgrid(
|
|
564
|
-
torch.arange(attn.resolution[0]), torch.arange(attn.resolution[1]), indexing="ij"
|
|
557
|
+
with torch.no_grad():
|
|
558
|
+
# Interpolate attention_biases
|
|
559
|
+
attn.attention_biases = nn.Parameter(
|
|
560
|
+
interpolate_attention_bias(attn.attention_biases, old_base, new_base)
|
|
565
561
|
)
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
torch.
|
|
569
|
-
torch.
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
562
|
+
|
|
563
|
+
device = attn.attention_biases.device
|
|
564
|
+
k_pos = torch.stack(
|
|
565
|
+
torch.meshgrid(
|
|
566
|
+
torch.arange(attn.resolution[0], device=device),
|
|
567
|
+
torch.arange(attn.resolution[1], device=device),
|
|
568
|
+
indexing="ij",
|
|
569
|
+
)
|
|
570
|
+
).flatten(1)
|
|
571
|
+
q_pos = torch.stack(
|
|
572
|
+
torch.meshgrid(
|
|
573
|
+
torch.arange(0, attn.resolution[0], step=2, device=device),
|
|
574
|
+
torch.arange(0, attn.resolution[1], step=2, device=device),
|
|
575
|
+
indexing="ij",
|
|
576
|
+
)
|
|
577
|
+
).flatten(1)
|
|
578
|
+
rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs()
|
|
579
|
+
rel_pos = (rel_pos[0] * attn.resolution[1]) + rel_pos[1]
|
|
580
|
+
attn.attention_bias_idxs = nn.Buffer(rel_pos.to(torch.long), persistent=False)
|
|
577
581
|
|
|
578
582
|
old_base = (old_base[0] // 2, old_base[1] // 2)
|
|
579
583
|
new_base = (new_base[0] // 2, new_base[1] // 2)
|
|
@@ -590,16 +594,22 @@ class EfficientFormer_v2(DetectorBackbone):
|
|
|
590
594
|
m.token_mixer.resolution = c_new_base
|
|
591
595
|
m.token_mixer.N = m.token_mixer.resolution[0] * m.token_mixer.resolution[1]
|
|
592
596
|
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
597
|
+
with torch.no_grad():
|
|
598
|
+
m.token_mixer.attention_biases = nn.Parameter(
|
|
599
|
+
interpolate_attention_bias(m.token_mixer.attention_biases, c_old_base, c_new_base)
|
|
600
|
+
)
|
|
601
|
+
|
|
602
|
+
device = m.token_mixer.attention_biases.device
|
|
603
|
+
pos = torch.stack(
|
|
604
|
+
torch.meshgrid(
|
|
605
|
+
torch.arange(c_new_base[0], device=device),
|
|
606
|
+
torch.arange(c_new_base[1], device=device),
|
|
607
|
+
indexing="ij",
|
|
608
|
+
)
|
|
609
|
+
).flatten(1)
|
|
610
|
+
rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
|
|
611
|
+
rel_pos = (rel_pos[0] * c_new_base[1]) + rel_pos[1]
|
|
612
|
+
m.token_mixer.attention_bias_idxs = nn.Buffer(rel_pos.to(torch.long), persistent=False)
|
|
603
613
|
|
|
604
614
|
|
|
605
615
|
registry.register_model_config(
|
birder/net/efficientvit_msft.py
CHANGED
|
@@ -497,14 +497,16 @@ class EfficientViT_MSFT(DetectorBackbone):
|
|
|
497
497
|
|
|
498
498
|
idxs.append(attention_offsets[offset])
|
|
499
499
|
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
500
|
+
with torch.no_grad():
|
|
501
|
+
m.mixer.m.attn.attention_biases = nn.Parameter(
|
|
502
|
+
interpolate_attention_bias(
|
|
503
|
+
m.mixer.m.attn.attention_biases, old_window_resolution, window_resolution
|
|
504
|
+
)
|
|
505
|
+
)
|
|
506
|
+
device = m.mixer.m.attn.attention_biases.device
|
|
507
|
+
m.mixer.m.attn.attention_bias_idxs = nn.Buffer(
|
|
508
|
+
torch.tensor(idxs, device=device, dtype=torch.long).view(N, N), persistent=False
|
|
503
509
|
)
|
|
504
|
-
)
|
|
505
|
-
m.mixer.m.attn.attention_bias_idxs = nn.Buffer(
|
|
506
|
-
torch.LongTensor(idxs).view(N, N), persistent=False
|
|
507
|
-
)
|
|
508
510
|
|
|
509
511
|
|
|
510
512
|
registry.register_model_config(
|
birder/net/fasternet.py
CHANGED
birder/net/fastvit.py
CHANGED
|
@@ -879,6 +879,7 @@ class FastViT(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
879
879
|
x = self.forward_features(x)
|
|
880
880
|
return self.features(x)
|
|
881
881
|
|
|
882
|
+
@torch.no_grad() # type: ignore[untyped-decorator]
|
|
882
883
|
def reparameterize_model(self) -> None:
|
|
883
884
|
for module in self.modules():
|
|
884
885
|
if hasattr(module, "reparameterize") is True:
|
birder/net/flexivit.py
CHANGED
|
@@ -519,15 +519,16 @@ class FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
|
|
|
519
519
|
else:
|
|
520
520
|
num_prefix_tokens = 0
|
|
521
521
|
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
522
|
+
with torch.no_grad():
|
|
523
|
+
pos_embedding = adjust_position_embedding(
|
|
524
|
+
# On rounding error see: https://github.com/facebookresearch/dino/issues/8
|
|
525
525
|
self.pos_embedding,
|
|
526
526
|
(old_size[0] // self.patch_size, old_size[1] // self.patch_size),
|
|
527
527
|
(new_size[0] // self.patch_size, new_size[1] // self.patch_size),
|
|
528
528
|
num_prefix_tokens,
|
|
529
529
|
)
|
|
530
|
-
|
|
530
|
+
|
|
531
|
+
self.pos_embedding = nn.Parameter(pos_embedding)
|
|
531
532
|
|
|
532
533
|
def adjust_patch_size(self, patch_size: int) -> None:
|
|
533
534
|
if self.patch_size == patch_size:
|