birder 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (100) hide show
  1. birder/__init__.py +2 -0
  2. birder/common/fs_ops.py +81 -1
  3. birder/common/training_cli.py +12 -2
  4. birder/common/training_utils.py +73 -12
  5. birder/data/collators/detection.py +3 -1
  6. birder/datahub/_lib.py +15 -6
  7. birder/datahub/evaluation.py +591 -0
  8. birder/eval/__init__.py +0 -0
  9. birder/eval/__main__.py +74 -0
  10. birder/eval/_embeddings.py +50 -0
  11. birder/eval/adversarial.py +315 -0
  12. birder/eval/benchmarks/__init__.py +0 -0
  13. birder/eval/benchmarks/awa2.py +357 -0
  14. birder/eval/benchmarks/bioscan5m.py +198 -0
  15. birder/eval/benchmarks/fishnet.py +318 -0
  16. birder/eval/benchmarks/flowers102.py +210 -0
  17. birder/eval/benchmarks/fungiclef.py +261 -0
  18. birder/eval/benchmarks/nabirds.py +202 -0
  19. birder/eval/benchmarks/newt.py +262 -0
  20. birder/eval/benchmarks/plankton.py +255 -0
  21. birder/eval/benchmarks/plantdoc.py +259 -0
  22. birder/eval/benchmarks/plantnet.py +252 -0
  23. birder/eval/classification.py +235 -0
  24. birder/eval/methods/__init__.py +0 -0
  25. birder/eval/methods/ami.py +78 -0
  26. birder/eval/methods/knn.py +71 -0
  27. birder/eval/methods/linear.py +152 -0
  28. birder/eval/methods/mlp.py +178 -0
  29. birder/eval/methods/simpleshot.py +100 -0
  30. birder/eval/methods/svm.py +92 -0
  31. birder/inference/classification.py +23 -2
  32. birder/inference/detection.py +35 -15
  33. birder/net/_vit_configs.py +5 -0
  34. birder/net/cait.py +3 -3
  35. birder/net/coat.py +3 -3
  36. birder/net/cswin_transformer.py +2 -1
  37. birder/net/deit.py +1 -1
  38. birder/net/deit3.py +1 -1
  39. birder/net/detection/__init__.py +2 -0
  40. birder/net/detection/base.py +41 -18
  41. birder/net/detection/deformable_detr.py +74 -50
  42. birder/net/detection/detr.py +29 -26
  43. birder/net/detection/efficientdet.py +42 -25
  44. birder/net/detection/faster_rcnn.py +53 -21
  45. birder/net/detection/fcos.py +42 -23
  46. birder/net/detection/lw_detr.py +1204 -0
  47. birder/net/detection/plain_detr.py +60 -47
  48. birder/net/detection/retinanet.py +47 -35
  49. birder/net/detection/rt_detr_v1.py +49 -46
  50. birder/net/detection/rt_detr_v2.py +95 -102
  51. birder/net/detection/ssd.py +47 -31
  52. birder/net/detection/ssdlite.py +2 -2
  53. birder/net/detection/yolo_v2.py +33 -18
  54. birder/net/detection/yolo_v3.py +35 -33
  55. birder/net/detection/yolo_v4.py +35 -20
  56. birder/net/detection/yolo_v4_tiny.py +1 -2
  57. birder/net/edgevit.py +3 -3
  58. birder/net/efficientvit_msft.py +1 -1
  59. birder/net/flexivit.py +1 -1
  60. birder/net/hiera.py +44 -67
  61. birder/net/hieradet.py +2 -2
  62. birder/net/maxvit.py +2 -2
  63. birder/net/mim/fcmae.py +2 -2
  64. birder/net/mim/mae_hiera.py +9 -16
  65. birder/net/mnasnet.py +2 -2
  66. birder/net/nextvit.py +4 -4
  67. birder/net/resnext.py +2 -2
  68. birder/net/rope_deit3.py +2 -2
  69. birder/net/rope_flexivit.py +2 -2
  70. birder/net/rope_vit.py +2 -2
  71. birder/net/simple_vit.py +1 -1
  72. birder/net/squeezenet.py +1 -1
  73. birder/net/ssl/capi.py +32 -25
  74. birder/net/ssl/dino_v2.py +12 -15
  75. birder/net/ssl/franca.py +26 -19
  76. birder/net/van.py +2 -2
  77. birder/net/vit.py +21 -3
  78. birder/net/vit_parallel.py +1 -1
  79. birder/net/vit_sam.py +62 -16
  80. birder/net/xcit.py +1 -1
  81. birder/ops/msda.py +46 -16
  82. birder/scripts/benchmark.py +35 -8
  83. birder/scripts/predict.py +14 -1
  84. birder/scripts/predict_detection.py +7 -1
  85. birder/scripts/train.py +27 -11
  86. birder/scripts/train_capi.py +13 -10
  87. birder/scripts/train_detection.py +18 -7
  88. birder/scripts/train_franca.py +10 -2
  89. birder/scripts/train_kd.py +28 -11
  90. birder/tools/adversarial.py +5 -0
  91. birder/tools/convert_model.py +101 -43
  92. birder/tools/quantize_model.py +33 -16
  93. birder/version.py +1 -1
  94. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/METADATA +17 -10
  95. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/RECORD +99 -75
  96. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/WHEEL +1 -1
  97. birder/scripts/evaluate.py +0 -176
  98. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/entry_points.txt +0 -0
  99. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/licenses/LICENSE +0 -0
  100. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/top_level.txt +0 -0
