birder 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/adversarial/base.py +1 -1
- birder/adversarial/simba.py +4 -4
- birder/common/cli.py +1 -1
- birder/common/fs_ops.py +13 -13
- birder/common/lib.py +2 -2
- birder/common/masking.py +3 -3
- birder/common/training_cli.py +24 -2
- birder/common/training_utils.py +28 -4
- birder/data/collators/detection.py +9 -1
- birder/data/transforms/detection.py +27 -8
- birder/data/transforms/mosaic.py +1 -1
- birder/datahub/classification.py +3 -3
- birder/inference/classification.py +3 -3
- birder/inference/data_parallel.py +1 -1
- birder/inference/detection.py +5 -5
- birder/inference/wbf.py +1 -1
- birder/introspection/attention_rollout.py +6 -6
- birder/introspection/feature_pca.py +4 -4
- birder/introspection/gradcam.py +1 -1
- birder/introspection/guided_backprop.py +2 -2
- birder/introspection/transformer_attribution.py +4 -4
- birder/layers/attention_pool.py +2 -2
- birder/layers/layer_scale.py +1 -1
- birder/model_registry/model_registry.py +2 -1
- birder/net/__init__.py +4 -10
- birder/net/_rope_vit_configs.py +435 -0
- birder/net/_vit_configs.py +466 -0
- birder/net/alexnet.py +5 -5
- birder/net/base.py +28 -3
- birder/net/biformer.py +18 -17
- birder/net/cait.py +7 -7
- birder/net/cas_vit.py +1 -1
- birder/net/coat.py +27 -27
- birder/net/conv2former.py +3 -3
- birder/net/convmixer.py +1 -1
- birder/net/convnext_v1.py +3 -11
- birder/net/convnext_v1_iso.py +198 -0
- birder/net/convnext_v2.py +2 -10
- birder/net/crossformer.py +9 -9
- birder/net/crossvit.py +6 -6
- birder/net/cspnet.py +1 -1
- birder/net/cswin_transformer.py +10 -10
- birder/net/davit.py +11 -11
- birder/net/deit.py +68 -29
- birder/net/deit3.py +69 -204
- birder/net/densenet.py +9 -8
- birder/net/detection/__init__.py +4 -0
- birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
- birder/net/detection/base.py +6 -5
- birder/net/detection/deformable_detr.py +31 -30
- birder/net/detection/detr.py +14 -11
- birder/net/detection/efficientdet.py +10 -29
- birder/net/detection/faster_rcnn.py +22 -22
- birder/net/detection/fcos.py +8 -8
- birder/net/detection/plain_detr.py +852 -0
- birder/net/detection/retinanet.py +4 -4
- birder/net/detection/rt_detr_v1.py +81 -25
- birder/net/detection/rt_detr_v2.py +1147 -0
- birder/net/detection/ssd.py +5 -5
- birder/net/detection/yolo_v2.py +12 -12
- birder/net/detection/yolo_v3.py +19 -19
- birder/net/detection/yolo_v4.py +16 -16
- birder/net/detection/yolo_v4_tiny.py +3 -3
- birder/net/dpn.py +1 -2
- birder/net/edgenext.py +5 -4
- birder/net/edgevit.py +13 -14
- birder/net/efficientformer_v1.py +3 -2
- birder/net/efficientformer_v2.py +18 -31
- birder/net/efficientnet_v2.py +3 -0
- birder/net/efficientvim.py +9 -9
- birder/net/efficientvit_mit.py +7 -7
- birder/net/efficientvit_msft.py +3 -3
- birder/net/fasternet.py +3 -3
- birder/net/fastvit.py +5 -12
- birder/net/flexivit.py +50 -58
- birder/net/focalnet.py +5 -9
- birder/net/gc_vit.py +11 -11
- birder/net/ghostnet_v1.py +1 -1
- birder/net/ghostnet_v2.py +1 -1
- birder/net/groupmixformer.py +13 -13
- birder/net/hgnet_v1.py +6 -6
- birder/net/hgnet_v2.py +4 -4
- birder/net/hiera.py +6 -6
- birder/net/hieradet.py +9 -9
- birder/net/hornet.py +3 -3
- birder/net/iformer.py +4 -4
- birder/net/inception_next.py +5 -15
- birder/net/inception_resnet_v1.py +3 -3
- birder/net/inception_resnet_v2.py +7 -4
- birder/net/inception_v3.py +3 -0
- birder/net/inception_v4.py +3 -0
- birder/net/levit.py +3 -3
- birder/net/lit_v1.py +13 -15
- birder/net/lit_v1_tiny.py +9 -9
- birder/net/lit_v2.py +14 -15
- birder/net/maxvit.py +11 -23
- birder/net/metaformer.py +5 -5
- birder/net/mim/crossmae.py +6 -6
- birder/net/mim/fcmae.py +3 -5
- birder/net/mim/mae_hiera.py +7 -7
- birder/net/mim/mae_vit.py +4 -6
- birder/net/mim/simmim.py +3 -4
- birder/net/mobilenet_v1.py +0 -9
- birder/net/mobilenet_v2.py +38 -44
- birder/net/{mobilenet_v3_large.py → mobilenet_v3.py} +37 -10
- birder/net/mobilenet_v4_hybrid.py +4 -4
- birder/net/mobileone.py +5 -12
- birder/net/mobilevit_v1.py +7 -34
- birder/net/mobilevit_v2.py +6 -54
- birder/net/moganet.py +8 -5
- birder/net/mvit_v2.py +30 -30
- birder/net/nextvit.py +2 -2
- birder/net/nfnet.py +4 -0
- birder/net/pit.py +11 -26
- birder/net/pvt_v1.py +9 -9
- birder/net/pvt_v2.py +10 -16
- birder/net/regionvit.py +15 -15
- birder/net/regnet.py +1 -1
- birder/net/repghost.py +5 -35
- birder/net/repvgg.py +3 -5
- birder/net/repvit.py +2 -2
- birder/net/resmlp.py +2 -2
- birder/net/resnest.py +4 -1
- birder/net/resnet_v1.py +125 -1
- birder/net/resnet_v2.py +75 -1
- birder/net/resnext.py +35 -1
- birder/net/rope_deit3.py +62 -151
- birder/net/rope_flexivit.py +46 -33
- birder/net/rope_vit.py +44 -758
- birder/net/sequencer2d.py +3 -4
- birder/net/shufflenet_v1.py +1 -1
- birder/net/shufflenet_v2.py +1 -1
- birder/net/simple_vit.py +69 -21
- birder/net/smt.py +8 -8
- birder/net/squeezenet.py +5 -12
- birder/net/squeezenext.py +0 -24
- birder/net/ssl/barlow_twins.py +1 -1
- birder/net/ssl/byol.py +2 -2
- birder/net/ssl/capi.py +4 -4
- birder/net/ssl/data2vec.py +1 -1
- birder/net/ssl/data2vec2.py +1 -1
- birder/net/ssl/dino_v2.py +13 -3
- birder/net/ssl/franca.py +28 -4
- birder/net/ssl/i_jepa.py +5 -5
- birder/net/ssl/ibot.py +1 -1
- birder/net/ssl/mmcr.py +1 -1
- birder/net/swiftformer.py +13 -3
- birder/net/swin_transformer_v1.py +4 -5
- birder/net/swin_transformer_v2.py +5 -8
- birder/net/tiny_vit.py +6 -19
- birder/net/transnext.py +19 -19
- birder/net/uniformer.py +4 -4
- birder/net/van.py +2 -2
- birder/net/vgg.py +1 -10
- birder/net/vit.py +72 -987
- birder/net/vit_parallel.py +35 -20
- birder/net/vit_sam.py +23 -48
- birder/net/vovnet_v2.py +1 -1
- birder/net/xcit.py +16 -13
- birder/ops/msda.py +4 -4
- birder/ops/swattention.py +10 -10
- birder/results/classification.py +3 -3
- birder/results/gui.py +8 -8
- birder/scripts/benchmark.py +37 -12
- birder/scripts/evaluate.py +1 -1
- birder/scripts/predict.py +3 -3
- birder/scripts/predict_detection.py +2 -2
- birder/scripts/train.py +63 -15
- birder/scripts/train_barlow_twins.py +10 -7
- birder/scripts/train_byol.py +10 -7
- birder/scripts/train_capi.py +15 -10
- birder/scripts/train_data2vec.py +10 -7
- birder/scripts/train_data2vec2.py +10 -7
- birder/scripts/train_detection.py +29 -14
- birder/scripts/train_dino_v1.py +13 -9
- birder/scripts/train_dino_v2.py +27 -14
- birder/scripts/train_dino_v2_dist.py +28 -15
- birder/scripts/train_franca.py +16 -9
- birder/scripts/train_i_jepa.py +12 -9
- birder/scripts/train_ibot.py +15 -11
- birder/scripts/train_kd.py +64 -17
- birder/scripts/train_mim.py +11 -8
- birder/scripts/train_mmcr.py +11 -8
- birder/scripts/train_rotnet.py +11 -7
- birder/scripts/train_simclr.py +10 -7
- birder/scripts/train_vicreg.py +10 -7
- birder/tools/adversarial.py +4 -4
- birder/tools/auto_anchors.py +5 -5
- birder/tools/avg_model.py +1 -1
- birder/tools/convert_model.py +30 -22
- birder/tools/det_results.py +1 -1
- birder/tools/download_model.py +1 -1
- birder/tools/ensemble_model.py +1 -1
- birder/tools/introspection.py +12 -3
- birder/tools/labelme_to_coco.py +2 -2
- birder/tools/model_info.py +15 -15
- birder/tools/pack.py +8 -8
- birder/tools/quantize_model.py +53 -4
- birder/tools/results.py +2 -2
- birder/tools/show_det_iterator.py +19 -6
- birder/tools/show_iterator.py +2 -2
- birder/tools/similarity.py +5 -5
- birder/tools/stats.py +4 -6
- birder/tools/voc_to_coco.py +1 -1
- birder/version.py +1 -1
- {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
- birder-0.4.1.dist-info/RECORD +300 -0
- {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
- birder/net/mobilenet_v3_small.py +0 -43
- birder/net/se_resnet_v1.py +0 -105
- birder/net/se_resnet_v2.py +0 -59
- birder/net/se_resnext.py +0 -30
- birder-0.3.3.dist-info/RECORD +0 -299
- {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
- {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
birder/net/deit3.py
CHANGED
|
@@ -15,12 +15,19 @@ from torch import nn
|
|
|
15
15
|
|
|
16
16
|
from birder.common.masking import mask_tensor
|
|
17
17
|
from birder.model_registry import registry
|
|
18
|
+
from birder.net._vit_configs import BASE
|
|
19
|
+
from birder.net._vit_configs import HUGE
|
|
20
|
+
from birder.net._vit_configs import LARGE
|
|
21
|
+
from birder.net._vit_configs import MEDIUM
|
|
22
|
+
from birder.net._vit_configs import SMALL
|
|
23
|
+
from birder.net._vit_configs import TINY
|
|
18
24
|
from birder.net.base import DetectorBackbone
|
|
19
25
|
from birder.net.base import MaskedTokenOmissionMixin
|
|
20
26
|
from birder.net.base import MaskedTokenRetentionMixin
|
|
21
27
|
from birder.net.base import PreTrainEncoder
|
|
22
28
|
from birder.net.base import TokenOmissionResultType
|
|
23
29
|
from birder.net.base import TokenRetentionResultType
|
|
30
|
+
from birder.net.base import normalize_out_indices
|
|
24
31
|
from birder.net.vit import Encoder
|
|
25
32
|
from birder.net.vit import EncoderBlock
|
|
26
33
|
from birder.net.vit import PatchEmbed
|
|
@@ -53,6 +60,7 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
|
|
|
53
60
|
mlp_dim: int = self.config["mlp_dim"]
|
|
54
61
|
layer_scale_init_value: Optional[float] = self.config.get("layer_scale_init_value", 1e-5)
|
|
55
62
|
num_reg_tokens: int = self.config.get("num_reg_tokens", 0)
|
|
63
|
+
out_indices: Optional[list[int]] = self.config.get("out_indices", None)
|
|
56
64
|
drop_path_rate: float = self.config["drop_path_rate"]
|
|
57
65
|
|
|
58
66
|
torch._assert(image_size[0] % patch_size == 0, "Input shape indivisible by patch size!")
|
|
@@ -64,6 +72,7 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
|
|
|
64
72
|
self.num_reg_tokens = num_reg_tokens
|
|
65
73
|
self.num_special_tokens = 1 + self.num_reg_tokens
|
|
66
74
|
self.pos_embed_special_tokens = pos_embed_special_tokens
|
|
75
|
+
self.out_indices = normalize_out_indices(out_indices, num_layers)
|
|
67
76
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] # Stochastic depth decay rule
|
|
68
77
|
|
|
69
78
|
self.conv_proj = nn.Conv2d(
|
|
@@ -72,7 +81,6 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
|
|
|
72
81
|
kernel_size=(patch_size, patch_size),
|
|
73
82
|
stride=(patch_size, patch_size),
|
|
74
83
|
padding=(0, 0),
|
|
75
|
-
bias=True,
|
|
76
84
|
)
|
|
77
85
|
self.patch_embed = PatchEmbed()
|
|
78
86
|
|
|
@@ -106,8 +114,9 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
|
|
|
106
114
|
)
|
|
107
115
|
self.norm = nn.LayerNorm(hidden_dim, eps=1e-6)
|
|
108
116
|
|
|
109
|
-
self.
|
|
110
|
-
self.
|
|
117
|
+
num_return_stages = len(self.out_indices) if self.out_indices is not None else 1
|
|
118
|
+
self.return_stages = [f"stage{stage_idx + 1}" for stage_idx in range(num_return_stages)]
|
|
119
|
+
self.return_channels = [hidden_dim] * num_return_stages
|
|
111
120
|
self.embedding_size = hidden_dim
|
|
112
121
|
self.classifier = self.create_classifier()
|
|
113
122
|
|
|
@@ -153,7 +162,7 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
|
|
|
153
162
|
)
|
|
154
163
|
|
|
155
164
|
def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
156
|
-
|
|
165
|
+
H, W = x.shape[-2:]
|
|
157
166
|
|
|
158
167
|
x = self.conv_proj(x)
|
|
159
168
|
x = self.patch_embed(x)
|
|
@@ -170,15 +179,20 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
|
|
|
170
179
|
x = x + self._get_pos_embed(H, W)
|
|
171
180
|
x = torch.concat([batch_special_tokens, x], dim=1)
|
|
172
181
|
|
|
173
|
-
|
|
174
|
-
|
|
182
|
+
if self.out_indices is None:
|
|
183
|
+
xs = [self.encoder(x)]
|
|
184
|
+
else:
|
|
185
|
+
xs = self.encoder.forward_features(x, out_indices=self.out_indices)
|
|
175
186
|
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
187
|
+
out: dict[str, torch.Tensor] = {}
|
|
188
|
+
for stage_name, stage_x in zip(self.return_stages, xs):
|
|
189
|
+
stage_x = stage_x[:, self.num_special_tokens :]
|
|
190
|
+
stage_x = stage_x.permute(0, 2, 1)
|
|
191
|
+
B, C, _ = stage_x.size()
|
|
192
|
+
stage_x = stage_x.reshape(B, C, H // self.patch_size, W // self.patch_size)
|
|
193
|
+
out[stage_name] = stage_x
|
|
180
194
|
|
|
181
|
-
return
|
|
195
|
+
return out
|
|
182
196
|
|
|
183
197
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
184
198
|
for param in self.conv_proj.parameters():
|
|
@@ -193,6 +207,10 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
|
|
|
193
207
|
for param in module.parameters():
|
|
194
208
|
param.requires_grad_(False)
|
|
195
209
|
|
|
210
|
+
def transform_to_backbone(self) -> None:
|
|
211
|
+
super().transform_to_backbone()
|
|
212
|
+
self.norm = nn.Identity()
|
|
213
|
+
|
|
196
214
|
def set_causal_attention(self, is_causal: bool = True) -> None:
|
|
197
215
|
self.encoder.set_causal_attention(is_causal)
|
|
198
216
|
|
|
@@ -203,7 +221,7 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
|
|
|
203
221
|
return_all_features: bool = False,
|
|
204
222
|
return_keys: Literal["all", "tokens", "embedding"] = "tokens",
|
|
205
223
|
) -> TokenOmissionResultType:
|
|
206
|
-
|
|
224
|
+
H, W = x.shape[-2:]
|
|
207
225
|
|
|
208
226
|
# Reshape and permute the input tensor
|
|
209
227
|
x = self.conv_proj(x)
|
|
@@ -266,7 +284,7 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
|
|
|
266
284
|
mask_token: Optional[torch.Tensor] = None,
|
|
267
285
|
return_keys: Literal["all", "features", "embedding"] = "features",
|
|
268
286
|
) -> TokenRetentionResultType:
|
|
269
|
-
|
|
287
|
+
H, W = x.shape[-2:]
|
|
270
288
|
|
|
271
289
|
x = self.conv_proj(x)
|
|
272
290
|
x = mask_tensor(x, mask, mask_token=mask_token, patch_factor=self.max_stride // self.stem_stride)
|
|
@@ -296,7 +314,7 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
|
|
|
296
314
|
if return_keys in ("all", "features"):
|
|
297
315
|
features = x[:, self.num_special_tokens :]
|
|
298
316
|
features = features.permute(0, 2, 1)
|
|
299
|
-
|
|
317
|
+
B, C, _ = features.size()
|
|
300
318
|
features = features.reshape(B, C, H // self.patch_size, W // self.patch_size)
|
|
301
319
|
result["features"] = features
|
|
302
320
|
|
|
@@ -306,7 +324,7 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
|
|
|
306
324
|
return result
|
|
307
325
|
|
|
308
326
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
309
|
-
|
|
327
|
+
H, W = x.shape[-2:]
|
|
310
328
|
|
|
311
329
|
# Reshape and permute the input tensor
|
|
312
330
|
x = self.conv_proj(x)
|
|
@@ -368,279 +386,126 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
|
|
|
368
386
|
registry.register_model_config(
|
|
369
387
|
"deit3_t16",
|
|
370
388
|
DeiT3,
|
|
371
|
-
config={
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
"drop_path_rate": 0.0,
|
|
378
|
-
},
|
|
389
|
+
config={"patch_size": 16, **TINY},
|
|
390
|
+
)
|
|
391
|
+
registry.register_model_config(
|
|
392
|
+
"deit3_t14",
|
|
393
|
+
DeiT3,
|
|
394
|
+
config={"patch_size": 14, **TINY},
|
|
379
395
|
)
|
|
380
396
|
registry.register_model_config(
|
|
381
397
|
"deit3_s16",
|
|
382
398
|
DeiT3,
|
|
383
|
-
config={
|
|
384
|
-
"patch_size": 16,
|
|
385
|
-
"num_layers": 12,
|
|
386
|
-
"num_heads": 6,
|
|
387
|
-
"hidden_dim": 384,
|
|
388
|
-
"mlp_dim": 1536,
|
|
389
|
-
"drop_path_rate": 0.05,
|
|
390
|
-
},
|
|
399
|
+
config={"patch_size": 16, **SMALL, "drop_path_rate": 0.05},
|
|
391
400
|
)
|
|
392
401
|
registry.register_model_config(
|
|
393
402
|
"deit3_s14",
|
|
394
403
|
DeiT3,
|
|
395
|
-
config={
|
|
396
|
-
"patch_size": 14,
|
|
397
|
-
"num_layers": 12,
|
|
398
|
-
"num_heads": 6,
|
|
399
|
-
"hidden_dim": 384,
|
|
400
|
-
"mlp_dim": 1536,
|
|
401
|
-
"drop_path_rate": 0.05,
|
|
402
|
-
},
|
|
404
|
+
config={"patch_size": 14, **SMALL, "drop_path_rate": 0.05},
|
|
403
405
|
)
|
|
404
406
|
registry.register_model_config(
|
|
405
407
|
"deit3_m16",
|
|
406
408
|
DeiT3,
|
|
407
|
-
config={
|
|
408
|
-
"patch_size": 16,
|
|
409
|
-
"num_layers": 12,
|
|
410
|
-
"num_heads": 8,
|
|
411
|
-
"hidden_dim": 512,
|
|
412
|
-
"mlp_dim": 2048,
|
|
413
|
-
"drop_path_rate": 0.1,
|
|
414
|
-
},
|
|
409
|
+
config={"patch_size": 16, **MEDIUM, "drop_path_rate": 0.1},
|
|
415
410
|
)
|
|
416
411
|
registry.register_model_config(
|
|
417
412
|
"deit3_m14",
|
|
418
413
|
DeiT3,
|
|
419
|
-
config={
|
|
420
|
-
"patch_size": 14,
|
|
421
|
-
"num_layers": 12,
|
|
422
|
-
"num_heads": 8,
|
|
423
|
-
"hidden_dim": 512,
|
|
424
|
-
"mlp_dim": 2048,
|
|
425
|
-
"drop_path_rate": 0.1,
|
|
426
|
-
},
|
|
414
|
+
config={"patch_size": 14, **MEDIUM, "drop_path_rate": 0.1},
|
|
427
415
|
)
|
|
428
416
|
registry.register_model_config(
|
|
429
417
|
"deit3_b16",
|
|
430
418
|
DeiT3,
|
|
431
|
-
config={
|
|
432
|
-
"patch_size": 16,
|
|
433
|
-
"num_layers": 12,
|
|
434
|
-
"num_heads": 12,
|
|
435
|
-
"hidden_dim": 768,
|
|
436
|
-
"mlp_dim": 3072,
|
|
437
|
-
"drop_path_rate": 0.2,
|
|
438
|
-
},
|
|
419
|
+
config={"patch_size": 16, **BASE, "drop_path_rate": 0.2},
|
|
439
420
|
)
|
|
440
421
|
registry.register_model_config(
|
|
441
422
|
"deit3_b14",
|
|
442
423
|
DeiT3,
|
|
443
|
-
config={
|
|
444
|
-
"patch_size": 14,
|
|
445
|
-
"num_layers": 12,
|
|
446
|
-
"num_heads": 12,
|
|
447
|
-
"hidden_dim": 768,
|
|
448
|
-
"mlp_dim": 3072,
|
|
449
|
-
"drop_path_rate": 0.2,
|
|
450
|
-
},
|
|
424
|
+
config={"patch_size": 14, **BASE, "drop_path_rate": 0.2},
|
|
451
425
|
)
|
|
452
426
|
registry.register_model_config(
|
|
453
427
|
"deit3_l16",
|
|
454
428
|
DeiT3,
|
|
455
|
-
config={
|
|
456
|
-
"patch_size": 16,
|
|
457
|
-
"num_layers": 24,
|
|
458
|
-
"num_heads": 16,
|
|
459
|
-
"hidden_dim": 1024,
|
|
460
|
-
"mlp_dim": 4096,
|
|
461
|
-
"drop_path_rate": 0.45,
|
|
462
|
-
},
|
|
429
|
+
config={"patch_size": 16, **LARGE, "drop_path_rate": 0.45},
|
|
463
430
|
)
|
|
464
431
|
registry.register_model_config(
|
|
465
432
|
"deit3_l14",
|
|
466
433
|
DeiT3,
|
|
467
|
-
config={
|
|
468
|
-
"patch_size": 14,
|
|
469
|
-
"num_layers": 24,
|
|
470
|
-
"num_heads": 16,
|
|
471
|
-
"hidden_dim": 1024,
|
|
472
|
-
"mlp_dim": 4096,
|
|
473
|
-
"drop_path_rate": 0.45,
|
|
474
|
-
},
|
|
434
|
+
config={"patch_size": 14, **LARGE, "drop_path_rate": 0.45},
|
|
475
435
|
)
|
|
476
436
|
registry.register_model_config(
|
|
477
437
|
"deit3_h16",
|
|
478
438
|
DeiT3,
|
|
479
|
-
config={
|
|
480
|
-
"patch_size": 16,
|
|
481
|
-
"num_layers": 32,
|
|
482
|
-
"num_heads": 16,
|
|
483
|
-
"hidden_dim": 1280,
|
|
484
|
-
"mlp_dim": 5120,
|
|
485
|
-
"drop_path_rate": 0.55,
|
|
486
|
-
},
|
|
439
|
+
config={"patch_size": 16, **HUGE, "drop_path_rate": 0.55},
|
|
487
440
|
)
|
|
488
441
|
registry.register_model_config(
|
|
489
442
|
"deit3_h14",
|
|
490
443
|
DeiT3,
|
|
491
|
-
config={
|
|
492
|
-
"patch_size": 14,
|
|
493
|
-
"num_layers": 32,
|
|
494
|
-
"num_heads": 16,
|
|
495
|
-
"hidden_dim": 1280,
|
|
496
|
-
"mlp_dim": 5120,
|
|
497
|
-
"drop_path_rate": 0.55,
|
|
498
|
-
},
|
|
444
|
+
config={"patch_size": 14, **HUGE, "drop_path_rate": 0.55},
|
|
499
445
|
)
|
|
500
446
|
|
|
501
447
|
# With registers
|
|
448
|
+
####################
|
|
449
|
+
|
|
502
450
|
registry.register_model_config(
|
|
503
451
|
"deit3_reg4_t16",
|
|
504
452
|
DeiT3,
|
|
505
|
-
config={
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
"num_reg_tokens": 4,
|
|
512
|
-
"drop_path_rate": 0.0,
|
|
513
|
-
},
|
|
453
|
+
config={"patch_size": 16, **TINY, "num_reg_tokens": 4},
|
|
454
|
+
)
|
|
455
|
+
registry.register_model_config(
|
|
456
|
+
"deit3_reg4_t14",
|
|
457
|
+
DeiT3,
|
|
458
|
+
config={"patch_size": 14, **TINY, "num_reg_tokens": 4},
|
|
514
459
|
)
|
|
515
460
|
registry.register_model_config(
|
|
516
461
|
"deit3_reg4_s16",
|
|
517
462
|
DeiT3,
|
|
518
|
-
config={
|
|
519
|
-
"patch_size": 16,
|
|
520
|
-
"num_layers": 12,
|
|
521
|
-
"num_heads": 6,
|
|
522
|
-
"hidden_dim": 384,
|
|
523
|
-
"mlp_dim": 1536,
|
|
524
|
-
"num_reg_tokens": 4,
|
|
525
|
-
"drop_path_rate": 0.05,
|
|
526
|
-
},
|
|
463
|
+
config={"patch_size": 16, **SMALL, "num_reg_tokens": 4, "drop_path_rate": 0.05},
|
|
527
464
|
)
|
|
528
465
|
registry.register_model_config(
|
|
529
466
|
"deit3_reg4_s14",
|
|
530
467
|
DeiT3,
|
|
531
|
-
config={
|
|
532
|
-
"patch_size": 14,
|
|
533
|
-
"num_layers": 12,
|
|
534
|
-
"num_heads": 6,
|
|
535
|
-
"hidden_dim": 384,
|
|
536
|
-
"mlp_dim": 1536,
|
|
537
|
-
"num_reg_tokens": 4,
|
|
538
|
-
"drop_path_rate": 0.05,
|
|
539
|
-
},
|
|
468
|
+
config={"patch_size": 14, **SMALL, "num_reg_tokens": 4, "drop_path_rate": 0.05},
|
|
540
469
|
)
|
|
541
470
|
registry.register_model_config(
|
|
542
471
|
"deit3_reg4_m16",
|
|
543
472
|
DeiT3,
|
|
544
|
-
config={
|
|
545
|
-
"patch_size": 16,
|
|
546
|
-
"num_layers": 12,
|
|
547
|
-
"num_heads": 8,
|
|
548
|
-
"hidden_dim": 512,
|
|
549
|
-
"mlp_dim": 2048,
|
|
550
|
-
"num_reg_tokens": 4,
|
|
551
|
-
"drop_path_rate": 0.1,
|
|
552
|
-
},
|
|
473
|
+
config={"patch_size": 16, **MEDIUM, "num_reg_tokens": 4, "drop_path_rate": 0.1},
|
|
553
474
|
)
|
|
554
475
|
registry.register_model_config(
|
|
555
476
|
"deit3_reg4_m14",
|
|
556
477
|
DeiT3,
|
|
557
|
-
config={
|
|
558
|
-
"patch_size": 14,
|
|
559
|
-
"num_layers": 12,
|
|
560
|
-
"num_heads": 8,
|
|
561
|
-
"hidden_dim": 512,
|
|
562
|
-
"mlp_dim": 2048,
|
|
563
|
-
"num_reg_tokens": 4,
|
|
564
|
-
"drop_path_rate": 0.1,
|
|
565
|
-
},
|
|
478
|
+
config={"patch_size": 14, **MEDIUM, "num_reg_tokens": 4, "drop_path_rate": 0.1},
|
|
566
479
|
)
|
|
567
480
|
registry.register_model_config(
|
|
568
481
|
"deit3_reg4_b16",
|
|
569
482
|
DeiT3,
|
|
570
|
-
config={
|
|
571
|
-
"patch_size": 16,
|
|
572
|
-
"num_layers": 12,
|
|
573
|
-
"num_heads": 12,
|
|
574
|
-
"hidden_dim": 768,
|
|
575
|
-
"mlp_dim": 3072,
|
|
576
|
-
"num_reg_tokens": 4,
|
|
577
|
-
"drop_path_rate": 0.2,
|
|
578
|
-
},
|
|
483
|
+
config={"patch_size": 16, **BASE, "num_reg_tokens": 4, "drop_path_rate": 0.2},
|
|
579
484
|
)
|
|
580
485
|
registry.register_model_config(
|
|
581
486
|
"deit3_reg4_b14",
|
|
582
487
|
DeiT3,
|
|
583
|
-
config={
|
|
584
|
-
"patch_size": 14,
|
|
585
|
-
"num_layers": 12,
|
|
586
|
-
"num_heads": 12,
|
|
587
|
-
"hidden_dim": 768,
|
|
588
|
-
"mlp_dim": 3072,
|
|
589
|
-
"num_reg_tokens": 4,
|
|
590
|
-
"drop_path_rate": 0.2,
|
|
591
|
-
},
|
|
488
|
+
config={"patch_size": 14, **BASE, "num_reg_tokens": 4, "drop_path_rate": 0.2},
|
|
592
489
|
)
|
|
593
490
|
registry.register_model_config(
|
|
594
491
|
"deit3_reg4_l16",
|
|
595
492
|
DeiT3,
|
|
596
|
-
config={
|
|
597
|
-
"patch_size": 16,
|
|
598
|
-
"num_layers": 24,
|
|
599
|
-
"num_heads": 16,
|
|
600
|
-
"hidden_dim": 1024,
|
|
601
|
-
"mlp_dim": 4096,
|
|
602
|
-
"num_reg_tokens": 4,
|
|
603
|
-
"drop_path_rate": 0.45,
|
|
604
|
-
},
|
|
493
|
+
config={"patch_size": 16, **LARGE, "num_reg_tokens": 4, "drop_path_rate": 0.45},
|
|
605
494
|
)
|
|
606
495
|
registry.register_model_config(
|
|
607
496
|
"deit3_reg4_l14",
|
|
608
497
|
DeiT3,
|
|
609
|
-
config={
|
|
610
|
-
"patch_size": 14,
|
|
611
|
-
"num_layers": 24,
|
|
612
|
-
"num_heads": 16,
|
|
613
|
-
"hidden_dim": 1024,
|
|
614
|
-
"mlp_dim": 4096,
|
|
615
|
-
"num_reg_tokens": 4,
|
|
616
|
-
"drop_path_rate": 0.45,
|
|
617
|
-
},
|
|
498
|
+
config={"patch_size": 14, **LARGE, "num_reg_tokens": 4, "drop_path_rate": 0.45},
|
|
618
499
|
)
|
|
619
500
|
registry.register_model_config(
|
|
620
501
|
"deit3_reg4_h16",
|
|
621
502
|
DeiT3,
|
|
622
|
-
config={
|
|
623
|
-
"patch_size": 16,
|
|
624
|
-
"num_layers": 32,
|
|
625
|
-
"num_heads": 16,
|
|
626
|
-
"hidden_dim": 1280,
|
|
627
|
-
"mlp_dim": 5120,
|
|
628
|
-
"num_reg_tokens": 4,
|
|
629
|
-
"drop_path_rate": 0.55,
|
|
630
|
-
},
|
|
503
|
+
config={"patch_size": 16, **HUGE, "num_reg_tokens": 4, "drop_path_rate": 0.55},
|
|
631
504
|
)
|
|
632
505
|
registry.register_model_config(
|
|
633
506
|
"deit3_reg4_h14",
|
|
634
507
|
DeiT3,
|
|
635
|
-
config={
|
|
636
|
-
"patch_size": 14,
|
|
637
|
-
"num_layers": 32,
|
|
638
|
-
"num_heads": 16,
|
|
639
|
-
"hidden_dim": 1280,
|
|
640
|
-
"mlp_dim": 5120,
|
|
641
|
-
"num_reg_tokens": 4,
|
|
642
|
-
"drop_path_rate": 0.55,
|
|
643
|
-
},
|
|
508
|
+
config={"patch_size": 14, **HUGE, "num_reg_tokens": 4, "drop_path_rate": 0.55},
|
|
644
509
|
)
|
|
645
510
|
|
|
646
511
|
registry.register_weights(
|
|
@@ -651,7 +516,7 @@ registry.register_weights(
|
|
|
651
516
|
"formats": {
|
|
652
517
|
"pt": {
|
|
653
518
|
"file_size": 21.5,
|
|
654
|
-
"sha256": "
|
|
519
|
+
"sha256": "a04141c7f6c459ae075a48ccdee5b82d191bbaa82337673140c06ef82f0a8dc5",
|
|
655
520
|
}
|
|
656
521
|
},
|
|
657
522
|
"net": {"network": "deit3_t16", "tag": "il-common"},
|
|
@@ -665,7 +530,7 @@ registry.register_weights(
|
|
|
665
530
|
"formats": {
|
|
666
531
|
"pt": {
|
|
667
532
|
"file_size": 21.5,
|
|
668
|
-
"sha256": "
|
|
533
|
+
"sha256": "d26320462da64df6d62b307f7fb35d09c86a5f073002dfb24a51f014074e65c3",
|
|
669
534
|
}
|
|
670
535
|
},
|
|
671
536
|
"net": {"network": "deit3_reg4_t16", "tag": "il-common"},
|
birder/net/densenet.py
CHANGED
|
@@ -104,19 +104,20 @@ class DenseNet(DetectorBackbone):
|
|
|
104
104
|
num_features = num_init_features
|
|
105
105
|
stages: OrderedDict[str, nn.Module] = OrderedDict()
|
|
106
106
|
return_channels: list[int] = []
|
|
107
|
-
layers = []
|
|
108
107
|
for i, num_layers in enumerate(layer_list):
|
|
108
|
+
stage_layers = []
|
|
109
|
+
if i != 0:
|
|
110
|
+
stage_layers.append(TransitionBlock(num_features, num_features // 2))
|
|
111
|
+
num_features = num_features // 2
|
|
109
112
|
|
|
110
|
-
|
|
113
|
+
stage_layers.append(DenseBlock(num_features, num_layers=num_layers, growth_rate=growth_rate))
|
|
111
114
|
num_features = num_features + (num_layers * growth_rate)
|
|
115
|
+
if i == len(layer_list) - 1:
|
|
116
|
+
stage_layers.append(nn.BatchNorm2d(num_features))
|
|
117
|
+
stage_layers.append(nn.ReLU(inplace=True))
|
|
112
118
|
|
|
113
|
-
stages[f"stage{i+1}"] = nn.Sequential(*
|
|
119
|
+
stages[f"stage{i+1}"] = nn.Sequential(*stage_layers)
|
|
114
120
|
return_channels.append(num_features)
|
|
115
|
-
layers = []
|
|
116
|
-
|
|
117
|
-
if i != len(layer_list) - 1:
|
|
118
|
-
layers.append(TransitionBlock(num_features, num_features // 2))
|
|
119
|
-
num_features = num_features // 2
|
|
120
121
|
|
|
121
122
|
self.body = nn.Sequential(stages)
|
|
122
123
|
self.features = nn.Sequential(
|
birder/net/detection/__init__.py
CHANGED
|
@@ -3,8 +3,10 @@ from birder.net.detection.detr import DETR
|
|
|
3
3
|
from birder.net.detection.efficientdet import EfficientDet
|
|
4
4
|
from birder.net.detection.faster_rcnn import Faster_RCNN
|
|
5
5
|
from birder.net.detection.fcos import FCOS
|
|
6
|
+
from birder.net.detection.plain_detr import Plain_DETR
|
|
6
7
|
from birder.net.detection.retinanet import RetinaNet
|
|
7
8
|
from birder.net.detection.rt_detr_v1 import RT_DETR_v1
|
|
9
|
+
from birder.net.detection.rt_detr_v2 import RT_DETR_v2
|
|
8
10
|
from birder.net.detection.ssd import SSD
|
|
9
11
|
from birder.net.detection.ssdlite import SSDLite
|
|
10
12
|
from birder.net.detection.vitdet import ViTDet
|
|
@@ -19,8 +21,10 @@ __all__ = [
|
|
|
19
21
|
"EfficientDet",
|
|
20
22
|
"Faster_RCNN",
|
|
21
23
|
"FCOS",
|
|
24
|
+
"Plain_DETR",
|
|
22
25
|
"RetinaNet",
|
|
23
26
|
"RT_DETR_v1",
|
|
27
|
+
"RT_DETR_v2",
|
|
24
28
|
"SSD",
|
|
25
29
|
"SSDLite",
|
|
26
30
|
"ViTDet",
|
|
@@ -71,7 +71,7 @@ def scale_anchors(anchors: AnchorGroups, from_size: tuple[int, int], to_size: tu
|
|
|
71
71
|
|
|
72
72
|
|
|
73
73
|
def scale_anchors(anchors: AnchorLike, from_size: tuple[int, int], to_size: tuple[int, int]) -> AnchorLike:
|
|
74
|
-
|
|
74
|
+
anchor_groups, single = _normalize_anchor_groups(anchors)
|
|
75
75
|
|
|
76
76
|
if from_size == to_size:
|
|
77
77
|
# Avoid aliasing default anchors in case they are mutated later
|
|
@@ -100,7 +100,7 @@ def pixels_to_grid(anchors: AnchorGroups, strides: Sequence[int]) -> AnchorGroup
|
|
|
100
100
|
|
|
101
101
|
|
|
102
102
|
def pixels_to_grid(anchors: AnchorLike, strides: Sequence[int]) -> AnchorLike:
|
|
103
|
-
|
|
103
|
+
anchor_groups, single = _normalize_anchor_groups(anchors)
|
|
104
104
|
if len(anchor_groups) != len(strides):
|
|
105
105
|
raise ValueError("strides must provide one value per anchor scale")
|
|
106
106
|
|
|
@@ -123,7 +123,7 @@ def grid_to_pixels(anchors: AnchorGroups, strides: Sequence[int]) -> AnchorGroup
|
|
|
123
123
|
|
|
124
124
|
|
|
125
125
|
def grid_to_pixels(anchors: AnchorLike, strides: Sequence[int]) -> AnchorLike:
|
|
126
|
-
|
|
126
|
+
anchor_groups, single = _normalize_anchor_groups(anchors)
|
|
127
127
|
if len(anchor_groups) != len(strides):
|
|
128
128
|
raise ValueError("strides must provide one value per anchor scale")
|
|
129
129
|
|
|
@@ -187,7 +187,7 @@ def resolve_anchor_group(
|
|
|
187
187
|
preset: str, *, anchor_format: str, model_size: tuple[int, int], model_strides: Sequence[int]
|
|
188
188
|
) -> AnchorGroup:
|
|
189
189
|
anchors = _resolve_anchors(preset, anchor_format=anchor_format, model_size=model_size, model_strides=model_strides)
|
|
190
|
-
|
|
190
|
+
anchor_groups, single = _normalize_anchor_groups(anchors)
|
|
191
191
|
if single is False:
|
|
192
192
|
raise ValueError("Expected a single anchor group for this model")
|
|
193
193
|
|
|
@@ -198,7 +198,7 @@ def resolve_anchor_groups(
|
|
|
198
198
|
preset: str, *, anchor_format: str, model_size: tuple[int, int], model_strides: Sequence[int]
|
|
199
199
|
) -> AnchorGroups:
|
|
200
200
|
anchors = _resolve_anchors(preset, anchor_format=anchor_format, model_size=model_size, model_strides=model_strides)
|
|
201
|
-
|
|
201
|
+
anchor_groups, single = _normalize_anchor_groups(anchors)
|
|
202
202
|
if single is True:
|
|
203
203
|
raise ValueError("Expected multiple anchor groups for this model")
|
|
204
204
|
|
birder/net/detection/base.py
CHANGED
|
@@ -41,6 +41,7 @@ def get_detection_signature(input_shape: tuple[int, ...], num_outputs: int, dyna
|
|
|
41
41
|
|
|
42
42
|
class DetectionBaseNet(nn.Module):
|
|
43
43
|
default_size: tuple[int, int]
|
|
44
|
+
block_group_regex: Optional[str]
|
|
44
45
|
auto_register = False
|
|
45
46
|
scriptable = True
|
|
46
47
|
task = str(Task.OBJECT_DETECTION)
|
|
@@ -308,7 +309,7 @@ class SimpleFeaturePyramidNetwork(nn.Module):
|
|
|
308
309
|
names.append(f"stage{idx+1}")
|
|
309
310
|
|
|
310
311
|
if self.extra_blocks is not None:
|
|
311
|
-
|
|
312
|
+
results, names = self.extra_blocks(results, [x], names)
|
|
312
313
|
|
|
313
314
|
out = OrderedDict(list(zip(names, results)))
|
|
314
315
|
|
|
@@ -432,7 +433,7 @@ class BoxCoder:
|
|
|
432
433
|
ctr_x = boxes[:, 0] + 0.5 * widths
|
|
433
434
|
ctr_y = boxes[:, 1] + 0.5 * heights
|
|
434
435
|
|
|
435
|
-
|
|
436
|
+
wx, wy, ww, wh = self.weights
|
|
436
437
|
dx = rel_codes[:, 0::4] / wx
|
|
437
438
|
dy = rel_codes[:, 1::4] / wy
|
|
438
439
|
dw = rel_codes[:, 2::4] / ww
|
|
@@ -510,8 +511,8 @@ class AnchorGenerator(nn.Module):
|
|
|
510
511
|
)
|
|
511
512
|
|
|
512
513
|
for size, stride, base_anchors in zip(grid_sizes, strides, cell_anchors):
|
|
513
|
-
|
|
514
|
-
|
|
514
|
+
grid_height, grid_width = size
|
|
515
|
+
stride_height, stride_width = stride
|
|
515
516
|
device = base_anchors.device
|
|
516
517
|
|
|
517
518
|
# For output anchor, compute [x_center, y_center, x_center, y_center]
|
|
@@ -656,7 +657,7 @@ class Matcher(nn.Module):
|
|
|
656
657
|
# tensor([39796, 32055, 32070, 39190, 40255, 40390, 41455, 45470, 45325, 46390]))
|
|
657
658
|
# Each element in the first tensor is a gt index,
|
|
658
659
|
# and each element in second tensor is a prediction index
|
|
659
|
-
# Note how gt items 1, 2, 3
|
|
660
|
+
# Note how gt items 1, 2, 3 and 5 each have two ties
|
|
660
661
|
|
|
661
662
|
pred_idx_to_update = gt_pred_pairs_of_highest_quality[1]
|
|
662
663
|
matches[pred_idx_to_update] = all_matches[pred_idx_to_update]
|