birder 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/__init__.py +2 -0
- birder/common/fs_ops.py +81 -1
- birder/common/training_cli.py +12 -2
- birder/common/training_utils.py +73 -12
- birder/data/collators/detection.py +3 -1
- birder/datahub/_lib.py +15 -6
- birder/datahub/evaluation.py +591 -0
- birder/eval/__init__.py +0 -0
- birder/eval/__main__.py +74 -0
- birder/eval/_embeddings.py +50 -0
- birder/eval/adversarial.py +315 -0
- birder/eval/benchmarks/__init__.py +0 -0
- birder/eval/benchmarks/awa2.py +357 -0
- birder/eval/benchmarks/bioscan5m.py +198 -0
- birder/eval/benchmarks/fishnet.py +318 -0
- birder/eval/benchmarks/flowers102.py +210 -0
- birder/eval/benchmarks/fungiclef.py +261 -0
- birder/eval/benchmarks/nabirds.py +202 -0
- birder/eval/benchmarks/newt.py +262 -0
- birder/eval/benchmarks/plankton.py +255 -0
- birder/eval/benchmarks/plantdoc.py +259 -0
- birder/eval/benchmarks/plantnet.py +252 -0
- birder/eval/classification.py +235 -0
- birder/eval/methods/__init__.py +0 -0
- birder/eval/methods/ami.py +78 -0
- birder/eval/methods/knn.py +71 -0
- birder/eval/methods/linear.py +152 -0
- birder/eval/methods/mlp.py +178 -0
- birder/eval/methods/simpleshot.py +100 -0
- birder/eval/methods/svm.py +92 -0
- birder/inference/classification.py +23 -2
- birder/inference/detection.py +35 -15
- birder/net/_vit_configs.py +5 -0
- birder/net/cait.py +3 -3
- birder/net/coat.py +3 -3
- birder/net/cswin_transformer.py +2 -1
- birder/net/deit.py +1 -1
- birder/net/deit3.py +1 -1
- birder/net/detection/__init__.py +2 -0
- birder/net/detection/base.py +41 -18
- birder/net/detection/deformable_detr.py +74 -50
- birder/net/detection/detr.py +29 -26
- birder/net/detection/efficientdet.py +42 -25
- birder/net/detection/faster_rcnn.py +53 -21
- birder/net/detection/fcos.py +42 -23
- birder/net/detection/lw_detr.py +1204 -0
- birder/net/detection/plain_detr.py +60 -47
- birder/net/detection/retinanet.py +47 -35
- birder/net/detection/rt_detr_v1.py +49 -46
- birder/net/detection/rt_detr_v2.py +95 -102
- birder/net/detection/ssd.py +47 -31
- birder/net/detection/ssdlite.py +2 -2
- birder/net/detection/yolo_v2.py +33 -18
- birder/net/detection/yolo_v3.py +35 -33
- birder/net/detection/yolo_v4.py +35 -20
- birder/net/detection/yolo_v4_tiny.py +1 -2
- birder/net/edgevit.py +3 -3
- birder/net/efficientvit_msft.py +1 -1
- birder/net/flexivit.py +1 -1
- birder/net/hiera.py +44 -67
- birder/net/hieradet.py +2 -2
- birder/net/maxvit.py +2 -2
- birder/net/mim/fcmae.py +2 -2
- birder/net/mim/mae_hiera.py +9 -16
- birder/net/mnasnet.py +2 -2
- birder/net/nextvit.py +4 -4
- birder/net/resnext.py +2 -2
- birder/net/rope_deit3.py +2 -2
- birder/net/rope_flexivit.py +2 -2
- birder/net/rope_vit.py +2 -2
- birder/net/simple_vit.py +1 -1
- birder/net/squeezenet.py +1 -1
- birder/net/ssl/capi.py +32 -25
- birder/net/ssl/dino_v2.py +12 -15
- birder/net/ssl/franca.py +26 -19
- birder/net/van.py +2 -2
- birder/net/vit.py +21 -3
- birder/net/vit_parallel.py +1 -1
- birder/net/vit_sam.py +62 -16
- birder/net/xcit.py +1 -1
- birder/ops/msda.py +46 -16
- birder/scripts/benchmark.py +35 -8
- birder/scripts/predict.py +14 -1
- birder/scripts/predict_detection.py +7 -1
- birder/scripts/train.py +27 -11
- birder/scripts/train_capi.py +13 -10
- birder/scripts/train_detection.py +18 -7
- birder/scripts/train_franca.py +10 -2
- birder/scripts/train_kd.py +28 -11
- birder/tools/adversarial.py +5 -0
- birder/tools/convert_model.py +101 -43
- birder/tools/quantize_model.py +33 -16
- birder/version.py +1 -1
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/METADATA +17 -10
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/RECORD +99 -75
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/WHEEL +1 -1
- birder/scripts/evaluate.py +0 -176
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/entry_points.txt +0 -0
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/licenses/LICENSE +0 -0
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/top_level.txt +0 -0
birder/net/nextvit.py
CHANGED
|
@@ -355,18 +355,18 @@ class NextViT(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
355
355
|
# Weights initialization
|
|
356
356
|
for m in self.modules():
|
|
357
357
|
if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm, nn.BatchNorm1d)):
|
|
358
|
-
nn.init.
|
|
359
|
-
nn.init.
|
|
358
|
+
nn.init.ones_(m.weight)
|
|
359
|
+
nn.init.zeros_(m.bias)
|
|
360
360
|
|
|
361
361
|
elif isinstance(m, nn.Linear):
|
|
362
362
|
nn.init.normal_(m.weight, std=0.02)
|
|
363
363
|
if hasattr(m, "bias") and m.bias is not None:
|
|
364
|
-
nn.init.
|
|
364
|
+
nn.init.zeros_(m.bias)
|
|
365
365
|
|
|
366
366
|
elif isinstance(m, nn.Conv2d):
|
|
367
367
|
nn.init.normal_(m.weight, std=0.02)
|
|
368
368
|
if hasattr(m, "bias") and m.bias is not None:
|
|
369
|
-
nn.init.
|
|
369
|
+
nn.init.zeros_(m.bias)
|
|
370
370
|
|
|
371
371
|
def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
372
372
|
x = self.stem(x)
|
birder/net/resnext.py
CHANGED
|
@@ -205,8 +205,8 @@ class ResNeXt(DetectorBackbone):
|
|
|
205
205
|
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
|
206
206
|
|
|
207
207
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|
208
|
-
nn.init.
|
|
209
|
-
nn.init.
|
|
208
|
+
nn.init.ones_(m.weight)
|
|
209
|
+
nn.init.zeros_(m.bias)
|
|
210
210
|
|
|
211
211
|
def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
212
212
|
x = self.stem(x)
|
birder/net/rope_deit3.py
CHANGED
|
@@ -223,7 +223,7 @@ class RoPE_DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Ma
|
|
|
223
223
|
pt_grid_size=self.pt_grid_size,
|
|
224
224
|
),
|
|
225
225
|
dim=-1,
|
|
226
|
-
).to(self.rope.pos_embed.device)
|
|
226
|
+
).to(self.rope.pos_embed.device, dtype=self.rope.pos_embed.dtype)
|
|
227
227
|
|
|
228
228
|
def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
229
229
|
H, W = x.shape[-2:]
|
|
@@ -249,7 +249,7 @@ class RoPE_DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Ma
|
|
|
249
249
|
xs = self.encoder.forward_features(x, rope, out_indices=self.out_indices)
|
|
250
250
|
|
|
251
251
|
out: dict[str, torch.Tensor] = {}
|
|
252
|
-
for stage_name, stage_x in zip(self.return_stages, xs):
|
|
252
|
+
for stage_name, stage_x in zip(self.return_stages, xs, strict=True):
|
|
253
253
|
stage_x = stage_x[:, self.num_special_tokens :]
|
|
254
254
|
stage_x = stage_x.permute(0, 2, 1)
|
|
255
255
|
B, C, _ = stage_x.size()
|
birder/net/rope_flexivit.py
CHANGED
|
@@ -292,7 +292,7 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
292
292
|
pt_grid_size=self.pt_grid_size,
|
|
293
293
|
),
|
|
294
294
|
dim=-1,
|
|
295
|
-
).to(self.rope.pos_embed.device)
|
|
295
|
+
).to(self.rope.pos_embed.device, dtype=self.rope.pos_embed.dtype)
|
|
296
296
|
|
|
297
297
|
def freeze(self, freeze_classifier: bool = True, unfreeze_features: bool = False) -> None:
|
|
298
298
|
for param in self.parameters():
|
|
@@ -342,7 +342,7 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
342
342
|
xs = self.encoder.forward_features(x, rope, out_indices=self.out_indices)
|
|
343
343
|
|
|
344
344
|
out: dict[str, torch.Tensor] = {}
|
|
345
|
-
for stage_name, stage_x in zip(self.return_stages, xs):
|
|
345
|
+
for stage_name, stage_x in zip(self.return_stages, xs, strict=True):
|
|
346
346
|
stage_x = stage_x[:, self.num_special_tokens :]
|
|
347
347
|
stage_x = stage_x.permute(0, 2, 1)
|
|
348
348
|
B, C, _ = stage_x.size()
|
birder/net/rope_vit.py
CHANGED
|
@@ -648,7 +648,7 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
|
|
|
648
648
|
pt_grid_size=self.pt_grid_size,
|
|
649
649
|
),
|
|
650
650
|
dim=-1,
|
|
651
|
-
).to(self.rope.pos_embed.device)
|
|
651
|
+
).to(self.rope.pos_embed.device, dtype=self.rope.pos_embed.dtype)
|
|
652
652
|
|
|
653
653
|
def freeze(self, freeze_classifier: bool = True, unfreeze_features: bool = False) -> None:
|
|
654
654
|
for param in self.parameters():
|
|
@@ -698,7 +698,7 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
|
|
|
698
698
|
xs = self.encoder.forward_features(x, rope, out_indices=self.out_indices)
|
|
699
699
|
|
|
700
700
|
out: dict[str, torch.Tensor] = {}
|
|
701
|
-
for stage_name, stage_x in zip(self.return_stages, xs):
|
|
701
|
+
for stage_name, stage_x in zip(self.return_stages, xs, strict=True):
|
|
702
702
|
stage_x = stage_x[:, self.num_special_tokens :]
|
|
703
703
|
stage_x = stage_x.permute(0, 2, 1)
|
|
704
704
|
B, C, _ = stage_x.size()
|
birder/net/simple_vit.py
CHANGED
|
@@ -215,7 +215,7 @@ class Simple_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin):
|
|
|
215
215
|
xs = self.encoder.forward_features(x, out_indices=self.out_indices)
|
|
216
216
|
|
|
217
217
|
out: dict[str, torch.Tensor] = {}
|
|
218
|
-
for stage_name, stage_x in zip(self.return_stages, xs):
|
|
218
|
+
for stage_name, stage_x in zip(self.return_stages, xs, strict=True):
|
|
219
219
|
stage_x = stage_x[:, self.num_special_tokens :]
|
|
220
220
|
stage_x = stage_x.permute(0, 2, 1)
|
|
221
221
|
B, C, _ = stage_x.size()
|
birder/net/squeezenet.py
CHANGED
|
@@ -76,7 +76,7 @@ class SqueezeNet(BaseNet):
|
|
|
76
76
|
if isinstance(m, nn.Conv2d):
|
|
77
77
|
nn.init.kaiming_uniform_(m.weight)
|
|
78
78
|
if m.bias is not None:
|
|
79
|
-
nn.init.
|
|
79
|
+
nn.init.zeros_(m.bias)
|
|
80
80
|
|
|
81
81
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
82
82
|
x = self.stem(x)
|
birder/net/ssl/capi.py
CHANGED
|
@@ -66,14 +66,20 @@ def sinkhorn_knopp_(M: torch.Tensor, temp: float, n_iterations: int, eps: float
|
|
|
66
66
|
|
|
67
67
|
|
|
68
68
|
class SinkhornQueue(nn.Module):
|
|
69
|
-
def __init__(self, queue_size: int, position_wise: bool) -> None:
|
|
69
|
+
def __init__(self, queue_size: int, position_wise: bool, dim: int, seq_len: Optional[int] = None) -> None:
|
|
70
70
|
super().__init__()
|
|
71
71
|
self.queue_size = queue_size
|
|
72
72
|
self.position_wise = position_wise
|
|
73
73
|
self.active = True
|
|
74
|
-
self.
|
|
75
|
-
|
|
76
|
-
|
|
74
|
+
if self.position_wise is True:
|
|
75
|
+
assert seq_len is not None, "seq_len is required when position_wise is True"
|
|
76
|
+
|
|
77
|
+
self.queue = nn.Buffer(torch.empty(seq_len, queue_size, dim))
|
|
78
|
+
else:
|
|
79
|
+
self.queue = nn.Buffer(torch.empty(queue_size, dim))
|
|
80
|
+
|
|
81
|
+
self.queue_ptr = nn.Buffer(torch.zeros(1, dtype=torch.long))
|
|
82
|
+
self.queue_full = nn.Buffer(torch.zeros(1, dtype=torch.bool))
|
|
77
83
|
|
|
78
84
|
def set_active(self, active: bool) -> None:
|
|
79
85
|
self.active = active
|
|
@@ -81,13 +87,13 @@ class SinkhornQueue(nn.Module):
|
|
|
81
87
|
def get(self) -> Optional[torch.Tensor]:
|
|
82
88
|
if self.active is False:
|
|
83
89
|
return None
|
|
84
|
-
if self.queue_full is False:
|
|
90
|
+
if self.queue_full.item() is False:
|
|
85
91
|
return None
|
|
86
92
|
|
|
87
93
|
return self.queue
|
|
88
94
|
|
|
89
95
|
@torch.no_grad() # type: ignore[untyped-decorator]
|
|
90
|
-
def forward(self, values: torch.Tensor) -> None:
|
|
96
|
+
def forward(self, values: torch.Tensor) -> None:
|
|
91
97
|
if self.active is False:
|
|
92
98
|
return
|
|
93
99
|
if values.numel() == 0:
|
|
@@ -98,21 +104,16 @@ class SinkhornQueue(nn.Module):
|
|
|
98
104
|
if values.dim() != 3:
|
|
99
105
|
raise ValueError("SinkhornQueue in position wise mode expects a 3D tensor")
|
|
100
106
|
|
|
101
|
-
seq_len = values.size(0)
|
|
102
107
|
batch_size = values.size(1)
|
|
103
|
-
dim = values.size(2)
|
|
104
|
-
|
|
105
|
-
if self.queue.numel() == 0:
|
|
106
|
-
self.queue = values.new_empty(seq_len, self.queue_size, dim)
|
|
107
108
|
|
|
108
109
|
values = values.detach()
|
|
109
110
|
if batch_size >= self.queue_size:
|
|
110
111
|
self.queue.copy_(values[:, -self.queue_size :, :])
|
|
111
|
-
self.queue_ptr
|
|
112
|
-
self.queue_full
|
|
112
|
+
self.queue_ptr.zero_()
|
|
113
|
+
self.queue_full.fill_(True)
|
|
113
114
|
return
|
|
114
115
|
|
|
115
|
-
ptr = self.queue_ptr
|
|
116
|
+
ptr = self.queue_ptr.item()
|
|
116
117
|
end = ptr + batch_size
|
|
117
118
|
if end <= self.queue_size:
|
|
118
119
|
self.queue[:, ptr:end, :].copy_(values)
|
|
@@ -121,26 +122,23 @@ class SinkhornQueue(nn.Module):
|
|
|
121
122
|
self.queue[:, ptr:, :].copy_(values[:, :first, :])
|
|
122
123
|
self.queue[:, : end - self.queue_size, :].copy_(values[:, first:, :])
|
|
123
124
|
|
|
124
|
-
self.queue_ptr
|
|
125
|
+
self.queue_ptr.fill_(end % self.queue_size)
|
|
125
126
|
if end >= self.queue_size:
|
|
126
|
-
self.queue_full
|
|
127
|
+
self.queue_full.fill_(True)
|
|
127
128
|
|
|
128
129
|
else:
|
|
129
130
|
# values shape: (N, dim) - 2D
|
|
130
131
|
if values.dim() != 2:
|
|
131
132
|
raise ValueError("SinkhornQueue in non-position wise mode expects a 2D tensor")
|
|
132
133
|
|
|
133
|
-
if self.queue.numel() == 0:
|
|
134
|
-
self.queue = values.new_empty(self.queue_size, values.size(1))
|
|
135
|
-
|
|
136
134
|
values = values.detach()
|
|
137
135
|
if values.size(0) >= self.queue_size:
|
|
138
136
|
self.queue.copy_(values[-self.queue_size :])
|
|
139
|
-
self.queue_ptr
|
|
140
|
-
self.queue_full
|
|
137
|
+
self.queue_ptr.zero_()
|
|
138
|
+
self.queue_full.fill_(True)
|
|
141
139
|
return
|
|
142
140
|
|
|
143
|
-
ptr = self.queue_ptr
|
|
141
|
+
ptr = self.queue_ptr.item()
|
|
144
142
|
end = ptr + values.size(0)
|
|
145
143
|
if end <= self.queue_size:
|
|
146
144
|
self.queue[ptr:end].copy_(values)
|
|
@@ -149,9 +147,9 @@ class SinkhornQueue(nn.Module):
|
|
|
149
147
|
self.queue[ptr:].copy_(values[:first])
|
|
150
148
|
self.queue[: end - self.queue_size].copy_(values[first:])
|
|
151
149
|
|
|
152
|
-
self.queue_ptr
|
|
150
|
+
self.queue_ptr.fill_(end % self.queue_size)
|
|
153
151
|
if end >= self.queue_size:
|
|
154
|
-
self.queue_full
|
|
152
|
+
self.queue_full.fill_(True)
|
|
155
153
|
|
|
156
154
|
|
|
157
155
|
class OnlineClustering(nn.Module):
|
|
@@ -166,6 +164,7 @@ class OnlineClustering(nn.Module):
|
|
|
166
164
|
pred_temp: float,
|
|
167
165
|
position_wise_sk: bool = True,
|
|
168
166
|
queue_size: Optional[int] = None,
|
|
167
|
+
seq_len: Optional[int] = None,
|
|
169
168
|
):
|
|
170
169
|
super().__init__()
|
|
171
170
|
self.n_sk_iter = n_sk_iter
|
|
@@ -176,7 +175,9 @@ class OnlineClustering(nn.Module):
|
|
|
176
175
|
if queue_size is None:
|
|
177
176
|
self.sinkhorn_queue = None
|
|
178
177
|
else:
|
|
179
|
-
self.sinkhorn_queue = SinkhornQueue(
|
|
178
|
+
self.sinkhorn_queue = SinkhornQueue(
|
|
179
|
+
queue_size, position_wise=position_wise_sk, dim=out_dim, seq_len=seq_len
|
|
180
|
+
)
|
|
180
181
|
|
|
181
182
|
# Weight initialization
|
|
182
183
|
nn.init.normal_(self.layer.weight, std=1.0)
|
|
@@ -399,6 +400,11 @@ class CAPITeacher(SSLBaseNet):
|
|
|
399
400
|
sk_mode: str = self.config["sk_mode"]
|
|
400
401
|
queue_size: Optional[int] = self.config.get("queue_size", None)
|
|
401
402
|
|
|
403
|
+
queue_seq_len: Optional[int] = None
|
|
404
|
+
if sk_mode == "position-wise" and queue_size is not None:
|
|
405
|
+
input_size = (self.size[0] // self.backbone.max_stride, self.size[1] // self.backbone.max_stride)
|
|
406
|
+
queue_seq_len = input_size[0] * input_size[1]
|
|
407
|
+
|
|
402
408
|
self.head = OnlineClustering(
|
|
403
409
|
self.backbone.embedding_size,
|
|
404
410
|
num_clusters,
|
|
@@ -408,6 +414,7 @@ class CAPITeacher(SSLBaseNet):
|
|
|
408
414
|
pred_temp=pred_temp,
|
|
409
415
|
position_wise_sk=sk_mode == "position-wise",
|
|
410
416
|
queue_size=queue_size,
|
|
417
|
+
seq_len=queue_seq_len,
|
|
411
418
|
)
|
|
412
419
|
|
|
413
420
|
def forward( # type: ignore[override] # pylint: disable=arguments-differ
|
birder/net/ssl/dino_v2.py
CHANGED
|
@@ -76,13 +76,13 @@ class DINOHead(nn.Module):
|
|
|
76
76
|
|
|
77
77
|
|
|
78
78
|
class SinkhornQueue(nn.Module):
|
|
79
|
-
def __init__(self, queue_size: int) -> None:
|
|
79
|
+
def __init__(self, queue_size: int, dim: int) -> None:
|
|
80
80
|
super().__init__()
|
|
81
81
|
self.queue_size = queue_size
|
|
82
82
|
self.active = True
|
|
83
|
-
self.queue = nn.Buffer(torch.empty(
|
|
84
|
-
self.queue_ptr
|
|
85
|
-
self.queue_full
|
|
83
|
+
self.queue = nn.Buffer(torch.empty(queue_size, dim))
|
|
84
|
+
self.queue_ptr = nn.Buffer(torch.zeros(1, dtype=torch.long))
|
|
85
|
+
self.queue_full = nn.Buffer(torch.zeros(1, dtype=torch.bool))
|
|
86
86
|
|
|
87
87
|
def set_active(self, active: bool) -> None:
|
|
88
88
|
self.active = active
|
|
@@ -90,7 +90,7 @@ class SinkhornQueue(nn.Module):
|
|
|
90
90
|
def get(self) -> Optional[torch.Tensor]:
|
|
91
91
|
if self.active is False:
|
|
92
92
|
return None
|
|
93
|
-
if self.queue_full is False:
|
|
93
|
+
if self.queue_full.item() is False:
|
|
94
94
|
return None
|
|
95
95
|
|
|
96
96
|
return self.queue
|
|
@@ -104,17 +104,14 @@ class SinkhornQueue(nn.Module):
|
|
|
104
104
|
if values.dim() != 2:
|
|
105
105
|
raise ValueError("SinkhornQueue expects a 2D tensor")
|
|
106
106
|
|
|
107
|
-
if self.queue.numel() == 0:
|
|
108
|
-
self.queue = values.new_empty(self.queue_size, values.size(1))
|
|
109
|
-
|
|
110
107
|
values = values.detach()
|
|
111
108
|
if values.size(0) >= self.queue_size:
|
|
112
109
|
self.queue.copy_(values[-self.queue_size :])
|
|
113
|
-
self.queue_ptr
|
|
114
|
-
self.queue_full
|
|
110
|
+
self.queue_ptr.zero_()
|
|
111
|
+
self.queue_full.fill_(True)
|
|
115
112
|
return
|
|
116
113
|
|
|
117
|
-
ptr = self.queue_ptr
|
|
114
|
+
ptr = self.queue_ptr.item()
|
|
118
115
|
end = ptr + values.size(0)
|
|
119
116
|
if end <= self.queue_size:
|
|
120
117
|
self.queue[ptr:end].copy_(values)
|
|
@@ -123,9 +120,9 @@ class SinkhornQueue(nn.Module):
|
|
|
123
120
|
self.queue[ptr:].copy_(values[:first])
|
|
124
121
|
self.queue[: end - self.queue_size].copy_(values[first:])
|
|
125
122
|
|
|
126
|
-
self.queue_ptr
|
|
123
|
+
self.queue_ptr.fill_(end % self.queue_size)
|
|
127
124
|
if end >= self.queue_size:
|
|
128
|
-
self.queue_full
|
|
125
|
+
self.queue_full.fill_(True)
|
|
129
126
|
|
|
130
127
|
|
|
131
128
|
class DINOLoss(nn.Module):
|
|
@@ -139,7 +136,7 @@ class DINOLoss(nn.Module):
|
|
|
139
136
|
if queue_size is None:
|
|
140
137
|
self.sinkhorn_queue = None
|
|
141
138
|
else:
|
|
142
|
-
self.sinkhorn_queue = SinkhornQueue(queue_size)
|
|
139
|
+
self.sinkhorn_queue = SinkhornQueue(queue_size, dim=out_dim)
|
|
143
140
|
|
|
144
141
|
self.updated = True
|
|
145
142
|
self.reduce_handle: Any = None
|
|
@@ -267,7 +264,7 @@ class iBOTPatchLoss(nn.Module):
|
|
|
267
264
|
if queue_size is None:
|
|
268
265
|
self.sinkhorn_queue = None
|
|
269
266
|
else:
|
|
270
|
-
self.sinkhorn_queue = SinkhornQueue(queue_size)
|
|
267
|
+
self.sinkhorn_queue = SinkhornQueue(queue_size, dim=patch_out_dim)
|
|
271
268
|
|
|
272
269
|
self.updated = True
|
|
273
270
|
self.reduce_handle: Any = None
|
birder/net/ssl/franca.py
CHANGED
|
@@ -124,13 +124,13 @@ class DINOHeadMRL(nn.Module):
|
|
|
124
124
|
|
|
125
125
|
|
|
126
126
|
class SinkhornQueue(nn.Module):
|
|
127
|
-
def __init__(self, queue_size: int) -> None:
|
|
127
|
+
def __init__(self, queue_size: int, dim: int) -> None:
|
|
128
128
|
super().__init__()
|
|
129
129
|
self.queue_size = queue_size
|
|
130
130
|
self.active = True
|
|
131
|
-
self.queue = nn.Buffer(torch.empty(
|
|
132
|
-
self.queue_ptr
|
|
133
|
-
self.queue_full
|
|
131
|
+
self.queue = nn.Buffer(torch.empty(queue_size, dim))
|
|
132
|
+
self.queue_ptr = nn.Buffer(torch.zeros(1, dtype=torch.long))
|
|
133
|
+
self.queue_full = nn.Buffer(torch.zeros(1, dtype=torch.bool))
|
|
134
134
|
|
|
135
135
|
def set_active(self, active: bool) -> None:
|
|
136
136
|
self.active = active
|
|
@@ -138,7 +138,7 @@ class SinkhornQueue(nn.Module):
|
|
|
138
138
|
def get(self) -> Optional[torch.Tensor]:
|
|
139
139
|
if self.active is False:
|
|
140
140
|
return None
|
|
141
|
-
if self.queue_full is False:
|
|
141
|
+
if self.queue_full.item() is False:
|
|
142
142
|
return None
|
|
143
143
|
|
|
144
144
|
return self.queue
|
|
@@ -152,17 +152,14 @@ class SinkhornQueue(nn.Module):
|
|
|
152
152
|
if values.dim() != 2:
|
|
153
153
|
raise ValueError("SinkhornQueue expects a 2D tensor")
|
|
154
154
|
|
|
155
|
-
if self.queue.numel() == 0:
|
|
156
|
-
self.queue = values.new_empty(self.queue_size, values.size(1))
|
|
157
|
-
|
|
158
155
|
values = values.detach()
|
|
159
156
|
if values.size(0) >= self.queue_size:
|
|
160
157
|
self.queue.copy_(values[-self.queue_size :])
|
|
161
|
-
self.queue_ptr
|
|
162
|
-
self.queue_full
|
|
158
|
+
self.queue_ptr.zero_()
|
|
159
|
+
self.queue_full.fill_(True)
|
|
163
160
|
return
|
|
164
161
|
|
|
165
|
-
ptr = self.queue_ptr
|
|
162
|
+
ptr = self.queue_ptr.item()
|
|
166
163
|
end = ptr + values.size(0)
|
|
167
164
|
if end <= self.queue_size:
|
|
168
165
|
self.queue[ptr:end].copy_(values)
|
|
@@ -171,13 +168,15 @@ class SinkhornQueue(nn.Module):
|
|
|
171
168
|
self.queue[ptr:].copy_(values[:first])
|
|
172
169
|
self.queue[: end - self.queue_size].copy_(values[first:])
|
|
173
170
|
|
|
174
|
-
self.queue_ptr
|
|
171
|
+
self.queue_ptr.fill_(end % self.queue_size)
|
|
175
172
|
if end >= self.queue_size:
|
|
176
|
-
self.queue_full
|
|
173
|
+
self.queue_full.fill_(True)
|
|
177
174
|
|
|
178
175
|
|
|
179
176
|
class DINOLossMRL(nn.Module):
|
|
180
|
-
def __init__(
|
|
177
|
+
def __init__(
|
|
178
|
+
self, student_temp: float, nesting_levels: int, queue_size: Optional[int] = None, out_dim: Optional[int] = None
|
|
179
|
+
) -> None:
|
|
181
180
|
super().__init__()
|
|
182
181
|
self.student_temp = student_temp
|
|
183
182
|
self.queue_active = True
|
|
@@ -185,9 +184,12 @@ class DINOLossMRL(nn.Module):
|
|
|
185
184
|
if queue_size is None:
|
|
186
185
|
self.sinkhorn_queue = None
|
|
187
186
|
else:
|
|
187
|
+
assert out_dim is not None, "out_dim is required when queue_size is set"
|
|
188
|
+
|
|
189
|
+
queue_dims = _get_nesting_list(out_dim, nesting_levels)
|
|
188
190
|
self.sinkhorn_queue = nn.ModuleList()
|
|
189
|
-
for
|
|
190
|
-
queue = SinkhornQueue(queue_size)
|
|
191
|
+
for dim in queue_dims:
|
|
192
|
+
queue = SinkhornQueue(queue_size, dim)
|
|
191
193
|
queue.set_active(self.queue_active)
|
|
192
194
|
self.sinkhorn_queue.append(queue)
|
|
193
195
|
|
|
@@ -300,7 +302,9 @@ class DINOLossMRL(nn.Module):
|
|
|
300
302
|
|
|
301
303
|
# pylint: disable=invalid-name
|
|
302
304
|
class iBOTPatchLossMRL(nn.Module):
|
|
303
|
-
def __init__(
|
|
305
|
+
def __init__(
|
|
306
|
+
self, student_temp: float, nesting_levels: int, queue_size: Optional[int] = None, out_dim: Optional[int] = None
|
|
307
|
+
) -> None:
|
|
304
308
|
super().__init__()
|
|
305
309
|
self.student_temp = student_temp
|
|
306
310
|
self.queue_active = True
|
|
@@ -308,9 +312,12 @@ class iBOTPatchLossMRL(nn.Module):
|
|
|
308
312
|
if queue_size is None:
|
|
309
313
|
self.sinkhorn_queue = None
|
|
310
314
|
else:
|
|
315
|
+
assert out_dim is not None, "out_dim is required when queue_size is set"
|
|
316
|
+
|
|
317
|
+
queue_dims = _get_nesting_list(out_dim, nesting_levels)
|
|
311
318
|
self.sinkhorn_queue = nn.ModuleList()
|
|
312
|
-
for
|
|
313
|
-
queue = SinkhornQueue(queue_size)
|
|
319
|
+
for dim in queue_dims:
|
|
320
|
+
queue = SinkhornQueue(queue_size, dim)
|
|
314
321
|
queue.set_active(self.queue_active)
|
|
315
322
|
self.sinkhorn_queue.append(queue)
|
|
316
323
|
|
birder/net/van.py
CHANGED
|
@@ -206,11 +206,11 @@ class VAN(DetectorBackbone):
|
|
|
206
206
|
if isinstance(m, nn.Linear):
|
|
207
207
|
nn.init.trunc_normal_(m.weight, std=0.02)
|
|
208
208
|
if m.bias is not None:
|
|
209
|
-
nn.init.
|
|
209
|
+
nn.init.zeros_(m.bias)
|
|
210
210
|
|
|
211
211
|
elif isinstance(m, nn.LayerNorm):
|
|
212
|
+
nn.init.ones_(m.weight)
|
|
212
213
|
nn.init.zeros_(m.bias)
|
|
213
|
-
nn.init.constant_(m.weight, 1.0)
|
|
214
214
|
|
|
215
215
|
elif isinstance(m, nn.Conv2d):
|
|
216
216
|
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
birder/net/vit.py
CHANGED
|
@@ -572,7 +572,7 @@ class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTok
|
|
|
572
572
|
xs = self.encoder.forward_features(x, out_indices=self.out_indices)
|
|
573
573
|
|
|
574
574
|
out: dict[str, torch.Tensor] = {}
|
|
575
|
-
for stage_name, stage_x in zip(self.return_stages, xs):
|
|
575
|
+
for stage_name, stage_x in zip(self.return_stages, xs, strict=True):
|
|
576
576
|
stage_x = stage_x[:, self.num_special_tokens :]
|
|
577
577
|
stage_x = stage_x.permute(0, 2, 1)
|
|
578
578
|
B, C, _ = stage_x.size()
|
|
@@ -802,6 +802,24 @@ class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTok
|
|
|
802
802
|
# Register model configs (side effects)
|
|
803
803
|
register_vit_configs(ViT)
|
|
804
804
|
|
|
805
|
+
registry.register_weights( # BioCLIP v1: https://arxiv.org/abs/2311.18803
|
|
806
|
+
"vit_b16_pn_bioclip-v1",
|
|
807
|
+
{
|
|
808
|
+
"url": "https://huggingface.co/birder-project/vit_b16_pn_bioclip-v1/resolve/main",
|
|
809
|
+
"description": (
|
|
810
|
+
"ViT b16 image encoder pre-trained by Imageomics using CLIP on the TreeOfLife-10M dataset. "
|
|
811
|
+
"This model has not been fine-tuned for a specific classification task"
|
|
812
|
+
),
|
|
813
|
+
"resolution": (224, 224),
|
|
814
|
+
"formats": {
|
|
815
|
+
"pt": {
|
|
816
|
+
"file_size": 328.9,
|
|
817
|
+
"sha256": "9b2e5598f233657932eeb77e027cd4c4d683bf75515768fe6971cab6ec10bf15",
|
|
818
|
+
},
|
|
819
|
+
},
|
|
820
|
+
"net": {"network": "vit_b16_pn", "tag": "bioclip-v1"},
|
|
821
|
+
},
|
|
822
|
+
)
|
|
805
823
|
registry.register_weights(
|
|
806
824
|
"vit_l16_mim_200",
|
|
807
825
|
{
|
|
@@ -849,8 +867,8 @@ registry.register_weights( # BioCLIP v2: https://arxiv.org/abs/2505.23883
|
|
|
849
867
|
"resolution": (224, 224),
|
|
850
868
|
"formats": {
|
|
851
869
|
"pt": {
|
|
852
|
-
"file_size":
|
|
853
|
-
"sha256": "
|
|
870
|
+
"file_size": 1159.7,
|
|
871
|
+
"sha256": "301a325579dafdfa2ea13b0cbaf8129211ecd1429c29afa20d1c2eaaa91d8b0d",
|
|
854
872
|
},
|
|
855
873
|
},
|
|
856
874
|
"net": {"network": "vit_l14_pn", "tag": "bioclip-v2"},
|
birder/net/vit_parallel.py
CHANGED
|
@@ -370,7 +370,7 @@ class ViT_Parallel(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
370
370
|
xs = self.encoder.forward_features(x, out_indices=self.out_indices)
|
|
371
371
|
|
|
372
372
|
out: dict[str, torch.Tensor] = {}
|
|
373
|
-
for stage_name, stage_x in zip(self.return_stages, xs):
|
|
373
|
+
for stage_name, stage_x in zip(self.return_stages, xs, strict=True):
|
|
374
374
|
stage_x = stage_x[:, self.num_special_tokens :]
|
|
375
375
|
stage_x = stage_x.permute(0, 2, 1)
|
|
376
376
|
B, C, _ = stage_x.size()
|
birder/net/vit_sam.py
CHANGED
|
@@ -29,7 +29,9 @@ from birder.net._vit_configs import BASE
|
|
|
29
29
|
from birder.net._vit_configs import HUGE
|
|
30
30
|
from birder.net._vit_configs import LARGE
|
|
31
31
|
from birder.net._vit_configs import MEDIUM
|
|
32
|
+
from birder.net._vit_configs import SMALL
|
|
32
33
|
from birder.net.base import DetectorBackbone
|
|
34
|
+
from birder.net.base import normalize_out_indices
|
|
33
35
|
from birder.net.vit import EncoderBlock as MAEDecoderBlock
|
|
34
36
|
|
|
35
37
|
|
|
@@ -72,7 +74,7 @@ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor
|
|
|
72
74
|
|
|
73
75
|
# Interpolate rel pos if needed
|
|
74
76
|
if rel_pos.shape[0] != max_rel_dist:
|
|
75
|
-
#
|
|
77
|
+
# Only reached in dynamic-size mode (rel-pos table resized on the fly)
|
|
76
78
|
rel_pos_resized = F.interpolate(
|
|
77
79
|
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), size=max_rel_dist, mode="linear"
|
|
78
80
|
)
|
|
@@ -242,6 +244,7 @@ class EncoderBlock(nn.Module):
|
|
|
242
244
|
class ViT_SAM(DetectorBackbone):
|
|
243
245
|
block_group_regex = r"body\.(\d+)"
|
|
244
246
|
|
|
247
|
+
# pylint: disable=too-many-locals
|
|
245
248
|
def __init__(
|
|
246
249
|
self,
|
|
247
250
|
input_channels: int,
|
|
@@ -266,6 +269,7 @@ class ViT_SAM(DetectorBackbone):
|
|
|
266
269
|
window_size: int = self.config["window_size"]
|
|
267
270
|
global_attn_indexes: list[int] = self.config["global_attn_indexes"]
|
|
268
271
|
neck_channels: Optional[int] = self.config.get("neck_channels", None)
|
|
272
|
+
out_indices: Optional[list[int]] = self.config.get("out_indices", None)
|
|
269
273
|
drop_path_rate: float = self.config["drop_path_rate"]
|
|
270
274
|
|
|
271
275
|
if norm_layer_type == "LayerNorm":
|
|
@@ -292,6 +296,7 @@ class ViT_SAM(DetectorBackbone):
|
|
|
292
296
|
self.hidden_dim = hidden_dim
|
|
293
297
|
self.global_attn_indexes = global_attn_indexes
|
|
294
298
|
self.num_special_tokens = 0
|
|
299
|
+
self.out_indices = normalize_out_indices(out_indices, num_layers)
|
|
295
300
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] # Stochastic depth decay rule
|
|
296
301
|
|
|
297
302
|
self.patch_embed = PatchEmbed(
|
|
@@ -356,8 +361,10 @@ class ViT_SAM(DetectorBackbone):
|
|
|
356
361
|
nn.Flatten(1),
|
|
357
362
|
)
|
|
358
363
|
|
|
359
|
-
self.
|
|
360
|
-
self.
|
|
364
|
+
num_return_stages = len(self.out_indices) if self.out_indices is not None else 1
|
|
365
|
+
self.return_stages = [f"stage{stage_idx + 1}" for stage_idx in range(num_return_stages)]
|
|
366
|
+
self.return_channels = [hidden_dim] * num_return_stages
|
|
367
|
+
self.return_channels[-1] = neck_channels
|
|
361
368
|
self.embedding_size = neck_channels
|
|
362
369
|
self.classifier = self.create_classifier()
|
|
363
370
|
|
|
@@ -372,13 +379,54 @@ class ViT_SAM(DetectorBackbone):
|
|
|
372
379
|
activation_layer=nn.GELU,
|
|
373
380
|
)
|
|
374
381
|
|
|
382
|
+
def _get_pos_embed(self, H: int, W: int) -> torch.Tensor:
|
|
383
|
+
if self.dynamic_size is False:
|
|
384
|
+
return self.pos_embedding
|
|
385
|
+
|
|
386
|
+
if H == self.size[0] and W == self.size[1]:
|
|
387
|
+
return self.pos_embedding
|
|
388
|
+
|
|
389
|
+
base_h = H // self.patch_size
|
|
390
|
+
base_w = W // self.patch_size
|
|
391
|
+
orig_dtype = self.pos_embedding.dtype
|
|
392
|
+
pos_embedding = self.pos_embedding.float()
|
|
393
|
+
pos_embedding = pos_embedding.permute(0, 3, 1, 2)
|
|
394
|
+
pos_embedding = F.interpolate(pos_embedding, size=(base_h, base_w), mode="bicubic", antialias=True)
|
|
395
|
+
pos_embedding = pos_embedding.permute(0, 2, 3, 1)
|
|
396
|
+
|
|
397
|
+
return pos_embedding.to(orig_dtype)
|
|
398
|
+
|
|
399
|
+
def set_causal_attention(self, is_causal: bool = True) -> None:
|
|
400
|
+
for b in self.body:
|
|
401
|
+
b.set_causal_attention(is_causal)
|
|
402
|
+
|
|
375
403
|
def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
404
|
+
H, W = x.shape[-2:]
|
|
376
405
|
x = self.patch_embed(x)
|
|
377
|
-
x = x + self.
|
|
406
|
+
x = x + self._get_pos_embed(H, W)
|
|
378
407
|
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
408
|
+
if self.out_indices is None:
|
|
409
|
+
x = self.body(x)
|
|
410
|
+
x = self.neck(x.permute(0, 3, 1, 2))
|
|
411
|
+
return {self.return_stages[0]: x}
|
|
412
|
+
|
|
413
|
+
out_indices_set = set(self.out_indices)
|
|
414
|
+
last_out_idx = max(out_indices_set)
|
|
415
|
+
out: dict[str, torch.Tensor] = {}
|
|
416
|
+
stage_idx = 0
|
|
417
|
+
for idx, blk in enumerate(self.body):
|
|
418
|
+
x = blk(x)
|
|
419
|
+
if idx not in out_indices_set:
|
|
420
|
+
continue
|
|
421
|
+
|
|
422
|
+
stage_x = x.permute(0, 3, 1, 2)
|
|
423
|
+
if idx == last_out_idx:
|
|
424
|
+
stage_x = self.neck(stage_x)
|
|
425
|
+
|
|
426
|
+
out[self.return_stages[stage_idx]] = stage_x
|
|
427
|
+
stage_idx += 1
|
|
428
|
+
|
|
429
|
+
return out
|
|
382
430
|
|
|
383
431
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
384
432
|
for param in self.patch_embed.parameters():
|
|
@@ -393,13 +441,10 @@ class ViT_SAM(DetectorBackbone):
|
|
|
393
441
|
for param in module.parameters():
|
|
394
442
|
param.requires_grad_(False)
|
|
395
443
|
|
|
396
|
-
def set_causal_attention(self, is_causal: bool = True) -> None:
|
|
397
|
-
for b in self.body:
|
|
398
|
-
b.set_causal_attention(is_causal)
|
|
399
|
-
|
|
400
444
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
445
|
+
H, W = x.shape[-2:]
|
|
401
446
|
x = self.patch_embed(x)
|
|
402
|
-
x = x + self.
|
|
447
|
+
x = x + self._get_pos_embed(H, W)
|
|
403
448
|
|
|
404
449
|
x = self.body(x)
|
|
405
450
|
x = self.neck(x.permute(0, 3, 1, 2))
|
|
@@ -410,9 +455,6 @@ class ViT_SAM(DetectorBackbone):
|
|
|
410
455
|
x = self.forward_features(x)
|
|
411
456
|
return self.features(x)
|
|
412
457
|
|
|
413
|
-
def set_dynamic_size(self, dynamic_size: bool = True) -> None:
|
|
414
|
-
assert dynamic_size is False, "Dynamic size not supported for this network"
|
|
415
|
-
|
|
416
458
|
def adjust_size(self, new_size: tuple[int, int]) -> None:
|
|
417
459
|
if new_size == self.size:
|
|
418
460
|
return
|
|
@@ -530,6 +572,11 @@ class ViT_SAM(DetectorBackbone):
|
|
|
530
572
|
|
|
531
573
|
|
|
532
574
|
# ViTDet (no neck)
|
|
575
|
+
registry.register_model_config(
|
|
576
|
+
"vit_det_s16",
|
|
577
|
+
ViT_SAM,
|
|
578
|
+
config={"patch_size": 16, **SMALL, "window_size": 14, "global_attn_indexes": [2, 5, 8, 11]},
|
|
579
|
+
)
|
|
533
580
|
registry.register_model_config(
|
|
534
581
|
"vit_det_m16_rms",
|
|
535
582
|
ViT_SAM,
|
|
@@ -541,7 +588,6 @@ registry.register_model_config(
|
|
|
541
588
|
"global_attn_indexes": [2, 5, 8, 11],
|
|
542
589
|
},
|
|
543
590
|
)
|
|
544
|
-
|
|
545
591
|
registry.register_model_config(
|
|
546
592
|
"vit_det_b16",
|
|
547
593
|
ViT_SAM,
|