birder/net/nextvit.py CHANGED
@@ -355,18 +355,18 @@ class NextViT(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
355
355
  # Weights initialization
356
356
  for m in self.modules():
357
357
  if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm, nn.BatchNorm1d)):
358
- nn.init.constant_(m.weight, 1.0)
359
- nn.init.constant_(m.bias, 0)
358
+ nn.init.ones_(m.weight)
359
+ nn.init.zeros_(m.bias)
360
360
 
361
361
  elif isinstance(m, nn.Linear):
362
362
  nn.init.normal_(m.weight, std=0.02)
363
363
  if hasattr(m, "bias") and m.bias is not None:
364
- nn.init.constant_(m.bias, 0)
364
+ nn.init.zeros_(m.bias)
365
365
 
366
366
  elif isinstance(m, nn.Conv2d):
367
367
  nn.init.normal_(m.weight, std=0.02)
368
368
  if hasattr(m, "bias") and m.bias is not None:
369
- nn.init.constant_(m.bias, 0)
369
+ nn.init.zeros_(m.bias)
370
370
 
371
371
  def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
372
372
  x = self.stem(x)
birder/net/resnext.py CHANGED
@@ -205,8 +205,8 @@ class ResNeXt(DetectorBackbone):
205
205
  nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
206
206
 
207
207
  elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
208
- nn.init.constant_(m.weight, 1)
209
- nn.init.constant_(m.bias, 0)
208
+ nn.init.ones_(m.weight)
209
+ nn.init.zeros_(m.bias)
210
210
 
211
211
  def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
212
212
  x = self.stem(x)
birder/net/rope_deit3.py CHANGED
@@ -223,7 +223,7 @@ class RoPE_DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Ma
223
223
  pt_grid_size=self.pt_grid_size,
224
224
  ),
225
225
  dim=-1,
226
- ).to(self.rope.pos_embed.device)
226
+ ).to(self.rope.pos_embed.device, dtype=self.rope.pos_embed.dtype)
227
227
 
228
228
  def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
229
229
  H, W = x.shape[-2:]
@@ -249,7 +249,7 @@ class RoPE_DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Ma
249
249
  xs = self.encoder.forward_features(x, rope, out_indices=self.out_indices)
250
250
 
251
251
  out: dict[str, torch.Tensor] = {}
252
- for stage_name, stage_x in zip(self.return_stages, xs):
252
+ for stage_name, stage_x in zip(self.return_stages, xs, strict=True):
253
253
  stage_x = stage_x[:, self.num_special_tokens :]
254
254
  stage_x = stage_x.permute(0, 2, 1)
255
255
  B, C, _ = stage_x.size()
@@ -292,7 +292,7 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
292
292
  pt_grid_size=self.pt_grid_size,
293
293
  ),
294
294
  dim=-1,
295
- ).to(self.rope.pos_embed.device)
295
+ ).to(self.rope.pos_embed.device, dtype=self.rope.pos_embed.dtype)
296
296
 
297
297
  def freeze(self, freeze_classifier: bool = True, unfreeze_features: bool = False) -> None:
298
298
  for param in self.parameters():
@@ -342,7 +342,7 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
342
342
  xs = self.encoder.forward_features(x, rope, out_indices=self.out_indices)
343
343
 
344
344
  out: dict[str, torch.Tensor] = {}
345
- for stage_name, stage_x in zip(self.return_stages, xs):
345
+ for stage_name, stage_x in zip(self.return_stages, xs, strict=True):
346
346
  stage_x = stage_x[:, self.num_special_tokens :]
347
347
  stage_x = stage_x.permute(0, 2, 1)
348
348
  B, C, _ = stage_x.size()
birder/net/rope_vit.py CHANGED
@@ -648,7 +648,7 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
648
648
  pt_grid_size=self.pt_grid_size,
649
649
  ),
650
650
  dim=-1,
651
- ).to(self.rope.pos_embed.device)
651
+ ).to(self.rope.pos_embed.device, dtype=self.rope.pos_embed.dtype)
652
652
 
653
653
  def freeze(self, freeze_classifier: bool = True, unfreeze_features: bool = False) -> None:
654
654
  for param in self.parameters():
