birder 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. birder/common/fs_ops.py +2 -2
  2. birder/introspection/attention_rollout.py +1 -1
  3. birder/introspection/transformer_attribution.py +1 -1
  4. birder/layers/layer_scale.py +1 -1
  5. birder/net/__init__.py +2 -10
  6. birder/net/_rope_vit_configs.py +430 -0
  7. birder/net/_vit_configs.py +479 -0
  8. birder/net/biformer.py +1 -0
  9. birder/net/cait.py +5 -5
  10. birder/net/coat.py +12 -12
  11. birder/net/conv2former.py +3 -3
  12. birder/net/convmixer.py +1 -1
  13. birder/net/convnext_v1.py +1 -1
  14. birder/net/crossvit.py +5 -5
  15. birder/net/davit.py +1 -1
  16. birder/net/deit.py +12 -26
  17. birder/net/deit3.py +42 -189
  18. birder/net/densenet.py +9 -8
  19. birder/net/detection/deformable_detr.py +5 -2
  20. birder/net/detection/detr.py +5 -2
  21. birder/net/detection/efficientdet.py +1 -1
  22. birder/net/dpn.py +1 -2
  23. birder/net/edgenext.py +2 -1
  24. birder/net/edgevit.py +3 -0
  25. birder/net/efficientformer_v1.py +2 -1
  26. birder/net/efficientformer_v2.py +18 -31
  27. birder/net/efficientnet_v2.py +3 -0
  28. birder/net/efficientvit_mit.py +5 -5
  29. birder/net/fasternet.py +2 -2
  30. birder/net/flexivit.py +22 -43
  31. birder/net/groupmixformer.py +1 -1
  32. birder/net/hgnet_v1.py +5 -5
  33. birder/net/inception_next.py +1 -1
  34. birder/net/inception_resnet_v1.py +3 -3
  35. birder/net/inception_resnet_v2.py +7 -4
  36. birder/net/inception_v3.py +3 -0
  37. birder/net/inception_v4.py +3 -0
  38. birder/net/maxvit.py +1 -1
  39. birder/net/metaformer.py +3 -3
  40. birder/net/mim/crossmae.py +1 -1
  41. birder/net/mim/mae_vit.py +1 -1
  42. birder/net/mim/simmim.py +1 -1
  43. birder/net/mobilenet_v1.py +0 -9
  44. birder/net/mobilenet_v2.py +38 -44
  45. birder/net/{mobilenet_v3_large.py → mobilenet_v3.py} +37 -10
  46. birder/net/mobilevit_v1.py +5 -32
  47. birder/net/mobilevit_v2.py +1 -45
  48. birder/net/moganet.py +8 -5
  49. birder/net/mvit_v2.py +6 -6
  50. birder/net/nfnet.py +4 -0
  51. birder/net/pit.py +1 -1
  52. birder/net/pvt_v1.py +5 -5
  53. birder/net/pvt_v2.py +5 -5
  54. birder/net/repghost.py +1 -30
  55. birder/net/resmlp.py +2 -2
  56. birder/net/resnest.py +3 -0
  57. birder/net/resnet_v1.py +125 -1
  58. birder/net/resnet_v2.py +75 -1
  59. birder/net/resnext.py +35 -1
  60. birder/net/rope_deit3.py +33 -136
  61. birder/net/rope_flexivit.py +18 -18
  62. birder/net/rope_vit.py +3 -735
  63. birder/net/simple_vit.py +22 -16
  64. birder/net/smt.py +1 -1
  65. birder/net/squeezenet.py +5 -12
  66. birder/net/squeezenext.py +0 -24
  67. birder/net/ssl/capi.py +1 -1
  68. birder/net/ssl/data2vec.py +1 -1
  69. birder/net/ssl/dino_v2.py +2 -2
  70. birder/net/ssl/franca.py +2 -2
  71. birder/net/ssl/i_jepa.py +1 -1
  72. birder/net/ssl/ibot.py +1 -1
  73. birder/net/swiftformer.py +12 -2
  74. birder/net/swin_transformer_v2.py +1 -1
  75. birder/net/tiny_vit.py +3 -16
  76. birder/net/van.py +2 -2
  77. birder/net/vit.py +35 -963
  78. birder/net/vit_sam.py +13 -38
  79. birder/net/xcit.py +7 -6
  80. birder/tools/introspection.py +1 -1
  81. birder/tools/model_info.py +3 -1
  82. birder/version.py +1 -1
  83. {birder-0.3.3.dist-info → birder-0.4.0.dist-info}/METADATA +1 -1
  84. {birder-0.3.3.dist-info → birder-0.4.0.dist-info}/RECORD +88 -90
  85. birder/net/mobilenet_v3_small.py +0 -43
  86. birder/net/se_resnet_v1.py +0 -105
  87. birder/net/se_resnet_v2.py +0 -59
  88. birder/net/se_resnext.py +0 -30
  89. {birder-0.3.3.dist-info → birder-0.4.0.dist-info}/WHEEL +0 -0
  90. {birder-0.3.3.dist-info → birder-0.4.0.dist-info}/entry_points.txt +0 -0
  91. {birder-0.3.3.dist-info → birder-0.4.0.dist-info}/licenses/LICENSE +0 -0
  92. {birder-0.3.3.dist-info → birder-0.4.0.dist-info}/top_level.txt +0 -0
