birder 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/adversarial/deepfool.py +2 -0
- birder/adversarial/simba.py +2 -0
- birder/common/fs_ops.py +2 -2
- birder/common/masking.py +13 -4
- birder/common/training_cli.py +6 -1
- birder/common/training_utils.py +4 -2
- birder/inference/classification.py +1 -1
- birder/introspection/__init__.py +2 -0
- birder/introspection/base.py +0 -7
- birder/introspection/feature_pca.py +101 -0
- birder/kernels/soft_nms/soft_nms.cpp +5 -2
- birder/model_registry/model_registry.py +3 -2
- birder/net/base.py +3 -3
- birder/net/biformer.py +2 -2
- birder/net/cas_vit.py +6 -6
- birder/net/coat.py +8 -8
- birder/net/conv2former.py +2 -2
- birder/net/convnext_v1.py +22 -2
- birder/net/convnext_v2.py +2 -2
- birder/net/crossformer.py +2 -2
- birder/net/cspnet.py +2 -2
- birder/net/cswin_transformer.py +2 -2
- birder/net/darknet.py +2 -2
- birder/net/davit.py +2 -2
- birder/net/deit.py +3 -3
- birder/net/deit3.py +3 -3
- birder/net/densenet.py +2 -2
- birder/net/detection/deformable_detr.py +2 -2
- birder/net/detection/detr.py +2 -2
- birder/net/detection/efficientdet.py +2 -2
- birder/net/detection/faster_rcnn.py +2 -2
- birder/net/detection/fcos.py +2 -2
- birder/net/detection/retinanet.py +2 -2
- birder/net/detection/rt_detr_v1.py +4 -4
- birder/net/detection/ssd.py +2 -2
- birder/net/detection/ssdlite.py +2 -2
- birder/net/detection/yolo_v2.py +2 -2
- birder/net/detection/yolo_v3.py +2 -2
- birder/net/detection/yolo_v4.py +2 -2
- birder/net/edgenext.py +2 -2
- birder/net/edgevit.py +1 -1
- birder/net/efficientformer_v1.py +4 -4
- birder/net/efficientformer_v2.py +6 -6
- birder/net/efficientnet_lite.py +2 -2
- birder/net/efficientnet_v1.py +2 -2
- birder/net/efficientnet_v2.py +2 -2
- birder/net/efficientvim.py +3 -3
- birder/net/efficientvit_mit.py +2 -2
- birder/net/efficientvit_msft.py +2 -2
- birder/net/fasternet.py +2 -2
- birder/net/fastvit.py +2 -3
- birder/net/flexivit.py +11 -6
- birder/net/focalnet.py +2 -3
- birder/net/gc_vit.py +17 -2
- birder/net/ghostnet_v1.py +2 -2
- birder/net/ghostnet_v2.py +2 -2
- birder/net/groupmixformer.py +2 -2
- birder/net/hgnet_v1.py +2 -2
- birder/net/hgnet_v2.py +2 -2
- birder/net/hiera.py +2 -2
- birder/net/hieradet.py +2 -2
- birder/net/hornet.py +2 -2
- birder/net/iformer.py +2 -2
- birder/net/inception_next.py +2 -2
- birder/net/inception_resnet_v1.py +2 -2
- birder/net/inception_resnet_v2.py +2 -2
- birder/net/inception_v3.py +2 -2
- birder/net/inception_v4.py +2 -2
- birder/net/levit.py +4 -4
- birder/net/lit_v1.py +2 -2
- birder/net/lit_v1_tiny.py +2 -2
- birder/net/lit_v2.py +2 -2
- birder/net/maxvit.py +2 -2
- birder/net/metaformer.py +2 -2
- birder/net/mnasnet.py +2 -2
- birder/net/mobilenet_v1.py +2 -2
- birder/net/mobilenet_v2.py +2 -2
- birder/net/mobilenet_v3_large.py +2 -2
- birder/net/mobilenet_v4.py +2 -2
- birder/net/mobilenet_v4_hybrid.py +2 -2
- birder/net/mobileone.py +2 -2
- birder/net/mobilevit_v2.py +2 -2
- birder/net/moganet.py +2 -2
- birder/net/mvit_v2.py +2 -2
- birder/net/nextvit.py +2 -2
- birder/net/nfnet.py +2 -2
- birder/net/pit.py +6 -6
- birder/net/pvt_v1.py +2 -2
- birder/net/pvt_v2.py +2 -2
- birder/net/rdnet.py +2 -2
- birder/net/regionvit.py +6 -6
- birder/net/regnet.py +2 -2
- birder/net/regnet_z.py +2 -2
- birder/net/repghost.py +2 -2
- birder/net/repvgg.py +2 -2
- birder/net/repvit.py +6 -6
- birder/net/resnest.py +2 -2
- birder/net/resnet_v1.py +2 -2
- birder/net/resnet_v2.py +2 -2
- birder/net/resnext.py +2 -2
- birder/net/rope_deit3.py +3 -3
- birder/net/rope_flexivit.py +13 -6
- birder/net/rope_vit.py +69 -10
- birder/net/shufflenet_v1.py +2 -2
- birder/net/shufflenet_v2.py +2 -2
- birder/net/smt.py +1 -2
- birder/net/squeezenext.py +2 -2
- birder/net/ssl/byol.py +3 -2
- birder/net/ssl/capi.py +156 -11
- birder/net/ssl/data2vec.py +3 -1
- birder/net/ssl/data2vec2.py +3 -1
- birder/net/ssl/dino_v1.py +1 -1
- birder/net/ssl/dino_v2.py +140 -18
- birder/net/ssl/franca.py +145 -13
- birder/net/ssl/ibot.py +1 -2
- birder/net/ssl/mmcr.py +3 -1
- birder/net/starnet.py +2 -2
- birder/net/swiftformer.py +6 -6
- birder/net/swin_transformer_v1.py +2 -2
- birder/net/swin_transformer_v2.py +2 -2
- birder/net/tiny_vit.py +2 -2
- birder/net/transnext.py +1 -1
- birder/net/uniformer.py +1 -1
- birder/net/van.py +1 -1
- birder/net/vgg.py +1 -1
- birder/net/vgg_reduced.py +1 -1
- birder/net/vit.py +172 -8
- birder/net/vit_parallel.py +5 -5
- birder/net/vit_sam.py +3 -3
- birder/net/vovnet_v1.py +2 -2
- birder/net/vovnet_v2.py +2 -2
- birder/net/wide_resnet.py +2 -2
- birder/net/xception.py +2 -2
- birder/net/xcit.py +2 -2
- birder/results/detection.py +104 -0
- birder/results/gui.py +10 -8
- birder/scripts/benchmark.py +1 -1
- birder/scripts/train.py +13 -18
- birder/scripts/train_barlow_twins.py +10 -14
- birder/scripts/train_byol.py +11 -15
- birder/scripts/train_capi.py +38 -17
- birder/scripts/train_data2vec.py +11 -15
- birder/scripts/train_data2vec2.py +13 -17
- birder/scripts/train_detection.py +11 -14
- birder/scripts/train_dino_v1.py +20 -22
- birder/scripts/train_dino_v2.py +126 -63
- birder/scripts/train_dino_v2_dist.py +127 -64
- birder/scripts/train_franca.py +49 -34
- birder/scripts/train_i_jepa.py +11 -14
- birder/scripts/train_ibot.py +16 -18
- birder/scripts/train_kd.py +14 -20
- birder/scripts/train_mim.py +10 -13
- birder/scripts/train_mmcr.py +11 -15
- birder/scripts/train_rotnet.py +12 -16
- birder/scripts/train_simclr.py +10 -14
- birder/scripts/train_vicreg.py +10 -14
- birder/tools/avg_model.py +24 -8
- birder/tools/det_results.py +91 -0
- birder/tools/introspection.py +35 -9
- birder/tools/results.py +11 -7
- birder/tools/show_iterator.py +1 -1
- birder/version.py +1 -1
- {birder-0.3.0.dist-info → birder-0.3.2.dist-info}/METADATA +1 -1
- birder-0.3.2.dist-info/RECORD +299 -0
- birder-0.3.0.dist-info/RECORD +0 -298
- {birder-0.3.0.dist-info → birder-0.3.2.dist-info}/WHEEL +0 -0
- {birder-0.3.0.dist-info → birder-0.3.2.dist-info}/entry_points.txt +0 -0
- {birder-0.3.0.dist-info → birder-0.3.2.dist-info}/licenses/LICENSE +0 -0
- {birder-0.3.0.dist-info → birder-0.3.2.dist-info}/top_level.txt +0 -0
birder/net/regionvit.py
CHANGED
|
@@ -464,14 +464,14 @@ class RegionViT(DetectorBackbone):
|
|
|
464
464
|
|
|
465
465
|
def freeze(self, freeze_classifier: bool = True, unfreeze_features: bool = False) -> None:
|
|
466
466
|
for param in self.parameters():
|
|
467
|
-
param.
|
|
467
|
+
param.requires_grad_(False)
|
|
468
468
|
|
|
469
469
|
if freeze_classifier is False:
|
|
470
470
|
for param in self.classifier.parameters():
|
|
471
|
-
param.
|
|
471
|
+
param.requires_grad_(True)
|
|
472
472
|
if unfreeze_features is True:
|
|
473
473
|
for param in self.norm.parameters():
|
|
474
|
-
param.
|
|
474
|
+
param.requires_grad_(True)
|
|
475
475
|
|
|
476
476
|
def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
477
477
|
o_x = x
|
|
@@ -488,16 +488,16 @@ class RegionViT(DetectorBackbone):
|
|
|
488
488
|
|
|
489
489
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
490
490
|
for param in self.patch_embed.parameters():
|
|
491
|
-
param.
|
|
491
|
+
param.requires_grad_(False)
|
|
492
492
|
for param in self.cls_token.parameters():
|
|
493
|
-
param.
|
|
493
|
+
param.requires_grad_(False)
|
|
494
494
|
|
|
495
495
|
for idx, module in enumerate(self.body.children()):
|
|
496
496
|
if idx >= up_to_stage:
|
|
497
497
|
break
|
|
498
498
|
|
|
499
499
|
for param in module.parameters():
|
|
500
|
-
param.
|
|
500
|
+
param.requires_grad_(False)
|
|
501
501
|
|
|
502
502
|
def forward_features(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
503
503
|
o_x = x
|
birder/net/regnet.py
CHANGED
|
@@ -364,14 +364,14 @@ class RegNet(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
364
364
|
|
|
365
365
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
366
366
|
for param in self.stem.parameters():
|
|
367
|
-
param.
|
|
367
|
+
param.requires_grad_(False)
|
|
368
368
|
|
|
369
369
|
for idx, module in enumerate(self.body.children()):
|
|
370
370
|
if idx >= up_to_stage:
|
|
371
371
|
break
|
|
372
372
|
|
|
373
373
|
for param in module.parameters():
|
|
374
|
-
param.
|
|
374
|
+
param.requires_grad_(False)
|
|
375
375
|
|
|
376
376
|
def masked_encoding_retention(
|
|
377
377
|
self,
|
birder/net/regnet_z.py
CHANGED
|
@@ -210,14 +210,14 @@ class RegNet_Z(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
210
210
|
|
|
211
211
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
212
212
|
for param in self.stem.parameters():
|
|
213
|
-
param.
|
|
213
|
+
param.requires_grad_(False)
|
|
214
214
|
|
|
215
215
|
for idx, module in enumerate(self.body.children()):
|
|
216
216
|
if idx >= up_to_stage:
|
|
217
217
|
break
|
|
218
218
|
|
|
219
219
|
for param in module.parameters():
|
|
220
|
-
param.
|
|
220
|
+
param.requires_grad_(False)
|
|
221
221
|
|
|
222
222
|
def masked_encoding_retention(
|
|
223
223
|
self,
|
birder/net/repghost.py
CHANGED
|
@@ -321,14 +321,14 @@ class RepGhost(DetectorBackbone):
|
|
|
321
321
|
|
|
322
322
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
323
323
|
for param in self.stem.parameters():
|
|
324
|
-
param.
|
|
324
|
+
param.requires_grad_(False)
|
|
325
325
|
|
|
326
326
|
for idx, module in enumerate(self.body.children()):
|
|
327
327
|
if idx >= up_to_stage:
|
|
328
328
|
break
|
|
329
329
|
|
|
330
330
|
for param in module.parameters():
|
|
331
|
-
param.
|
|
331
|
+
param.requires_grad_(False)
|
|
332
332
|
|
|
333
333
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
334
334
|
x = self.stem(x)
|
birder/net/repvgg.py
CHANGED
|
@@ -302,14 +302,14 @@ class RepVgg(DetectorBackbone):
|
|
|
302
302
|
|
|
303
303
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
304
304
|
for param in self.stem.parameters():
|
|
305
|
-
param.
|
|
305
|
+
param.requires_grad_(False)
|
|
306
306
|
|
|
307
307
|
for idx, module in enumerate(self.body.children()):
|
|
308
308
|
if idx >= up_to_stage:
|
|
309
309
|
break
|
|
310
310
|
|
|
311
311
|
for param in module.parameters():
|
|
312
|
-
param.
|
|
312
|
+
param.requires_grad_(False)
|
|
313
313
|
|
|
314
314
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
315
315
|
x = self.stem(x)
|
birder/net/repvit.py
CHANGED
|
@@ -399,18 +399,18 @@ class RepViT(DetectorBackbone):
|
|
|
399
399
|
|
|
400
400
|
def freeze(self, freeze_classifier: bool = True, unfreeze_features: bool = False) -> None:
|
|
401
401
|
for param in self.parameters():
|
|
402
|
-
param.
|
|
402
|
+
param.requires_grad_(False)
|
|
403
403
|
|
|
404
404
|
if freeze_classifier is False:
|
|
405
405
|
for param in self.classifier.parameters():
|
|
406
|
-
param.
|
|
406
|
+
param.requires_grad_(True)
|
|
407
407
|
|
|
408
408
|
for param in self.dist_classifier.parameters():
|
|
409
|
-
param.
|
|
409
|
+
param.requires_grad_(True)
|
|
410
410
|
|
|
411
411
|
if unfreeze_features is True:
|
|
412
412
|
for param in self.features.parameters():
|
|
413
|
-
param.
|
|
413
|
+
param.requires_grad_(True)
|
|
414
414
|
|
|
415
415
|
def transform_to_backbone(self) -> None:
|
|
416
416
|
self.features = nn.Identity()
|
|
@@ -430,14 +430,14 @@ class RepViT(DetectorBackbone):
|
|
|
430
430
|
|
|
431
431
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
432
432
|
for param in self.stem.parameters():
|
|
433
|
-
param.
|
|
433
|
+
param.requires_grad_(False)
|
|
434
434
|
|
|
435
435
|
for idx, module in enumerate(self.body.children()):
|
|
436
436
|
if idx >= up_to_stage:
|
|
437
437
|
break
|
|
438
438
|
|
|
439
439
|
for param in module.parameters():
|
|
440
|
-
param.
|
|
440
|
+
param.requires_grad_(False)
|
|
441
441
|
|
|
442
442
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
443
443
|
x = self.stem(x)
|
birder/net/resnest.py
CHANGED
|
@@ -271,14 +271,14 @@ class ResNeSt(DetectorBackbone):
|
|
|
271
271
|
|
|
272
272
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
273
273
|
for param in self.stem.parameters():
|
|
274
|
-
param.
|
|
274
|
+
param.requires_grad_(False)
|
|
275
275
|
|
|
276
276
|
for idx, module in enumerate(self.body.children()):
|
|
277
277
|
if idx >= up_to_stage:
|
|
278
278
|
break
|
|
279
279
|
|
|
280
280
|
for param in module.parameters():
|
|
281
|
-
param.
|
|
281
|
+
param.requires_grad_(False)
|
|
282
282
|
|
|
283
283
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
284
284
|
x = self.stem(x)
|
birder/net/resnet_v1.py
CHANGED
|
@@ -192,14 +192,14 @@ class ResNet_v1(DetectorBackbone):
|
|
|
192
192
|
|
|
193
193
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
194
194
|
for param in self.stem.parameters():
|
|
195
|
-
param.
|
|
195
|
+
param.requires_grad_(False)
|
|
196
196
|
|
|
197
197
|
for idx, module in enumerate(self.body.children()):
|
|
198
198
|
if idx >= up_to_stage:
|
|
199
199
|
break
|
|
200
200
|
|
|
201
201
|
for param in module.parameters():
|
|
202
|
-
param.
|
|
202
|
+
param.requires_grad_(False)
|
|
203
203
|
|
|
204
204
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
205
205
|
x = self.stem(x)
|
birder/net/resnet_v2.py
CHANGED
|
@@ -178,14 +178,14 @@ class ResNet_v2(DetectorBackbone):
|
|
|
178
178
|
|
|
179
179
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
180
180
|
for param in self.stem.parameters():
|
|
181
|
-
param.
|
|
181
|
+
param.requires_grad_(False)
|
|
182
182
|
|
|
183
183
|
for idx, module in enumerate(self.body.children()):
|
|
184
184
|
if idx >= up_to_stage:
|
|
185
185
|
break
|
|
186
186
|
|
|
187
187
|
for param in module.parameters():
|
|
188
|
-
param.
|
|
188
|
+
param.requires_grad_(False)
|
|
189
189
|
|
|
190
190
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
191
191
|
x = self.stem(x)
|
birder/net/resnext.py
CHANGED
|
@@ -216,14 +216,14 @@ class ResNeXt(DetectorBackbone):
|
|
|
216
216
|
|
|
217
217
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
218
218
|
for param in self.stem.parameters():
|
|
219
|
-
param.
|
|
219
|
+
param.requires_grad_(False)
|
|
220
220
|
|
|
221
221
|
for idx, module in enumerate(self.body.children()):
|
|
222
222
|
if idx >= up_to_stage:
|
|
223
223
|
break
|
|
224
224
|
|
|
225
225
|
for param in module.parameters():
|
|
226
|
-
param.
|
|
226
|
+
param.requires_grad_(False)
|
|
227
227
|
|
|
228
228
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
229
229
|
x = self.stem(x)
|
birder/net/rope_deit3.py
CHANGED
|
@@ -245,16 +245,16 @@ class RoPE_DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Ma
|
|
|
245
245
|
|
|
246
246
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
247
247
|
for param in self.conv_proj.parameters():
|
|
248
|
-
param.
|
|
248
|
+
param.requires_grad_(False)
|
|
249
249
|
|
|
250
|
-
self.pos_embedding.
|
|
250
|
+
self.pos_embedding.requires_grad_(False)
|
|
251
251
|
|
|
252
252
|
for idx, module in enumerate(self.encoder.children()):
|
|
253
253
|
if idx >= up_to_stage:
|
|
254
254
|
break
|
|
255
255
|
|
|
256
256
|
for param in module.parameters():
|
|
257
|
-
param.
|
|
257
|
+
param.requires_grad_(False)
|
|
258
258
|
|
|
259
259
|
def set_causal_attention(self, is_causal: bool = True) -> None:
|
|
260
260
|
self.encoder.set_causal_attention(is_causal)
|
birder/net/rope_flexivit.py
CHANGED
|
@@ -69,6 +69,8 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
69
69
|
layer_scale_init_value: Optional[float] = self.config.get("layer_scale_init_value", None)
|
|
70
70
|
pre_norm: bool = self.config.get("pre_norm", False)
|
|
71
71
|
post_norm: bool = self.config.get("post_norm", True)
|
|
72
|
+
qkv_bias: bool = self.config.get("qkv_bias", True)
|
|
73
|
+
qk_norm: bool = self.config.get("qk_norm", False)
|
|
72
74
|
num_reg_tokens: int = self.config.get("num_reg_tokens", 0)
|
|
73
75
|
class_token: bool = self.config.get("class_token", True)
|
|
74
76
|
attn_pool_head: bool = self.config.get("attn_pool_head", False)
|
|
@@ -118,6 +120,7 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
118
120
|
self.num_reg_tokens = num_reg_tokens
|
|
119
121
|
self.attn_pool_special_tokens = attn_pool_special_tokens
|
|
120
122
|
self.norm_layer = norm_layer
|
|
123
|
+
self.norm_layer_eps = norm_layer_eps
|
|
121
124
|
self.mlp_layer = mlp_layer
|
|
122
125
|
self.act_layer = act_layer
|
|
123
126
|
self.rope_rot_type = rope_rot_type
|
|
@@ -190,6 +193,8 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
190
193
|
attention_dropout,
|
|
191
194
|
dpr,
|
|
192
195
|
pre_norm=pre_norm,
|
|
196
|
+
qkv_bias=qkv_bias,
|
|
197
|
+
qk_norm=qk_norm,
|
|
193
198
|
activation_layer=act_layer,
|
|
194
199
|
layer_scale_init_value=layer_scale_init_value,
|
|
195
200
|
norm_layer=norm_layer,
|
|
@@ -231,6 +236,7 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
231
236
|
rope_temperature=rope_temperature,
|
|
232
237
|
layer_scale_init_value=layer_scale_init_value,
|
|
233
238
|
norm_layer=norm_layer,
|
|
239
|
+
norm_layer_eps=norm_layer_eps,
|
|
234
240
|
mlp_layer=mlp_layer,
|
|
235
241
|
rope_rot_type=rope_rot_type,
|
|
236
242
|
)
|
|
@@ -285,16 +291,16 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
285
291
|
|
|
286
292
|
def freeze(self, freeze_classifier: bool = True, unfreeze_features: bool = False) -> None:
|
|
287
293
|
for param in self.parameters():
|
|
288
|
-
param.
|
|
294
|
+
param.requires_grad_(False)
|
|
289
295
|
|
|
290
296
|
if freeze_classifier is False:
|
|
291
297
|
for param in self.classifier.parameters():
|
|
292
|
-
param.
|
|
298
|
+
param.requires_grad_(True)
|
|
293
299
|
|
|
294
300
|
if unfreeze_features is True:
|
|
295
301
|
if self.attn_pool is not None:
|
|
296
302
|
for param in self.attn_pool.parameters():
|
|
297
|
-
param.
|
|
303
|
+
param.requires_grad_(True)
|
|
298
304
|
|
|
299
305
|
def set_causal_attention(self, is_causal: bool = True) -> None:
|
|
300
306
|
self.encoder.set_causal_attention(is_causal)
|
|
@@ -332,16 +338,16 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
332
338
|
|
|
333
339
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
334
340
|
for param in self.conv_proj.parameters():
|
|
335
|
-
param.
|
|
341
|
+
param.requires_grad_(False)
|
|
336
342
|
|
|
337
|
-
self.pos_embedding.
|
|
343
|
+
self.pos_embedding.requires_grad_(False)
|
|
338
344
|
|
|
339
345
|
for idx, module in enumerate(self.encoder.children()):
|
|
340
346
|
if idx >= up_to_stage:
|
|
341
347
|
break
|
|
342
348
|
|
|
343
349
|
for param in module.parameters():
|
|
344
|
-
param.
|
|
350
|
+
param.requires_grad_(False)
|
|
345
351
|
|
|
346
352
|
# pylint: disable=too-many-branches
|
|
347
353
|
def masked_encoding_omission(
|
|
@@ -588,6 +594,7 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
588
594
|
rope_temperature=self.rope_temperature,
|
|
589
595
|
layer_scale_init_value=self.layer_scale_init_value,
|
|
590
596
|
norm_layer=self.norm_layer,
|
|
597
|
+
norm_layer_eps=self.norm_layer_eps,
|
|
591
598
|
mlp_layer=self.mlp_layer,
|
|
592
599
|
rope_rot_type=self.rope_rot_type,
|
|
593
600
|
)
|
birder/net/rope_vit.py
CHANGED
|
@@ -150,6 +150,10 @@ class RoPEAttention(nn.Module):
|
|
|
150
150
|
attn_drop: float,
|
|
151
151
|
proj_drop: float,
|
|
152
152
|
num_special_tokens: int,
|
|
153
|
+
qkv_bias: bool = True,
|
|
154
|
+
qk_norm: bool = False,
|
|
155
|
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
|
156
|
+
norm_layer_eps: float = 1e-6,
|
|
153
157
|
rope_rot_type: str = "standard",
|
|
154
158
|
) -> None:
|
|
155
159
|
super().__init__()
|
|
@@ -167,7 +171,14 @@ class RoPEAttention(nn.Module):
|
|
|
167
171
|
else:
|
|
168
172
|
raise ValueError(f"Unknown rope_rot_type, got '{rope_rot_type}'")
|
|
169
173
|
|
|
170
|
-
self.qkv = nn.Linear(dim, dim * 3)
|
|
174
|
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
175
|
+
if qk_norm is True:
|
|
176
|
+
self.q_norm = norm_layer(self.head_dim, eps=norm_layer_eps)
|
|
177
|
+
self.k_norm = norm_layer(self.head_dim, eps=norm_layer_eps)
|
|
178
|
+
else:
|
|
179
|
+
self.q_norm = nn.Identity()
|
|
180
|
+
self.k_norm = nn.Identity()
|
|
181
|
+
|
|
171
182
|
self.attn_drop = nn.Dropout(attn_drop)
|
|
172
183
|
self.proj = nn.Linear(dim, dim)
|
|
173
184
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
@@ -176,6 +187,8 @@ class RoPEAttention(nn.Module):
|
|
|
176
187
|
(B, N, C) = x.size()
|
|
177
188
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
|
178
189
|
(q, k, v) = qkv.unbind(0)
|
|
190
|
+
q = self.q_norm(q)
|
|
191
|
+
k = self.k_norm(k)
|
|
179
192
|
|
|
180
193
|
n = self.num_special_tokens
|
|
181
194
|
q = torch.concat([q[:, :, :n, :], self.apply_rot_fn(q[:, :, n:, :], rope)], dim=2)
|
|
@@ -207,6 +220,8 @@ class EncoderBlock(nn.Module):
|
|
|
207
220
|
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
|
208
221
|
norm_layer_eps: float = 1e-6,
|
|
209
222
|
mlp_layer: Callable[..., nn.Module] = FFN,
|
|
223
|
+
qkv_bias: bool = True,
|
|
224
|
+
qk_norm: bool = False,
|
|
210
225
|
rope_rot_type: str = "standard",
|
|
211
226
|
) -> None:
|
|
212
227
|
super().__init__()
|
|
@@ -222,6 +237,10 @@ class EncoderBlock(nn.Module):
|
|
|
222
237
|
attn_drop=attention_dropout,
|
|
223
238
|
proj_drop=dropout,
|
|
224
239
|
num_special_tokens=num_special_tokens,
|
|
240
|
+
qkv_bias=qkv_bias,
|
|
241
|
+
qk_norm=qk_norm,
|
|
242
|
+
norm_layer=norm_layer,
|
|
243
|
+
norm_layer_eps=norm_layer_eps,
|
|
225
244
|
rope_rot_type=rope_rot_type,
|
|
226
245
|
)
|
|
227
246
|
if layer_scale_init_value is not None:
|
|
@@ -249,7 +268,6 @@ class EncoderBlock(nn.Module):
|
|
|
249
268
|
|
|
250
269
|
|
|
251
270
|
class Encoder(nn.Module):
|
|
252
|
-
# pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
253
271
|
def __init__(
|
|
254
272
|
self,
|
|
255
273
|
num_layers: int,
|
|
@@ -261,6 +279,8 @@ class Encoder(nn.Module):
|
|
|
261
279
|
attention_dropout: float,
|
|
262
280
|
dpr: list[float],
|
|
263
281
|
pre_norm: bool = False,
|
|
282
|
+
qkv_bias: bool = True,
|
|
283
|
+
qk_norm: bool = False,
|
|
264
284
|
activation_layer: Callable[..., nn.Module] = nn.GELU,
|
|
265
285
|
layer_scale_init_value: Optional[float] = None,
|
|
266
286
|
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
|
@@ -293,6 +313,8 @@ class Encoder(nn.Module):
|
|
|
293
313
|
norm_layer=norm_layer,
|
|
294
314
|
norm_layer_eps=norm_layer_eps,
|
|
295
315
|
mlp_layer=mlp_layer,
|
|
316
|
+
qkv_bias=qkv_bias,
|
|
317
|
+
qk_norm=qk_norm,
|
|
296
318
|
rope_rot_type=rope_rot_type,
|
|
297
319
|
)
|
|
298
320
|
)
|
|
@@ -331,6 +353,7 @@ class MAEDecoderBlock(nn.Module):
|
|
|
331
353
|
rope_temperature: float,
|
|
332
354
|
layer_scale_init_value: Optional[float] = None,
|
|
333
355
|
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
|
356
|
+
norm_layer_eps: float = 1e-6,
|
|
334
357
|
mlp_layer: Callable[..., nn.Module] = FFN,
|
|
335
358
|
rope_rot_type: str = "standard",
|
|
336
359
|
) -> None:
|
|
@@ -346,7 +369,7 @@ class MAEDecoderBlock(nn.Module):
|
|
|
346
369
|
)
|
|
347
370
|
|
|
348
371
|
# Attention block
|
|
349
|
-
self.norm1 = norm_layer(hidden_dim, eps=
|
|
372
|
+
self.norm1 = norm_layer(hidden_dim, eps=norm_layer_eps)
|
|
350
373
|
self.attn = RoPEAttention(
|
|
351
374
|
hidden_dim,
|
|
352
375
|
num_heads,
|
|
@@ -361,7 +384,7 @@ class MAEDecoderBlock(nn.Module):
|
|
|
361
384
|
self.layer_scale_1 = nn.Identity()
|
|
362
385
|
|
|
363
386
|
# MLP block
|
|
364
|
-
self.norm2 = norm_layer(hidden_dim, eps=
|
|
387
|
+
self.norm2 = norm_layer(hidden_dim, eps=norm_layer_eps)
|
|
365
388
|
self.mlp = mlp_layer(hidden_dim, mlp_dim, act_layer=activation_layer, dropout=0.0)
|
|
366
389
|
if layer_scale_init_value is not None:
|
|
367
390
|
self.layer_scale_2 = LayerScale(hidden_dim, layer_scale_init_value)
|
|
@@ -403,6 +426,8 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
|
|
|
403
426
|
layer_scale_init_value: Optional[float] = self.config.get("layer_scale_init_value", None)
|
|
404
427
|
pre_norm: bool = self.config.get("pre_norm", False)
|
|
405
428
|
post_norm: bool = self.config.get("post_norm", True)
|
|
429
|
+
qkv_bias: bool = self.config.get("qkv_bias", True)
|
|
430
|
+
qk_norm: bool = self.config.get("qk_norm", False)
|
|
406
431
|
num_reg_tokens: int = self.config.get("num_reg_tokens", 0)
|
|
407
432
|
class_token: bool = self.config.get("class_token", True)
|
|
408
433
|
attn_pool_head: bool = self.config.get("attn_pool_head", False)
|
|
@@ -450,6 +475,7 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
|
|
|
450
475
|
self.num_reg_tokens = num_reg_tokens
|
|
451
476
|
self.attn_pool_special_tokens = attn_pool_special_tokens
|
|
452
477
|
self.norm_layer = norm_layer
|
|
478
|
+
self.norm_layer_eps = norm_layer_eps
|
|
453
479
|
self.mlp_layer = mlp_layer
|
|
454
480
|
self.act_layer = act_layer
|
|
455
481
|
self.rope_rot_type = rope_rot_type
|
|
@@ -521,6 +547,8 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
|
|
|
521
547
|
attention_dropout,
|
|
522
548
|
dpr,
|
|
523
549
|
pre_norm=pre_norm,
|
|
550
|
+
qkv_bias=qkv_bias,
|
|
551
|
+
qk_norm=qk_norm,
|
|
524
552
|
activation_layer=act_layer,
|
|
525
553
|
layer_scale_init_value=layer_scale_init_value,
|
|
526
554
|
norm_layer=norm_layer,
|
|
@@ -562,6 +590,7 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
|
|
|
562
590
|
rope_temperature=rope_temperature,
|
|
563
591
|
layer_scale_init_value=layer_scale_init_value,
|
|
564
592
|
norm_layer=norm_layer,
|
|
593
|
+
norm_layer_eps=norm_layer_eps,
|
|
565
594
|
mlp_layer=mlp_layer,
|
|
566
595
|
rope_rot_type=rope_rot_type,
|
|
567
596
|
)
|
|
@@ -614,16 +643,16 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
|
|
|
614
643
|
|
|
615
644
|
def freeze(self, freeze_classifier: bool = True, unfreeze_features: bool = False) -> None:
|
|
616
645
|
for param in self.parameters():
|
|
617
|
-
param.
|
|
646
|
+
param.requires_grad_(False)
|
|
618
647
|
|
|
619
648
|
if freeze_classifier is False:
|
|
620
649
|
for param in self.classifier.parameters():
|
|
621
|
-
param.
|
|
650
|
+
param.requires_grad_(True)
|
|
622
651
|
|
|
623
652
|
if unfreeze_features is True:
|
|
624
653
|
if self.attn_pool is not None:
|
|
625
654
|
for param in self.attn_pool.parameters():
|
|
626
|
-
param.
|
|
655
|
+
param.requires_grad_(True)
|
|
627
656
|
|
|
628
657
|
def set_causal_attention(self, is_causal: bool = True) -> None:
|
|
629
658
|
self.encoder.set_causal_attention(is_causal)
|
|
@@ -661,16 +690,16 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
|
|
|
661
690
|
|
|
662
691
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
663
692
|
for param in self.conv_proj.parameters():
|
|
664
|
-
param.
|
|
693
|
+
param.requires_grad_(False)
|
|
665
694
|
|
|
666
|
-
self.pos_embedding.
|
|
695
|
+
self.pos_embedding.requires_grad_(False)
|
|
667
696
|
|
|
668
697
|
for idx, module in enumerate(self.encoder.children()):
|
|
669
698
|
if idx >= up_to_stage:
|
|
670
699
|
break
|
|
671
700
|
|
|
672
701
|
for param in module.parameters():
|
|
673
|
-
param.
|
|
702
|
+
param.requires_grad_(False)
|
|
674
703
|
|
|
675
704
|
def masked_encoding_omission(
|
|
676
705
|
self,
|
|
@@ -904,6 +933,7 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
|
|
|
904
933
|
rope_temperature=self.rope_temperature,
|
|
905
934
|
layer_scale_init_value=self.layer_scale_init_value,
|
|
906
935
|
norm_layer=self.norm_layer,
|
|
936
|
+
norm_layer_eps=self.norm_layer_eps,
|
|
907
937
|
mlp_layer=self.mlp_layer,
|
|
908
938
|
rope_rot_type=self.rope_rot_type,
|
|
909
939
|
)
|
|
@@ -931,6 +961,7 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
|
|
|
931
961
|
# - rms : RMSNorm (instead of LayerNorm)
|
|
932
962
|
# - pn : Pre-Norm (layer norm before the encoder) - implies different norm eps
|
|
933
963
|
# - npn : No Post Norm (disables post-normalization layer)
|
|
964
|
+
# - qkn : QK Norm
|
|
934
965
|
#
|
|
935
966
|
# Feed-Forward Network:
|
|
936
967
|
# - swiglu : SwiGLU FFN layer type (instead of standard FFN)
|
|
@@ -1068,6 +1099,20 @@ registry.register_model_config(
|
|
|
1068
1099
|
"drop_path_rate": 0.1,
|
|
1069
1100
|
},
|
|
1070
1101
|
)
|
|
1102
|
+
registry.register_model_config(
|
|
1103
|
+
"rope_vit_b16_qkn_ls",
|
|
1104
|
+
RoPE_ViT,
|
|
1105
|
+
config={
|
|
1106
|
+
"patch_size": 16,
|
|
1107
|
+
"num_layers": 12,
|
|
1108
|
+
"num_heads": 12,
|
|
1109
|
+
"hidden_dim": 768,
|
|
1110
|
+
"mlp_dim": 3072,
|
|
1111
|
+
"layer_scale_init_value": 1e-5,
|
|
1112
|
+
"qk_norm": True,
|
|
1113
|
+
"drop_path_rate": 0.1,
|
|
1114
|
+
},
|
|
1115
|
+
)
|
|
1071
1116
|
registry.register_model_config(
|
|
1072
1117
|
"rope_i_vit_b16_pn_aps_c1", # For PE Core - https://arxiv.org/abs/2504.13181
|
|
1073
1118
|
RoPE_ViT,
|
|
@@ -1310,6 +1355,20 @@ registry.register_model_config(
|
|
|
1310
1355
|
"drop_path_rate": 0.0,
|
|
1311
1356
|
},
|
|
1312
1357
|
)
|
|
1358
|
+
registry.register_model_config(
|
|
1359
|
+
"rope_vit_reg4_m14_avg",
|
|
1360
|
+
RoPE_ViT,
|
|
1361
|
+
config={
|
|
1362
|
+
"patch_size": 14,
|
|
1363
|
+
"num_layers": 12,
|
|
1364
|
+
"num_heads": 8,
|
|
1365
|
+
"hidden_dim": 512,
|
|
1366
|
+
"mlp_dim": 2048,
|
|
1367
|
+
"num_reg_tokens": 4,
|
|
1368
|
+
"class_token": False,
|
|
1369
|
+
"drop_path_rate": 0.0,
|
|
1370
|
+
},
|
|
1371
|
+
)
|
|
1313
1372
|
registry.register_model_config(
|
|
1314
1373
|
"rope_vit_reg4_b32",
|
|
1315
1374
|
RoPE_ViT,
|
birder/net/shufflenet_v1.py
CHANGED
|
@@ -220,14 +220,14 @@ class ShuffleNet_v1(DetectorBackbone):
|
|
|
220
220
|
|
|
221
221
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
222
222
|
for param in self.stem.parameters():
|
|
223
|
-
param.
|
|
223
|
+
param.requires_grad_(False)
|
|
224
224
|
|
|
225
225
|
for idx, module in enumerate(self.body.children()):
|
|
226
226
|
if idx >= up_to_stage:
|
|
227
227
|
break
|
|
228
228
|
|
|
229
229
|
for param in module.parameters():
|
|
230
|
-
param.
|
|
230
|
+
param.requires_grad_(False)
|
|
231
231
|
|
|
232
232
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
233
233
|
x = self.stem(x)
|
birder/net/shufflenet_v2.py
CHANGED
|
@@ -166,14 +166,14 @@ class ShuffleNet_v2(DetectorBackbone):
|
|
|
166
166
|
|
|
167
167
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
168
168
|
for param in self.stem.parameters():
|
|
169
|
-
param.
|
|
169
|
+
param.requires_grad_(False)
|
|
170
170
|
|
|
171
171
|
for idx, module in enumerate(self.body.children()):
|
|
172
172
|
if idx >= up_to_stage:
|
|
173
173
|
break
|
|
174
174
|
|
|
175
175
|
for param in module.parameters():
|
|
176
|
-
param.
|
|
176
|
+
param.requires_grad_(False)
|
|
177
177
|
|
|
178
178
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
179
179
|
x = self.stem(x)
|
birder/net/smt.py
CHANGED
|
@@ -275,7 +275,6 @@ class Stem(nn.Module):
|
|
|
275
275
|
|
|
276
276
|
|
|
277
277
|
class SMTStage(nn.Module):
|
|
278
|
-
# pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
279
278
|
def __init__(
|
|
280
279
|
self,
|
|
281
280
|
dim: int,
|
|
@@ -429,7 +428,7 @@ class SMT(DetectorBackbone):
|
|
|
429
428
|
break
|
|
430
429
|
|
|
431
430
|
for param in module.parameters():
|
|
432
|
-
param.
|
|
431
|
+
param.requires_grad_(False)
|
|
433
432
|
|
|
434
433
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
435
434
|
return self.body(x)
|
birder/net/squeezenext.py
CHANGED
|
@@ -177,14 +177,14 @@ class SqueezeNext(DetectorBackbone):
|
|
|
177
177
|
|
|
178
178
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
179
179
|
for param in self.stem.parameters():
|
|
180
|
-
param.
|
|
180
|
+
param.requires_grad_(False)
|
|
181
181
|
|
|
182
182
|
for idx, module in enumerate(self.body.children()):
|
|
183
183
|
if idx >= up_to_stage:
|
|
184
184
|
break
|
|
185
185
|
|
|
186
186
|
for param in module.parameters():
|
|
187
|
-
param.
|
|
187
|
+
param.requires_grad_(False)
|
|
188
188
|
|
|
189
189
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
190
190
|
x = self.stem(x)
|
birder/net/ssl/byol.py
CHANGED
|
@@ -82,8 +82,9 @@ class BYOL(SSLBaseNet):
|
|
|
82
82
|
online_predictions = self.online_predictor(projection)
|
|
83
83
|
(online_pred_one, online_pred_two) = online_predictions.chunk(2, dim=0)
|
|
84
84
|
|
|
85
|
-
|
|
86
|
-
|
|
85
|
+
with torch.no_grad():
|
|
86
|
+
target_projections = self.target_encoder(x)
|
|
87
|
+
(target_proj_one, target_proj_two) = target_projections.chunk(2, dim=0)
|
|
87
88
|
|
|
88
89
|
loss_one = loss_fn(online_pred_one, target_proj_two.detach())
|
|
89
90
|
loss_two = loss_fn(online_pred_two, target_proj_one.detach())
|