@@ -698,7 +698,7 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
698
698
  xs = self.encoder.forward_features(x, rope, out_indices=self.out_indices)
699
699
 
700
700
  out: dict[str, torch.Tensor] = {}
701
- for stage_name, stage_x in zip(self.return_stages, xs):
701
+ for stage_name, stage_x in zip(self.return_stages, xs, strict=True):
702
702
  stage_x = stage_x[:, self.num_special_tokens :]
703
703
  stage_x = stage_x.permute(0, 2, 1)
704
704
  B, C, _ = stage_x.size()
birder/net/simple_vit.py CHANGED
@@ -215,7 +215,7 @@ class Simple_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin):
215
215
  xs = self.encoder.forward_features(x, out_indices=self.out_indices)
216
216
 
217
217
  out: dict[str, torch.Tensor] = {}
218
- for stage_name, stage_x in zip(self.return_stages, xs):
218
+ for stage_name, stage_x in zip(self.return_stages, xs, strict=True):
219
219
  stage_x = stage_x[:, self.num_special_tokens :]
220
220
  stage_x = stage_x.permute(0, 2, 1)
221
221
  B, C, _ = stage_x.size()
birder/net/squeezenet.py CHANGED
@@ -76,7 +76,7 @@ class SqueezeNet(BaseNet):
76
76
  if isinstance(m, nn.Conv2d):
77
77
  nn.init.kaiming_uniform_(m.weight)
78
78
  if m.bias is not None:
79
- nn.init.constant_(m.bias, 0)
79
+ nn.init.zeros_(m.bias)
80
80
 
81
81
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
82
82
  x = self.stem(x)
birder/net/ssl/capi.py CHANGED
@@ -66,14 +66,20 @@ def sinkhorn_knopp_(M: torch.Tensor, temp: float, n_iterations: int, eps: float
66
66
 
67
67
 
68
68
  class SinkhornQueue(nn.Module):
69
- def __init__(self, queue_size: int, position_wise: bool) -> None:
69
+ def __init__(self, queue_size: int, position_wise: bool, dim: int, seq_len: Optional[int] = None) -> None:
70
70
  super().__init__()
71
71
  self.queue_size = queue_size
72
72
  self.position_wise = position_wise
73
73
  self.active = True
74
- self.queue = nn.Buffer(torch.empty(0), persistent=False)
75
- self.queue_ptr: int = 0
76
- self.queue_full: bool = False
74
+ if self.position_wise is True:
75
+ assert seq_len is not None, "seq_len is required when position_wise is True"
76
+
77
+ self.queue = nn.Buffer(torch.empty(seq_len, queue_size, dim))
78
+ else:
79
+ self.queue = nn.Buffer(torch.empty(queue_size, dim))
80
+
81
+ self.queue_ptr = nn.Buffer(torch.zeros(1, dtype=torch.long))
82
+ self.queue_full = nn.Buffer(torch.zeros(1, dtype=torch.bool))
77
83
 
78
84
  def set_active(self, active: bool) -> None:
79
85
  self.active = active
@@ -81,13 +87,13 @@ class SinkhornQueue(nn.Module):
81
87
  def get(self) -> Optional[torch.Tensor]:
82
88
  if self.active is False:
83
89
  return None
84
- if self.queue_full is False:
90
+ if self.queue_full.item() is False:
85
91
  return None
86
92
 
87
93
  return self.queue
88
94
 
89
95
  @torch.no_grad() # type: ignore[untyped-decorator]
90
- def forward(self, values: torch.Tensor) -> None: # pylint: disable=too-many-branches
96
+ def forward(self, values: torch.Tensor) -> None:
91
97
  if self.active is False:
92
98
  return
93
99
  if values.numel() == 0:
@@ -98,21 +104,16 @@ class SinkhornQueue(nn.Module):
98
104
  if values.dim() != 3:
99
105
  raise ValueError("SinkhornQueue in position wise mode expects a 3D tensor")
100
106
 
101
- seq_len = values.size(0)
102
107
  batch_size = values.size(1)
103
- dim = values.size(2)
104
-
105
- if self.queue.numel() == 0:
106
- self.queue = values.new_empty(seq_len, self.queue_size, dim)
107
108
 
108
109
  values = values.detach()
109
110
  if batch_size >= self.queue_size:
110
111
  self.queue.copy_(values[:, -self.queue_size :, :])
111
- self.queue_ptr = 0
112
- self.queue_full = True
112
+ self.queue_ptr.zero_()
113
+ self.queue_full.fill_(True)
113
114
  return
114
115
 
115
- ptr = self.queue_ptr
116
+ ptr = self.queue_ptr.item()
116
117
  end = ptr + batch_size
117
118
  if end <= self.queue_size:
118
119
  self.queue[:, ptr:end, :].copy_(values)
@@ -121,26 +122,23 @@ class SinkhornQueue(nn.Module):
121
122
  self.queue[:, ptr:, :].copy_(values[:, :first, :])
122
123
  self.queue[:, : end - self.queue_size, :].copy_(values[:, first:, :])
123
124
 
124
- self.queue_ptr = end % self.queue_size
125
+ self.queue_ptr.fill_(end % self.queue_size)
125
126
  if end >= self.queue_size:
126
- self.queue_full = True
127
+ self.queue_full.fill_(True)
127
128
 
128
129
  else:
129
130
  # values shape: (N, dim) - 2D
130
131
  if values.dim() != 2:
131
132
  raise ValueError("SinkhornQueue in non-position wise mode expects a 2D tensor")
132
133
 
133
- if self.queue.numel() == 0:
134
- self.queue = values.new_empty(self.queue_size, values.size(1))
135
-
136
134
  values = values.detach()
137
135
  if values.size(0) >= self.queue_size:
138
136
  self.queue.copy_(values[-self.queue_size :])
139
- self.queue_ptr = 0
140
- self.queue_full = True
137
+ self.queue_ptr.zero_()
138
+ self.queue_full.fill_(True)
141
139
  return
142
140
 
143
- ptr = self.queue_ptr
141
+ ptr = self.queue_ptr.item()
144
142
  end = ptr + values.size(0)
145
143
  if end <= self.queue_size:
146
144
  self.queue[ptr:end].copy_(values)
@@ -149,9 +147,9 @@ class SinkhornQueue(nn.Module):
149
147
  self.queue[ptr:].copy_(values[:first])
150
148
  self.queue[: end - self.queue_size].copy_(values[first:])
151
149
 
152
- self.queue_ptr = end % self.queue_size
150
+ self.queue_ptr.fill_(end % self.queue_size)
153
151
  if end >= self.queue_size:
154
- self.queue_full = True
152
+ self.queue_full.fill_(True)
155
153
 
156
154
 
157
155
  class OnlineClustering(nn.Module):
@@ -166,6 +164,7 @@ class OnlineClustering(nn.Module):
166
164
  pred_temp: float,
167
165
  position_wise_sk: bool = True,
168
166
  queue_size: Optional[int] = None,
167
+ seq_len: Optional[int] = None,
169
168
  ):