birder/net/deit.py CHANGED
@@ -16,6 +16,9 @@ import torch
16
16
  from torch import nn
17
17
 
18
18
  from birder.model_registry import registry
19
+ from birder.net._vit_configs import BASE
20
+ from birder.net._vit_configs import SMALL
21
+ from birder.net._vit_configs import TINY
19
22
  from birder.net.base import BaseNet
20
23
  from birder.net.vit import Encoder
21
24
  from birder.net.vit import PatchEmbed
@@ -94,6 +97,10 @@ class DeiT(BaseNet):
94
97
  self.classifier = self.create_classifier()
95
98
  self.distillation_output = False
96
99
 
100
+ self.max_stride = patch_size
101
+ self.stem_stride = patch_size
102
+ self.stem_width = hidden_dim
103
+
97
104
  # Weight initialization
98
105
  if isinstance(self.conv_proj, nn.Conv2d):
99
106
  # Init the patchify stem
@@ -200,38 +207,17 @@ class DeiT(BaseNet):
200
207
  registry.register_model_config(
201
208
  "deit_t16",
202
209
  DeiT,
203
- config={
204
- "patch_size": 16,
205
- "num_layers": 12,
206
- "num_heads": 3,
207
- "hidden_dim": 192,
208
- "mlp_dim": 768,
209
- "drop_path_rate": 0.0,
210
- },
210
+ config={"patch_size": 16, **TINY},
211
211
  )
212
212
  registry.register_model_config(
213
213
  "deit_s16",
214
214
  DeiT,
215
- config={
216
- "patch_size": 16,
217
- "num_layers": 12,
218
- "num_heads": 6,
219
- "hidden_dim": 384,
220
- "mlp_dim": 1536,
221
- "drop_path_rate": 0.1,
222
- },
215
+ config={"patch_size": 16, **SMALL, "drop_path_rate": 0.1}, # Override the SMALL definition
223
216
  )
224
217
  registry.register_model_config(
225
218
  "deit_b16",
226
219
  DeiT,
227
- config={
228
- "patch_size": 16,
229
- "num_layers": 12,
230
- "num_heads": 12,
231
- "hidden_dim": 768,
232
- "mlp_dim": 3072,
233
- "drop_path_rate": 0.1,
234
- },
220
+ config={"patch_size": 16, **BASE},
235
221
  )
236
222
 