170
169
  super().__init__()
171
170
  self.n_sk_iter = n_sk_iter
@@ -176,7 +175,9 @@ class OnlineClustering(nn.Module):
176
175
  if queue_size is None:
177
176
  self.sinkhorn_queue = None
178
177
  else:
179
- self.sinkhorn_queue = SinkhornQueue(queue_size, position_wise=position_wise_sk)
178
+ self.sinkhorn_queue = SinkhornQueue(
179
+ queue_size, position_wise=position_wise_sk, dim=out_dim, seq_len=seq_len
180
+ )
180
181
 
181
182
  # Weight initialization
182
183
  nn.init.normal_(self.layer.weight, std=1.0)
@@ -399,6 +400,11 @@ class CAPITeacher(SSLBaseNet):
399
400
  sk_mode: str = self.config["sk_mode"]
400
401
  queue_size: Optional[int] = self.config.get("queue_size", None)
401
402
 
403
+ queue_seq_len: Optional[int] = None
404
+ if sk_mode == "position-wise" and queue_size is not None:
405
+ input_size = (self.size[0] // self.backbone.max_stride, self.size[1] // self.backbone.max_stride)
406
+ queue_seq_len = input_size[0] * input_size[1]
407
+
402
408
  self.head = OnlineClustering(
403
409
  self.backbone.embedding_size,
404
410
  num_clusters,
@@ -408,6 +414,7 @@ class CAPITeacher(SSLBaseNet):
408
414
  pred_temp=pred_temp,
409
415
  position_wise_sk=sk_mode == "position-wise",
410
416
  queue_size=queue_size,
417
+ seq_len=queue_seq_len,
411
418
  )
412
419
 
413
420
  def forward( # type: ignore[override] # pylint: disable=arguments-differ
birder/net/ssl/dino_v2.py CHANGED
@@ -76,13 +76,13 @@ class DINOHead(nn.Module):
76
76
 
77
77
 
78
78
  class SinkhornQueue(nn.Module):
79
- def __init__(self, queue_size: int) -> None:
79
+ def __init__(self, queue_size: int, dim: int) -> None:
80
80
  super().__init__()
81
81
  self.queue_size = queue_size
82
82
  self.active = True
83
- self.queue = nn.Buffer(torch.empty(0), persistent=False)
84
- self.queue_ptr: int = 0
85
- self.queue_full: bool = False
83
+ self.queue = nn.Buffer(torch.empty(queue_size, dim))
84
+ self.queue_ptr = nn.Buffer(torch.zeros(1, dtype=torch.long))
85
+ self.queue_full = nn.Buffer(torch.zeros(1, dtype=torch.bool))
86
86
 
87
87
  def set_active(self, active: bool) -> None:
88
88
  self.active = active
@@ -90,7 +90,7 @@ class SinkhornQueue(nn.Module):
90
90
  def get(self) -> Optional[torch.Tensor]:
91
91
  if self.active is False:
92
92
  return None
93
- if self.queue_full is False:
93
+ if self.queue_full.item() is False:
94
94
  return None
95
95
 
96
96
  return self.queue
@@ -104,17 +104,14 @@ class SinkhornQueue(nn.Module):
104
104
  if values.dim() != 2:
105
105
  raise ValueError("SinkhornQueue expects a 2D tensor")
106
106
 
107
- if self.queue.numel() == 0:
108
- self.queue = values.new_empty(self.queue_size, values.size(1))
109
-
110
107
  values = values.detach()
111
108
  if values.size(0) >= self.queue_size:
112
109
  self.queue.copy_(values[-self.queue_size :])
113
- self.queue_ptr = 0
114
- self.queue_full = True
110
+ self.queue_ptr.zero_()
111
+ self.queue_full.fill_(True)
115
112
  return
116
113
 
117
- ptr = self.queue_ptr
114
+ ptr = self.queue_ptr.item()
118
115
  end = ptr + values.size(0)
119
116
  if end <= self.queue_size:
120
117
  self.queue[ptr:end].copy_(values)
@@ -123,9 +120,9 @@ class SinkhornQueue(nn.Module):
123
120
  self.queue[ptr:].copy_(values[:first])
124
121
  self.queue[: end - self.queue_size].copy_(values[first:])
125
122
 
126
- self.queue_ptr = end % self.queue_size
123
+ self.queue_ptr.fill_(end % self.queue_size)
127
124
  if end >= self.queue_size:
128
- self.queue_full = True
125
+ self.queue_full.fill_(True)
129
126
 
130
127
 
131
128
  class DINOLoss(nn.Module):
@@ -139,7 +136,7 @@ class DINOLoss(nn.Module):
139
136
  if queue_size is None:
140
137
  self.sinkhorn_queue = None
141
138
  else:
142
- self.sinkhorn_queue = SinkhornQueue(queue_size)
139
+ self.sinkhorn_queue = SinkhornQueue(queue_size, dim=out_dim)
143
140
 
144
141
  self.updated = True
145
142
  self.reduce_handle: Any = None
@@ -267,7 +264,7 @@ class iBOTPatchLoss(nn.Module):
267
264
  if queue_size is None:
268
265
  self.sinkhorn_queue = None
269
266
  else:
270
- self.sinkhorn_queue = SinkhornQueue(queue_size)
267
+ self.sinkhorn_queue = SinkhornQueue(queue_size, dim=patch_out_dim)
271
268
 
272
269
  self.updated = True
273
270
  self.reduce_handle: Any = None
birder/net/ssl/franca.py CHANGED
@@ -124,13 +124,13 @@ class DINOHeadMRL(nn.Module):
124
124
 
125
125
 
126
126
  class SinkhornQueue(nn.Module):
127
- def __init__(self, queue_size: int) -> None:
127
+ def __init__(self, queue_size: int, dim: int) -> None:
128
128
  super().__init__()
129
129
  self.queue_size = queue_size
130
130
  self.active = True
131
- self.queue = nn.Buffer(torch.empty(0), persistent=False)
132
- self.queue_ptr: int = 0
133
- self.queue_full: bool = False
131
+ self.queue = nn.Buffer(torch.empty(queue_size, dim))
132
+ self.queue_ptr = nn.Buffer(torch.zeros(1, dtype=torch.long))
133
+ self.queue_full = nn.Buffer(torch.zeros(1, dtype=torch.bool))
134
134
 
135
135
  def set_active(self, active: bool) -> None:
136
136
  self.active = active
@@ -138,7 +138,7 @@ class SinkhornQueue(nn.Module):
138
138
  def get(self) -> Optional[torch.Tensor]:
139
139
  if self.active is False:
140
140
  return None
141
- if self.queue_full is False:
141
+ if self.queue_full.item() is False:
142
142
  return None
143
143
 
144
144
  return self.queue
@@ -152,17 +152,14 @@ class SinkhornQueue(nn.Module):
152
152
  if values.dim() != 2:
153
153
  raise ValueError("SinkhornQueue expects a 2D tensor")
154
154
 
155
- if self.queue.numel() == 0:
156
- self.queue = values.new_empty(self.queue_size, values.size(1))
157
-
158
155
  values = values.detach()
159
156
  if values.size(0) >= self.queue_size:
160
157
  self.queue.copy_(values[-self.queue_size :])
161
- self.queue_ptr = 0
162
- self.queue_full = True
158
+ self.queue_ptr.zero_()
159
+ self.queue_full.fill_(True)
163
160
  return
164
161
 
165
- ptr = self.queue_ptr
162
+ ptr = self.queue_ptr.item()
166
163
  end = ptr + values.size(0)
167
164
  if end <= self.queue_size:
168
165
  self.queue[ptr:end].copy_(values)
@@ -171,13 +168,15 @@ class SinkhornQueue(nn.Module):
171
168
  self.queue[ptr:].copy_(values[:first])
172
169
  self.queue[: end - self.queue_size].copy_(values[first:])
173
170
 
174
- self.queue_ptr = end % self.queue_size
171
+ self.queue_ptr.fill_(end % self.queue_size)
175
172
  if end >= self.queue_size:
176
- self.queue_full = True
173
+ self.queue_full.fill_(True)
177
174
 
178
175
 
179
176
  class DINOLossMRL(nn.Module):
180
- def __init__(self, student_temp: float, nesting_levels: int, queue_size: Optional[int] = None) -> None:
177
+ def __init__(
178
+ self, student_temp: float, nesting_levels: int, queue_size: Optional[int] = None, out_dim: Optional[int] = None
179
+ ) -> None:
181
180
  super().__init__()
182
181
  self.student_temp = student_temp
183
182
  self.queue_active = True
@@ -185,9 +184,12 @@ class DINOLossMRL(nn.Module):
185
184
  if queue_size is None:
186
185
  self.sinkhorn_queue = None
187
186
  else:
187
+ assert out_dim is not None, "out_dim is required when queue_size is set"
188
+
189
+ queue_dims = _get_nesting_list(out_dim, nesting_levels)
188
190
  self.sinkhorn_queue = nn.ModuleList()
189
- for _ in range(nesting_levels):
190
- queue = SinkhornQueue(queue_size)
191
+ for dim in queue_dims:
192
+ queue = SinkhornQueue(queue_size, dim)
191
193
  queue.set_active(self.queue_active)
192
194
  self.sinkhorn_queue.append(queue)
193
195
 
@@ -300,7 +302,9 @@ class DINOLossMRL(nn.Module):
300
302
 
301
303
  # pylint: disable=invalid-name
302
304
  class iBOTPatchLossMRL(nn.Module):
303
- def __init__(self, student_temp: float, nesting_levels: int, queue_size: Optional[int] = None) -> None:
305
+ def __init__(
306
+ self, student_temp: float, nesting_levels: int, queue_size: Optional[int] = None, out_dim: Optional[int] = None
307
+ ) -> None:
304
308
  super().__init__()
305
309
  self.student_temp = student_temp
306
310
  self.queue_active = True
@@ -308,9 +312,12 @@ class iBOTPatchLossMRL(nn.Module):
308
312
  if queue_size is None:
309
313
  self.sinkhorn_queue = None
310
314
  else:
315
+ assert out_dim is not None, "out_dim is required when queue_size is set"
316
+
317
+ queue_dims = _get_nesting_list(out_dim, nesting_levels)
311
318
  self.sinkhorn_queue = nn.ModuleList()
312
- for _ in range(nesting_levels):
313
- queue = SinkhornQueue(queue_size)
319
+ for dim in queue_dims:
320
+ queue = SinkhornQueue(queue_size, dim)
314
321
  queue.set_active(self.queue_active)
315
322
  self.sinkhorn_queue.append(queue)
316
323
 
birder/net/van.py CHANGED
@@ -206,11 +206,11 @@ class VAN(DetectorBackbone):
206
206
  if isinstance(m, nn.Linear):
207
207
  nn.init.trunc_normal_(m.weight, std=0.02)
208
208
  if m.bias is not None:
209
- nn.init.constant_(m.bias, 0)
209
+ nn.init.zeros_(m.bias)
210
210
 
211
211
  elif isinstance(m, nn.LayerNorm):
212
+ nn.init.ones_(m.weight)
212
213
  nn.init.zeros_(m.bias)
213
- nn.init.constant_(m.weight, 1.0)
214
214
 
215
215
  elif isinstance(m, nn.Conv2d):
216
216
  fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
birder/net/vit.py CHANGED
@@ -572,7 +572,7 @@ class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTok
572
572
  xs = self.encoder.forward_features(x, out_indices=self.out_indices)
573
573
 
574
574
  out: dict[str, torch.Tensor] = {}
575
- for stage_name, stage_x in zip(self.return_stages, xs):
575
+ for stage_name, stage_x in zip(self.return_stages, xs, strict=True):
576
576
  stage_x = stage_x[:, self.num_special_tokens :]
577
577
  stage_x = stage_x.permute(0, 2, 1)
578
578
  B, C, _ = stage_x.size()
@@ -802,6 +802,24 @@ class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTok
802
802
  # Register model configs (side effects)
803
803
  register_vit_configs(ViT)
804
804
 
805
+ registry.register_weights( # BioCLIP v1: https://arxiv.org/abs/2311.18803
806
+ "vit_b16_pn_bioclip-v1",
807
+ {
808
+ "url": "https://huggingface.co/birder-project/vit_b16_pn_bioclip-v1/resolve/main",
809
+ "description": (
810
+ "ViT b16 image encoder pre-trained by Imageomics using CLIP on the TreeOfLife-10M dataset. "
811
+ "This model has not been fine-tuned for a specific classification task"
812
+ ),
813
+ "resolution": (224, 224),
814
+ "formats": {
815
+ "pt": {
816
+ "file_size": 328.9,
817
+ "sha256": "9b2e5598f233657932eeb77e027cd4c4d683bf75515768fe6971cab6ec10bf15",
818
+ },
819
+ },
820
+ "net": {"network": "vit_b16_pn", "tag": "bioclip-v1"},
821
+ },
822
+ )
805
823
  registry.register_weights(
806
824
  "vit_l16_mim_200",
807
825
  {
@@ -849,8 +867,8 @@ registry.register_weights( # BioCLIP v2: https://arxiv.org/abs/2505.23883
849
867
  "resolution": (224, 224),
850
868
  "formats": {
851
869
  "pt": {
852
- "file_size": 1156.6,
853
- "sha256": "6cd7bd6993762590891fe2b41db1649cde5a0c4de5a7f341672f8856ed529d07",
870
+ "file_size": 1159.7,
871
+ "sha256": "301a325579dafdfa2ea13b0cbaf8129211ecd1429c29afa20d1c2eaaa91d8b0d",
854
872
  },
855
873
  },
856
874
  "net": {"network": "vit_l14_pn", "tag": "bioclip-v2"},
@@ -370,7 +370,7 @@ class ViT_Parallel(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
370
370
  xs = self.encoder.forward_features(x, out_indices=self.out_indices)
371
371
 
372
372
  out: dict[str, torch.Tensor] = {}
373
- for stage_name, stage_x in zip(self.return_stages, xs):
373
+ for stage_name, stage_x in zip(self.return_stages, xs, strict=True):
374
374
  stage_x = stage_x[:, self.num_special_tokens :]
375
375
  stage_x = stage_x.permute(0, 2, 1)
376
376
  B, C, _ = stage_x.size()
birder/net/vit_sam.py CHANGED
@@ -29,7 +29,9 @@ from birder.net._vit_configs import BASE
29
29
  from birder.net._vit_configs import HUGE
30
30
  from birder.net._vit_configs import LARGE
31
31
  from birder.net._vit_configs import MEDIUM
32
+ from birder.net._vit_configs import SMALL
32
33
  from birder.net.base import DetectorBackbone
34
+ from birder.net.base import normalize_out_indices
33
35
  from birder.net.vit import EncoderBlock as MAEDecoderBlock
34
36
 
35
37
 
@@ -72,7 +74,7 @@ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor
72
74
 
73
75
  # Interpolate rel pos if needed
74
76
  if rel_pos.shape[0] != max_rel_dist:
75
- # Adjust size is a one off interpolation, should prevent us from getting here
77
+ # Only reached in dynamic-size mode (rel-pos table resized on the fly)
76
78
  rel_pos_resized = F.interpolate(
77
79
  rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), size=max_rel_dist, mode="linear"
78
80
  )
@@ -242,6 +244,7 @@ class EncoderBlock(nn.Module):
242
244
  class ViT_SAM(DetectorBackbone):
243
245
  block_group_regex = r"body\.(\d+)"
244
246
 
247
+ # pylint: disable=too-many-locals
245
248
  def __init__(
246
249
  self,
247
250
  input_channels: int,
@@ -266,6 +269,7 @@ class ViT_SAM(DetectorBackbone):
266
269
  window_size: int = self.config["window_size"]
267
270
  global_attn_indexes: list[int] = self.config["global_attn_indexes"]
268
271
  neck_channels: Optional[int] = self.config.get("neck_channels", None)
272
+ out_indices: Optional[list[int]] = self.config.get("out_indices", None)
269
273
  drop_path_rate: float = self.config["drop_path_rate"]
270
274
 
271
275
  if norm_layer_type == "LayerNorm":
@@ -292,6 +296,7 @@ class ViT_SAM(DetectorBackbone):
292
296
  self.hidden_dim = hidden_dim
293
297
  self.global_attn_indexes = global_attn_indexes
294
298
  self.num_special_tokens = 0
299
+ self.out_indices = normalize_out_indices(out_indices, num_layers)
295
300
  dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] # Stochastic depth decay rule
296
301
 
297
302
  self.patch_embed = PatchEmbed(
@@ -356,8 +361,10 @@ class ViT_SAM(DetectorBackbone):
356
361
  nn.Flatten(1),
357
362
  )
358
363
 
359
- self.return_stages = ["neck"] # Actually meaningless, but for completeness
360
- self.return_channels = [neck_channels]
364
+ num_return_stages = len(self.out_indices) if self.out_indices is not None else 1
365
+ self.return_stages = [f"stage{stage_idx + 1}" for stage_idx in range(num_return_stages)]
366
+ self.return_channels = [hidden_dim] * num_return_stages
367
+ self.return_channels[-1] = neck_channels
361
368
  self.embedding_size = neck_channels
362
369
  self.classifier = self.create_classifier()
363
370
 
@@ -372,13 +379,54 @@ class ViT_SAM(DetectorBackbone):
372
379
  activation_layer=nn.GELU,
373
380
  )
374
381
 
382
+ def _get_pos_embed(self, H: int, W: int) -> torch.Tensor:
383
+ if self.dynamic_size is False:
384
+ return self.pos_embedding
385
+
386
+ if H == self.size[0] and W == self.size[1]:
387
+ return self.pos_embedding
388
+
389
+ base_h = H // self.patch_size
390
+ base_w = W // self.patch_size
391
+ orig_dtype = self.pos_embedding.dtype
392
+ pos_embedding = self.pos_embedding.float()
393
+ pos_embedding = pos_embedding.permute(0, 3, 1, 2)
394
+ pos_embedding = F.interpolate(pos_embedding, size=(base_h, base_w), mode="bicubic", antialias=True)
395
+ pos_embedding = pos_embedding.permute(0, 2, 3, 1)
396
+
397
+ return pos_embedding.to(orig_dtype)
398
+
399
+ def set_causal_attention(self, is_causal: bool = True) -> None:
400
+ for b in self.body:
401
+ b.set_causal_attention(is_causal)
402
+
375
403
  def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
404
+ H, W = x.shape[-2:]
376
405
  x = self.patch_embed(x)
377
- x = x + self.pos_embedding
406
+ x = x + self._get_pos_embed(H, W)
378
407
 
379
- x = self.body(x)
380
- x = self.neck(x.permute(0, 3, 1, 2))
381
- return {self.return_stages[0]: x}
408
+ if self.out_indices is None:
409
+ x = self.body(x)
410
+ x = self.neck(x.permute(0, 3, 1, 2))
411
+ return {self.return_stages[0]: x}
412
+
413
+ out_indices_set = set(self.out_indices)
414
+ last_out_idx = max(out_indices_set)
415
+ out: dict[str, torch.Tensor] = {}
416
+ stage_idx = 0
417
+ for idx, blk in enumerate(self.body):
418
+ x = blk(x)
419
+ if idx not in out_indices_set:
420
+ continue
421
+
422
+ stage_x = x.permute(0, 3, 1, 2)
423
+ if idx == last_out_idx:
424
+ stage_x = self.neck(stage_x)
425
+
426
+ out[self.return_stages[stage_idx]] = stage_x
427
+ stage_idx += 1
428
+
429
+ return out
382
430
 
383
431
  def freeze_stages(self, up_to_stage: int) -> None:
384
432
  for param in self.patch_embed.parameters():
@@ -393,13 +441,10 @@ class ViT_SAM(DetectorBackbone):
393
441
  for param in module.parameters():
394
442
  param.requires_grad_(False)
395
443
 
396
- def set_causal_attention(self, is_causal: bool = True) -> None:
397
- for b in self.body:
398
- b.set_causal_attention(is_causal)
399
-
400
444
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
445
+ H, W = x.shape[-2:]
401
446
  x = self.patch_embed(x)
402
- x = x + self.pos_embedding
447
+ x = x + self._get_pos_embed(H, W)
403
448
 
404
449
  x = self.body(x)
405
450
  x = self.neck(x.permute(0, 3, 1, 2))
@@ -410,9 +455,6 @@ class ViT_SAM(DetectorBackbone):
410
455
  x = self.forward_features(x)
411
456
  return self.features(x)
412
457
 
413
- def set_dynamic_size(self, dynamic_size: bool = True) -> None:
414
- assert dynamic_size is False, "Dynamic size not supported for this network"
415
-
416
458
  def adjust_size(self, new_size: tuple[int, int]) -> None:
417
459
  if new_size == self.size:
418
460
  return
@@ -530,6 +572,11 @@ class ViT_SAM(DetectorBackbone):
530
572
 
531
573
 
532
574
  # ViTDet (no neck)
575
+ registry.register_model_config(
576
+ "vit_det_s16",
577
+ ViT_SAM,
578
+ config={"patch_size": 16, **SMALL, "window_size": 14, "global_attn_indexes": [2, 5, 8, 11]},
579
+ )
533
580
  registry.register_model_config(
534
581
  "vit_det_m16_rms",
535
582
  ViT_SAM,
@@ -541,7 +588,6 @@ registry.register_model_config(
541
588
  "global_attn_indexes": [2, 5, 8, 11],
542
589
  },
543
590
  )
544
-
545
591
  registry.register_model_config(
546
592
  "vit_det_b16",
547
593
  ViT_SAM,