237
223
  registry.register_weights(
@@ -242,7 +228,7 @@ registry.register_weights(
242
228
  "formats": {
243
229
  "pt": {
244
230
  "file_size": 21.7,
245
- "sha256": "ac124122dec9f1bceff383a6a555ca375ca1b613caf486dac3f29d87afac03b3",
231
+ "sha256": "68b33aba0c1be5e78d4a33e74a7c1ea72b6abb232d59f0048ff9b8342e43246e",
246
232
  }
247
233
  },
248
234
  "net": {"network": "deit_t16", "tag": "il-common"},
@@ -258,7 +244,7 @@ registry.register_weights(
258
244
  "formats": {
259
245
  "pt": {
260
246
  "file_size": 21.7,
261
- "sha256": "fafd0c3c65f9c35318f449f60485f640917736ee7b44056be55c2226909ffdb8",
247
+ "sha256": "f693e89fc350341141c55152bec9f499df63738e8423071f3b8e71801c3e5415",
262
248
  }
263
249
  },
264
250
  "net": {"network": "deit_t16", "tag": "dist-il-common"},
birder/net/deit3.py CHANGED
@@ -15,6 +15,12 @@ from torch import nn
15
15
 
16
16
  from birder.common.masking import mask_tensor
17
17
  from birder.model_registry import registry
18
+ from birder.net._vit_configs import BASE
19
+ from birder.net._vit_configs import HUGE
20
+ from birder.net._vit_configs import LARGE
21
+ from birder.net._vit_configs import MEDIUM
22
+ from birder.net._vit_configs import SMALL
23
+ from birder.net._vit_configs import TINY
18
24
  from birder.net.base import DetectorBackbone
19
25
  from birder.net.base import MaskedTokenOmissionMixin
20
26
  from birder.net.base import MaskedTokenRetentionMixin
@@ -368,279 +374,126 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
368
374
  registry.register_model_config(
369
375
  "deit3_t16",
370
376
  DeiT3,
371
- config={
372
- "patch_size": 16,
373
- "num_layers": 12,
374
- "num_heads": 3,
375
- "hidden_dim": 192,
376
- "mlp_dim": 768,
377
- "drop_path_rate": 0.0,
378
- },
377
+ config={"patch_size": 16, **TINY},
378
+ )
379
+ registry.register_model_config(
380
+ "deit3_t14",
381
+ DeiT3,
382
+ config={"patch_size": 14, **TINY},
379
383
  )
380
384
  registry.register_model_config(
381
385
  "deit3_s16",
382
386
  DeiT3,
383
- config={
384
- "patch_size": 16,
385
- "num_layers": 12,
386
- "num_heads": 6,
387
- "hidden_dim": 384,
388
- "mlp_dim": 1536,
389
- "drop_path_rate": 0.05,
390
- },
387
+ config={"patch_size": 16, **SMALL, "drop_path_rate": 0.05},
391
388
  )
392
389
  registry.register_model_config(
393
390
  "deit3_s14",
394
391
  DeiT3,
395
- config={
396
- "patch_size": 14,
397
- "num_layers": 12,
398
- "num_heads": 6,
399
- "hidden_dim": 384,
400
- "mlp_dim": 1536,
401
- "drop_path_rate": 0.05,
402
- },
392
+ config={"patch_size": 14, **SMALL, "drop_path_rate": 0.05},
403
393
  )
404
394
  registry.register_model_config(
405
395
  "deit3_m16",
406
396
  DeiT3,
407
- config={
408
- "patch_size": 16,
409
- "num_layers": 12,
410
- "num_heads": 8,
411
- "hidden_dim": 512,
412
- "mlp_dim": 2048,
413
- "drop_path_rate": 0.1,
414
- },
397
+ config={"patch_size": 16, **MEDIUM, "drop_path_rate": 0.1},
415
398
  )
416
399
  registry.register_model_config(
417
400
  "deit3_m14",
418
401
  DeiT3,
419
- config={
420
- "patch_size": 14,
421
- "num_layers": 12,
422
- "num_heads": 8,
423
- "hidden_dim": 512,
424
- "mlp_dim": 2048,
425
- "drop_path_rate": 0.1,
426
- },
402
+ config={"patch_size": 14, **MEDIUM, "drop_path_rate": 0.1},
427
403
  )
428
404
  registry.register_model_config(
429
405
  "deit3_b16",
430
406
  DeiT3,
431
- config={
432
- "patch_size": 16,
433
- "num_layers": 12,
434
- "num_heads": 12,
435
- "hidden_dim": 768,
436
- "mlp_dim": 3072,
437
- "drop_path_rate": 0.2,
438
- },
407
+ config={"patch_size": 16, **BASE, "drop_path_rate": 0.2},
439
408
  )
440
409
  registry.register_model_config(
441
410
  "deit3_b14",
442
411
  DeiT3,
443
- config={
444
- "patch_size": 14,
445
- "num_layers": 12,
446
- "num_heads": 12,
447
- "hidden_dim": 768,
448
- "mlp_dim": 3072,
449
- "drop_path_rate": 0.2,
450
- },
412
+ config={"patch_size": 14, **BASE, "drop_path_rate": 0.2},
451
413
  )
452
414
  registry.register_model_config(
453
415
  "deit3_l16",
454
416
  DeiT3,
455
- config={
456
- "patch_size": 16,
457
- "num_layers": 24,
458
- "num_heads": 16,
459
- "hidden_dim": 1024,
460
- "mlp_dim": 4096,
461
- "drop_path_rate": 0.45,
462
- },
417
+ config={"patch_size": 16, **LARGE, "drop_path_rate": 0.45},
463
418
  )
464
419
  registry.register_model_config(
465
420
  "deit3_l14",
466
421
  DeiT3,
467
- config={
468
- "patch_size": 14,
469
- "num_layers": 24,
470
- "num_heads": 16,
471
- "hidden_dim": 1024,
472
- "mlp_dim": 4096,
473
- "drop_path_rate": 0.45,
474
- },
422
+ config={"patch_size": 14, **LARGE, "drop_path_rate": 0.45},
475
423
  )
476
424
  registry.register_model_config(
477
425
  "deit3_h16",
478
426
  DeiT3,
479
- config={
480
- "patch_size": 16,
481
- "num_layers": 32,
482
- "num_heads": 16,
483
- "hidden_dim": 1280,
484
- "mlp_dim": 5120,
485
- "drop_path_rate": 0.55,
486
- },
427
+ config={"patch_size": 16, **HUGE, "drop_path_rate": 0.55},
487
428
  )
488
429
  registry.register_model_config(
489
430
  "deit3_h14",
490
431
  DeiT3,
491
- config={
492
- "patch_size": 14,
493
- "num_layers": 32,
494
- "num_heads": 16,
495
- "hidden_dim": 1280,
496
- "mlp_dim": 5120,
497
- "drop_path_rate": 0.55,
498
- },
432
+ config={"patch_size": 14, **HUGE, "drop_path_rate": 0.55},
499
433
  )
500
434
 
501
435
  # With registers
436
+ ####################
437
+
502
438
  registry.register_model_config(
503
439
  "deit3_reg4_t16",
504
440
  DeiT3,
505
- config={
506
- "patch_size": 16,
507
- "num_layers": 12,
508
- "num_heads": 3,
509
- "hidden_dim": 192,
510
- "mlp_dim": 768,
511
- "num_reg_tokens": 4,
512
- "drop_path_rate": 0.0,
513
- },
441
+ config={"patch_size": 16, **TINY, "num_reg_tokens": 4},
442
+ )
443
+ registry.register_model_config(
444
+ "deit3_reg4_t14",
445
+ DeiT3,
446
+ config={"patch_size": 14, **TINY, "num_reg_tokens": 4},
514
447
  )
515
448
  registry.register_model_config(
516
449
  "deit3_reg4_s16",
517
450
  DeiT3,
518
- config={
519
- "patch_size": 16,
520
- "num_layers": 12,
521
- "num_heads": 6,
522
- "hidden_dim": 384,
523
- "mlp_dim": 1536,
524
- "num_reg_tokens": 4,
525
- "drop_path_rate": 0.05,
526
- },
451
+ config={"patch_size": 16, **SMALL, "num_reg_tokens": 4, "drop_path_rate": 0.05},
527
452
  )
528
453
  registry.register_model_config(
529
454
  "deit3_reg4_s14",
530
455
  DeiT3,
531
- config={
532
- "patch_size": 14,
533
- "num_layers": 12,
534
- "num_heads": 6,
535
- "hidden_dim": 384,
536
- "mlp_dim": 1536,
537
- "num_reg_tokens": 4,
538
- "drop_path_rate": 0.05,
539
- },
456
+ config={"patch_size": 14, **SMALL, "num_reg_tokens": 4, "drop_path_rate": 0.05},
540
457
  )
541
458
  registry.register_model_config(
542
459
  "deit3_reg4_m16",
543
460
  DeiT3,
544
- config={
545
- "patch_size": 16,
546
- "num_layers": 12,
547
- "num_heads": 8,
548
- "hidden_dim": 512,
549
- "mlp_dim": 2048,
550
- "num_reg_tokens": 4,
551
- "drop_path_rate": 0.1,
552
- },
461
+ config={"patch_size": 16, **MEDIUM, "num_reg_tokens": 4, "drop_path_rate": 0.1},
553
462
  )
554
463
  registry.register_model_config(
555
464
  "deit3_reg4_m14",
556
465
  DeiT3,
557
- config={
558
- "patch_size": 14,
559
- "num_layers": 12,
560
- "num_heads": 8,
561
- "hidden_dim": 512,
562
- "mlp_dim": 2048,
563
- "num_reg_tokens": 4,
564
- "drop_path_rate": 0.1,
565
- },
466
+ config={"patch_size": 14, **MEDIUM, "num_reg_tokens": 4, "drop_path_rate": 0.1},
566
467
  )
567
468
  registry.register_model_config(
568
469
  "deit3_reg4_b16",
569
470
  DeiT3,
570
- config={
571
- "patch_size": 16,
572
- "num_layers": 12,
573
- "num_heads": 12,
574
- "hidden_dim": 768,
575
- "mlp_dim": 3072,
576
- "num_reg_tokens": 4,
577
- "drop_path_rate": 0.2,
578
- },
471
+ config={"patch_size": 16, **BASE, "num_reg_tokens": 4, "drop_path_rate": 0.2},
579
472
  )
580
473
  registry.register_model_config(
581
474
  "deit3_reg4_b14",
582
475
  DeiT3,
583
- config={
584
- "patch_size": 14,
585
- "num_layers": 12,
586
- "num_heads": 12,
587
- "hidden_dim": 768,
588
- "mlp_dim": 3072,
589
- "num_reg_tokens": 4,
590
- "drop_path_rate": 0.2,
591
- },
476
+ config={"patch_size": 14, **BASE, "num_reg_tokens": 4, "drop_path_rate": 0.2},
592
477
  )
593
478
  registry.register_model_config(
594
479
  "deit3_reg4_l16",
595
480
  DeiT3,
596
- config={
597
- "patch_size": 16,
598
- "num_layers": 24,
599
- "num_heads": 16,
600
- "hidden_dim": 1024,
601
- "mlp_dim": 4096,
602
- "num_reg_tokens": 4,
603
- "drop_path_rate": 0.45,
604
- },
481
+ config={"patch_size": 16, **LARGE, "num_reg_tokens": 4, "drop_path_rate": 0.45},
605
482
  )
606
483
  registry.register_model_config(
607
484
  "deit3_reg4_l14",
608
485
  DeiT3,
609
- config={
610
- "patch_size": 14,
611
- "num_layers": 24,
612
- "num_heads": 16,
613
- "hidden_dim": 1024,
614
- "mlp_dim": 4096,
615
- "num_reg_tokens": 4,
616
- "drop_path_rate": 0.45,
617
- },
486
+ config={"patch_size": 14, **LARGE, "num_reg_tokens": 4, "drop_path_rate": 0.45},
618
487
  )
619
488
  registry.register_model_config(
620
489
  "deit3_reg4_h16",
621
490
  DeiT3,
622
- config={
623
- "patch_size": 16,
624
- "num_layers": 32,
625
- "num_heads": 16,
626
- "hidden_dim": 1280,
627
- "mlp_dim": 5120,
628
- "num_reg_tokens": 4,
629
- "drop_path_rate": 0.55,
630
- },
491
+ config={"patch_size": 16, **HUGE, "num_reg_tokens": 4, "drop_path_rate": 0.55},
631
492
  )
632
493
  registry.register_model_config(
633
494
  "deit3_reg4_h14",
634
495
  DeiT3,
635
- config={
636
- "patch_size": 14,
637
- "num_layers": 32,
638
- "num_heads": 16,
639
- "hidden_dim": 1280,
640
- "mlp_dim": 5120,
641
- "num_reg_tokens": 4,
642
- "drop_path_rate": 0.55,
643
- },
496
+ config={"patch_size": 14, **HUGE, "num_reg_tokens": 4, "drop_path_rate": 0.55},
644
497
  )
645
498
 
646
499
  registry.register_weights(
@@ -651,7 +504,7 @@ registry.register_weights(
651
504
  "formats": {
652
505
  "pt": {
653
506
  "file_size": 21.5,
654
- "sha256": "6cd9749a9522f8ff61088e38702553fb1c4d2547b417c499652e3bfa6a81e77a",
507
+ "sha256": "a04141c7f6c459ae075a48ccdee5b82d191bbaa82337673140c06ef82f0a8dc5",
655
508
  }
656
509
  },
657
510
  "net": {"network": "deit3_t16", "tag": "il-common"},
@@ -665,7 +518,7 @@ registry.register_weights(
665
518
  "formats": {
666
519
  "pt": {
667
520
  "file_size": 21.5,
668
- "sha256": "6806a5ae7d45f1c84b25e9869a9cbc7de94368fe9573dc3777acf2da8c83dc4e",
521
+ "sha256": "d26320462da64df6d62b307f7fb35d09c86a5f073002dfb24a51f014074e65c3",
669
522
  }
670
523
  },
671
524
  "net": {"network": "deit3_reg4_t16", "tag": "il-common"},
birder/net/densenet.py CHANGED
@@ -104,19 +104,20 @@ class DenseNet(DetectorBackbone):
104
104
  num_features = num_init_features
105
105
  stages: OrderedDict[str, nn.Module] = OrderedDict()
106
106
  return_channels: list[int] = []
107
- layers = []
108
107
  for i, num_layers in enumerate(layer_list):
108
+ stage_layers = []
109
+ if i != 0:
110
+ stage_layers.append(TransitionBlock(num_features, num_features // 2))
111
+ num_features = num_features // 2
109
112
 
110
- layers.append(DenseBlock(num_features, num_layers=num_layers, growth_rate=growth_rate))
113
+ stage_layers.append(DenseBlock(num_features, num_layers=num_layers, growth_rate=growth_rate))
111
114
  num_features = num_features + (num_layers * growth_rate)
115
+ if i == len(layer_list) - 1:
116
+ stage_layers.append(nn.BatchNorm2d(num_features))
117
+ stage_layers.append(nn.ReLU(inplace=True))
112
118
 
113
- stages[f"stage{i+1}"] = nn.Sequential(*layers)
119
+ stages[f"stage{i+1}"] = nn.Sequential(*stage_layers)
114
120
  return_channels.append(num_features)
115
- layers = []
116
-
117
- if i != len(layer_list) - 1:
118
- layers.append(TransitionBlock(num_features, num_features // 2))
119
- num_features = num_features // 2
120
121
 
121
122
  self.body = nn.Sequential(stages)
122
123
  self.features = nn.Sequential(
@@ -9,7 +9,7 @@ https://arxiv.org/abs/2010.04159
9
9
 
10
10
  Changes from original:
11
11
  * Removed two stage support
12
- * Zero cost matrix elements on overflow (HungarianMatcher)
12
+ * Penalize cost matrix elements on overflow (HungarianMatcher)
13
13
  """
14
14
 
15
15
  # Reference license: Apache-2.0 (both)
@@ -89,7 +89,10 @@ class HungarianMatcher(nn.Module):
89
89
  # Final cost matrix
90
90
  C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
91
91
  C = C.view(B, num_queries, -1).cpu()
92
- C[C.isnan() | C.isinf()] = 0.0
92
+ finite = torch.isfinite(C)
93
+ if not torch.all(finite):
94
+ penalty = C[finite].max().item() + 1.0 if finite.any().item() else 1.0
95
+ C.nan_to_num_(nan=penalty, posinf=penalty, neginf=penalty)
93
96
 
94
97
  sizes = [len(v["boxes"]) for v in targets]
95
98
  indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
@@ -6,7 +6,7 @@ Paper "End-to-End Object Detection with Transformers", https://arxiv.org/abs/200
6
6
 
7
7
  Changes from original:
8
8
  * Move background index to first from last (to be inline with the rest of Birder detectors)
9
- * Zero cost matrix elements on overflow (HungarianMatcher)
9
+ * Penalize cost matrix elements on overflow (HungarianMatcher)
10
10
  """
11
11
 
12
12
  # Reference license: Apache-2.0
@@ -78,7 +78,10 @@ class HungarianMatcher(nn.Module):
78
78
  # Final cost matrix
79
79
  C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
80
80
  C = C.view(B, num_queries, -1).cpu()
81
- C[C.isnan() | C.isinf()] = 0.0
81
+ finite = torch.isfinite(C)
82
+ if not torch.all(finite):
83
+ penalty = C[finite].max().item() + 1.0 if finite.any().item() else 1.0
84
+ C.nan_to_num_(nan=penalty, posinf=penalty, neginf=penalty)
82
85
 
83
86
  sizes = [len(v["boxes"]) for v in targets]
84
87
  indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
@@ -195,7 +195,7 @@ class FpnCombine(nn.Module):
195
195
  )
196
196
 
197
197
  if weight_method in {"attn", "fastattn"}:
198
- self.edge_weights = nn.Parameter(torch.ones(len(inputs_offsets)), requires_grad=True) # WSM
198
+ self.edge_weights = nn.Parameter(torch.ones(len(inputs_offsets))) # WSM
199
199
  else:
200
200
  self.edge_weights = None
201
201
 
birder/net/dpn.py CHANGED
@@ -181,9 +181,8 @@ class DPN(BaseNet):
181
181
  layers.append(DualPathBlock(num_features, r, r, bw, inc, groups, "normal"))
182
182
  num_features += inc
183
183
 
184
- self.norm_act = nn.Sequential(nn.BatchNorm2d(num_features), nn.ELU())
185
-
186
184
  self.body = nn.Sequential(*layers)
185
+ self.norm_act = nn.Sequential(nn.BatchNorm2d(num_features), nn.ReLU())
187
186
  self.features = nn.Sequential(
188
187
  nn.AdaptiveAvgPool2d(output_size=(1, 1)),
189
188
  nn.Flatten(1),
birder/net/edgenext.py CHANGED
@@ -10,6 +10,7 @@ https://arxiv.org/abs/2206.10589
10
10
 
11
11
  import math
12
12
  from collections import OrderedDict
13
+ from functools import partial
13
14
  from typing import Any
14
15
  from typing import Optional
15
16
 
@@ -277,7 +278,7 @@ class EdgeNeXt(DetectorBackbone):
277
278
  stride=(4, 4),
278
279
  padding=(0, 0),
279
280
  bias=True,
280
- norm_layer=LayerNorm2d,
281
+ norm_layer=partial(LayerNorm2d, eps=1e-6),
281
282
  activation_layer=None,
282
283
  )
283
284
 
birder/net/edgevit.py CHANGED
@@ -4,6 +4,9 @@ https://github.com/saic-fi/edgevit/blob/master/src/edgevit.py
4
4
 
5
5
  Paper "EdgeViTs: Competing Light-weight CNNs on Mobile Devices with Vision Transformers",
6
6
  https://arxiv.org/abs/2205.03436
7
+
8
+ Changes from original:
9
+ * Removed classifier bias
7
10
  """
8
11
 
9
12
  # Reference license: Apache-2.0
@@ -9,6 +9,7 @@ https://arxiv.org/abs/2206.01191
9
9
 
10
10
  Changes from original:
11
11
  * Removed attention bias cache
12
+ * Stem bias term removed
12
13
  """
13
14
 
14
15
  # Reference license: Apache-2.0 (both)
@@ -76,7 +77,7 @@ class Downsample(nn.Module):
76
77
  stride: tuple[int, int],
77
78
  ) -> None:
78
79
  super().__init__()
79
- padding = (kernel_size[0] // 2, kernel_size[1] // 2)
80
+ padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2)
80
81
  self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
81
82
  self.norm = nn.BatchNorm2d(out_channels)
82
83
 
@@ -9,6 +9,7 @@ https://arxiv.org/abs/2212.08059
9
9
 
10
10
  Changes from original:
11
11
  * Removed attention bias cache
12
+ * Removed biases before norms
12
13
  """
13
14
 
14
15
  # Reference license: Apache-2.0 (both)
@@ -244,9 +245,24 @@ class ConvMLP(nn.Module):
244
245
  drop: float,
245
246
  ) -> None:
246
247
  super().__init__()
247
- self.fc1 = Conv2dNormActivation(in_features, hidden_features, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
248
+ self.fc1 = Conv2dNormActivation(
249
+ in_features,
250
+ hidden_features,
251
+ kernel_size=(1, 1),
252
+ stride=(1, 1),
253
+ padding=(0, 0),
254
+ activation_layer=nn.GELU,
255
+ inplace=None,
256
+ )
248
257
  self.mid = Conv2dNormActivation(
249
- hidden_features, hidden_features, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=hidden_features
258
+ hidden_features,
259
+ hidden_features,
260
+ kernel_size=(3, 3),
261
+ stride=(1, 1),
262
+ padding=(1, 1),
263
+ groups=hidden_features,
264
+ activation_layer=nn.GELU,
265
+ inplace=None,
250
266
  )
251
267
  self.drop1 = nn.Dropout(drop)
252
268
  self.fc2 = Conv2dNormActivation(
@@ -676,32 +692,3 @@ registry.register_model_config(
676
692
  ],
677
693
  },
678
694
  )
679
-
680
- registry.register_weights(
681
- "efficientformer_v2_s0_il-common",
682
- {
683
- "description": "EfficientFormer v2 S0 model trained on the il-common dataset",
684
- "resolution": (256, 256),
685
- "formats": {
686
- "pt": {
687
- "file_size": 13.2,
688
- "sha256": "b5ba923d351d45a04686b5bda037438719e0f442a41a34207a7f19737a8edb45",
689
- }
690
- },
691
- "net": {"network": "efficientformer_v2_s0", "tag": "il-common"},
692
- },
693
- )
694
- registry.register_weights(
695
- "efficientformer_v2_s1_il-common",
696
- {
697
- "description": "EfficientFormer v2 S1 model trained on the il-common dataset",
698
- "resolution": (256, 256),
699
- "formats": {
700
- "pt": {
701
- "file_size": 22.9,
702
- "sha256": "6b7ce6bbf5aa83e222cd16d8f07e749cdbb703fd383f99e88362ec8401d81401",
703
- }
704
- },
705
- "net": {"network": "efficientformer_v2_s1", "tag": "il-common"},
706
- },
707
- )
@@ -4,6 +4,9 @@ https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py
4
4
 
5
5
  Paper "EfficientNetV2: Smaller Models and Faster Training",
6
6
  https://arxiv.org/abs/2104.00298
7
+
8
+ Changes from original:
9
+ * Using nn.BatchNorm2d with eps 1e-5 instead of 1e-3
7
10
  """
8
11
 
9
12
  # Reference license: BSD 3-Clause