birder 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/__init__.py +2 -0
- birder/common/fs_ops.py +81 -1
- birder/common/training_cli.py +12 -2
- birder/common/training_utils.py +73 -12
- birder/data/collators/detection.py +3 -1
- birder/datahub/_lib.py +15 -6
- birder/datahub/evaluation.py +591 -0
- birder/eval/__init__.py +0 -0
- birder/eval/__main__.py +74 -0
- birder/eval/_embeddings.py +50 -0
- birder/eval/adversarial.py +315 -0
- birder/eval/benchmarks/__init__.py +0 -0
- birder/eval/benchmarks/awa2.py +357 -0
- birder/eval/benchmarks/bioscan5m.py +198 -0
- birder/eval/benchmarks/fishnet.py +318 -0
- birder/eval/benchmarks/flowers102.py +210 -0
- birder/eval/benchmarks/fungiclef.py +261 -0
- birder/eval/benchmarks/nabirds.py +202 -0
- birder/eval/benchmarks/newt.py +262 -0
- birder/eval/benchmarks/plankton.py +255 -0
- birder/eval/benchmarks/plantdoc.py +259 -0
- birder/eval/benchmarks/plantnet.py +252 -0
- birder/eval/classification.py +235 -0
- birder/eval/methods/__init__.py +0 -0
- birder/eval/methods/ami.py +78 -0
- birder/eval/methods/knn.py +71 -0
- birder/eval/methods/linear.py +152 -0
- birder/eval/methods/mlp.py +178 -0
- birder/eval/methods/simpleshot.py +100 -0
- birder/eval/methods/svm.py +92 -0
- birder/inference/classification.py +23 -2
- birder/inference/detection.py +35 -15
- birder/net/_vit_configs.py +5 -0
- birder/net/cait.py +3 -3
- birder/net/coat.py +3 -3
- birder/net/cswin_transformer.py +2 -1
- birder/net/deit.py +1 -1
- birder/net/deit3.py +1 -1
- birder/net/detection/__init__.py +2 -0
- birder/net/detection/base.py +41 -18
- birder/net/detection/deformable_detr.py +74 -50
- birder/net/detection/detr.py +29 -26
- birder/net/detection/efficientdet.py +42 -25
- birder/net/detection/faster_rcnn.py +53 -21
- birder/net/detection/fcos.py +42 -23
- birder/net/detection/lw_detr.py +1204 -0
- birder/net/detection/plain_detr.py +60 -47
- birder/net/detection/retinanet.py +47 -35
- birder/net/detection/rt_detr_v1.py +49 -46
- birder/net/detection/rt_detr_v2.py +95 -102
- birder/net/detection/ssd.py +47 -31
- birder/net/detection/ssdlite.py +2 -2
- birder/net/detection/yolo_v2.py +33 -18
- birder/net/detection/yolo_v3.py +35 -33
- birder/net/detection/yolo_v4.py +35 -20
- birder/net/detection/yolo_v4_tiny.py +1 -2
- birder/net/edgevit.py +3 -3
- birder/net/efficientvit_msft.py +1 -1
- birder/net/flexivit.py +1 -1
- birder/net/hiera.py +44 -67
- birder/net/hieradet.py +2 -2
- birder/net/maxvit.py +2 -2
- birder/net/mim/fcmae.py +2 -2
- birder/net/mim/mae_hiera.py +9 -16
- birder/net/mnasnet.py +2 -2
- birder/net/nextvit.py +4 -4
- birder/net/resnext.py +2 -2
- birder/net/rope_deit3.py +2 -2
- birder/net/rope_flexivit.py +2 -2
- birder/net/rope_vit.py +2 -2
- birder/net/simple_vit.py +1 -1
- birder/net/squeezenet.py +1 -1
- birder/net/ssl/capi.py +32 -25
- birder/net/ssl/dino_v2.py +12 -15
- birder/net/ssl/franca.py +26 -19
- birder/net/van.py +2 -2
- birder/net/vit.py +21 -3
- birder/net/vit_parallel.py +1 -1
- birder/net/vit_sam.py +62 -16
- birder/net/xcit.py +1 -1
- birder/ops/msda.py +46 -16
- birder/scripts/benchmark.py +35 -8
- birder/scripts/predict.py +14 -1
- birder/scripts/predict_detection.py +7 -1
- birder/scripts/train.py +27 -11
- birder/scripts/train_capi.py +13 -10
- birder/scripts/train_detection.py +18 -7
- birder/scripts/train_franca.py +10 -2
- birder/scripts/train_kd.py +28 -11
- birder/tools/adversarial.py +5 -0
- birder/tools/convert_model.py +101 -43
- birder/tools/quantize_model.py +33 -16
- birder/version.py +1 -1
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/METADATA +17 -10
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/RECORD +99 -75
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/WHEEL +1 -1
- birder/scripts/evaluate.py +0 -176
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/entry_points.txt +0 -0
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/licenses/LICENSE +0 -0
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/top_level.txt +0 -0
birder/net/deit3.py
CHANGED
|
@@ -185,7 +185,7 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
|
|
|
185
185
|
xs = self.encoder.forward_features(x, out_indices=self.out_indices)
|
|
186
186
|
|
|
187
187
|
out: dict[str, torch.Tensor] = {}
|
|
188
|
-
for stage_name, stage_x in zip(self.return_stages, xs):
|
|
188
|
+
for stage_name, stage_x in zip(self.return_stages, xs, strict=True):
|
|
189
189
|
stage_x = stage_x[:, self.num_special_tokens :]
|
|
190
190
|
stage_x = stage_x.permute(0, 2, 1)
|
|
191
191
|
B, C, _ = stage_x.size()
|
birder/net/detection/__init__.py
CHANGED
|
@@ -3,6 +3,7 @@ from birder.net.detection.detr import DETR
|
|
|
3
3
|
from birder.net.detection.efficientdet import EfficientDet
|
|
4
4
|
from birder.net.detection.faster_rcnn import Faster_RCNN
|
|
5
5
|
from birder.net.detection.fcos import FCOS
|
|
6
|
+
from birder.net.detection.lw_detr import LW_DETR
|
|
6
7
|
from birder.net.detection.plain_detr import Plain_DETR
|
|
7
8
|
from birder.net.detection.retinanet import RetinaNet
|
|
8
9
|
from birder.net.detection.rt_detr_v1 import RT_DETR_v1
|
|
@@ -21,6 +22,7 @@ __all__ = [
|
|
|
21
22
|
"EfficientDet",
|
|
22
23
|
"Faster_RCNN",
|
|
23
24
|
"FCOS",
|
|
25
|
+
"LW_DETR",
|
|
24
26
|
"Plain_DETR",
|
|
25
27
|
"RetinaNet",
|
|
26
28
|
"RT_DETR_v1",
|
birder/net/detection/base.py
CHANGED
|
@@ -44,6 +44,7 @@ class DetectionBaseNet(nn.Module):
|
|
|
44
44
|
block_group_regex: Optional[str]
|
|
45
45
|
auto_register = False
|
|
46
46
|
scriptable = True
|
|
47
|
+
exportable = True
|
|
47
48
|
task = str(Task.OBJECT_DETECTION)
|
|
48
49
|
|
|
49
50
|
def __init_subclass__(cls) -> None:
|
|
@@ -134,40 +135,62 @@ class DetectionBaseNet(nn.Module):
|
|
|
134
135
|
f" Found invalid box {degenerate_bb} for target at index {target_idx}.",
|
|
135
136
|
)
|
|
136
137
|
|
|
137
|
-
|
|
138
|
-
|
|
138
|
+
def _to_img_list(self, x: torch.Tensor, image_sizes: Optional[list[tuple[int, int]]] = None) -> "ImageList":
|
|
139
|
+
B = x.size(0)
|
|
139
140
|
if image_sizes is None:
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
torch.
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
)
|
|
148
|
-
image_sizes_list.append((image_size[0], image_size[1]))
|
|
141
|
+
H = x.size(2)
|
|
142
|
+
W = x.size(3)
|
|
143
|
+
h_tensor = torch.full((B,), H, dtype=torch.int64, device=x.device)
|
|
144
|
+
w_tensor = torch.full((B,), W, dtype=torch.int64, device=x.device)
|
|
145
|
+
image_sizes_tensor = torch.stack([h_tensor, w_tensor], dim=1)
|
|
146
|
+
else:
|
|
147
|
+
image_sizes_tensor = torch.tensor(image_sizes, dtype=torch.int64, device=x.device)
|
|
149
148
|
|
|
150
|
-
return ImageList(x,
|
|
149
|
+
return ImageList(x, image_sizes_tensor)
|
|
151
150
|
|
|
152
151
|
def forward(
|
|
153
152
|
self,
|
|
154
153
|
x: torch.Tensor,
|
|
155
154
|
targets: Optional[list[dict[str, torch.Tensor]]] = None,
|
|
156
155
|
masks: Optional[torch.Tensor] = None,
|
|
157
|
-
image_sizes: Optional[list[
|
|
156
|
+
image_sizes: Optional[list[tuple[int, int]]] = None,
|
|
158
157
|
) -> tuple[list[dict[str, torch.Tensor]], dict[str, torch.Tensor]]:
|
|
159
158
|
# TypedDict not supported for TorchScript - avoid returning DetectorResultType
|
|
160
159
|
raise NotImplementedError
|
|
161
160
|
|
|
162
161
|
|
|
163
162
|
class ImageList:
|
|
164
|
-
def __init__(self, tensors: torch.Tensor, image_sizes:
|
|
163
|
+
def __init__(self, tensors: torch.Tensor, image_sizes: torch.Tensor) -> None:
|
|
165
164
|
self.tensors = tensors
|
|
166
|
-
self.image_sizes = image_sizes
|
|
165
|
+
self.image_sizes = image_sizes # Shape: (B, 2) with [H, W] format
|
|
167
166
|
|
|
168
167
|
def to(self, device: torch.device) -> "ImageList":
|
|
169
168
|
cast_tensor = self.tensors.to(device)
|
|
170
|
-
|
|
169
|
+
cast_sizes = self.image_sizes.to(device)
|
|
170
|
+
return ImageList(cast_tensor, cast_sizes)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def clip_boxes_to_image(boxes: torch.Tensor, image_size: torch.Tensor) -> torch.Tensor:
|
|
174
|
+
"""
|
|
175
|
+
Clip boxes to image boundaries
|
|
176
|
+
|
|
177
|
+
Parameters
|
|
178
|
+
----------
|
|
179
|
+
boxes
|
|
180
|
+
Boxes in (x1, y1, x2, y2) format, shape (..., 4)
|
|
181
|
+
image_size
|
|
182
|
+
Tensor of [height, width]
|
|
183
|
+
|
|
184
|
+
Returns
|
|
185
|
+
-------
|
|
186
|
+
Clipped boxes
|
|
187
|
+
"""
|
|
188
|
+
|
|
189
|
+
boxes_x = boxes[..., 0::2].clamp(min=0, max=image_size[1])
|
|
190
|
+
boxes_y = boxes[..., 1::2].clamp(min=0, max=image_size[0])
|
|
191
|
+
clipped_boxes = torch.stack([boxes_x[..., 0], boxes_y[..., 0], boxes_x[..., 1], boxes_y[..., 1]], dim=-1)
|
|
192
|
+
|
|
193
|
+
return clipped_boxes
|
|
171
194
|
|
|
172
195
|
|
|
173
196
|
###############################################################################
|
|
@@ -325,7 +348,7 @@ class SimpleFeaturePyramidNetwork(nn.Module):
|
|
|
325
348
|
|
|
326
349
|
|
|
327
350
|
# pylint: disable=protected-access,too-many-locals
|
|
328
|
-
@torch.jit._script_if_tracing # type: ignore
|
|
351
|
+
@torch.jit._script_if_tracing # type: ignore[untyped-decorator]
|
|
329
352
|
def encode_boxes(reference_boxes: torch.Tensor, proposals: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
|
|
330
353
|
"""
|
|
331
354
|
Encode a set of proposals with respect to some reference boxes
|
|
@@ -609,7 +632,7 @@ class Matcher(nn.Module):
|
|
|
609
632
|
|
|
610
633
|
if match_quality_matrix.numel() == 0:
|
|
611
634
|
# Empty targets or proposals not supported during training
|
|
612
|
-
if match_quality_matrix.
|
|
635
|
+
if match_quality_matrix.size(0) == 0:
|
|
613
636
|
raise ValueError("No ground-truth boxes available for one of the images during training")
|
|
614
637
|
|
|
615
638
|
raise ValueError("No proposal boxes available for one of the images during training")
|
|
@@ -56,7 +56,7 @@ class HungarianMatcher(nn.Module):
|
|
|
56
56
|
@torch.jit.unused # type: ignore[untyped-decorator]
|
|
57
57
|
def forward(
|
|
58
58
|
self, class_logits: torch.Tensor, box_regression: torch.Tensor, targets: list[dict[str, torch.Tensor]]
|
|
59
|
-
) -> list[torch.Tensor]:
|
|
59
|
+
) -> list[tuple[torch.Tensor, torch.Tensor]]:
|
|
60
60
|
with torch.no_grad():
|
|
61
61
|
B, num_queries = class_logits.shape[:2]
|
|
62
62
|
|
|
@@ -135,7 +135,7 @@ class MultiScaleDeformableAttention(nn.Module):
|
|
|
135
135
|
self.reset_parameters()
|
|
136
136
|
|
|
137
137
|
def reset_parameters(self) -> None:
|
|
138
|
-
nn.init.
|
|
138
|
+
nn.init.zeros_(self.sampling_offsets.weight)
|
|
139
139
|
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
|
|
140
140
|
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
|
|
141
141
|
grid_init = (
|
|
@@ -149,12 +149,12 @@ class MultiScaleDeformableAttention(nn.Module):
|
|
|
149
149
|
with torch.no_grad():
|
|
150
150
|
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
|
|
151
151
|
|
|
152
|
-
nn.init.
|
|
153
|
-
nn.init.
|
|
152
|
+
nn.init.zeros_(self.attention_weights.weight)
|
|
153
|
+
nn.init.zeros_(self.attention_weights.bias)
|
|
154
154
|
nn.init.xavier_uniform_(self.value_proj.weight)
|
|
155
|
-
nn.init.
|
|
155
|
+
nn.init.zeros_(self.value_proj.bias)
|
|
156
156
|
nn.init.xavier_uniform_(self.output_proj.weight)
|
|
157
|
-
nn.init.
|
|
157
|
+
nn.init.zeros_(self.output_proj.bias)
|
|
158
158
|
|
|
159
159
|
def forward(
|
|
160
160
|
self,
|
|
@@ -164,10 +164,11 @@ class MultiScaleDeformableAttention(nn.Module):
|
|
|
164
164
|
input_spatial_shapes: torch.Tensor,
|
|
165
165
|
input_level_start_index: torch.Tensor,
|
|
166
166
|
input_padding_mask: Optional[torch.Tensor] = None,
|
|
167
|
+
src_shapes: Optional[list[list[int]]] = None,
|
|
167
168
|
) -> torch.Tensor:
|
|
168
169
|
N, num_queries, _ = query.size()
|
|
169
170
|
N, sequence_length, _ = input_flatten.size()
|
|
170
|
-
assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == sequence_length
|
|
171
|
+
# assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == sequence_length
|
|
171
172
|
|
|
172
173
|
value = self.value_proj(input_flatten)
|
|
173
174
|
if input_padding_mask is not None:
|
|
@@ -208,6 +209,7 @@ class MultiScaleDeformableAttention(nn.Module):
|
|
|
208
209
|
sampling_locations,
|
|
209
210
|
attention_weights,
|
|
210
211
|
self.im2col_step,
|
|
212
|
+
src_shapes,
|
|
211
213
|
)
|
|
212
214
|
|
|
213
215
|
output = self.output_proj(output)
|
|
@@ -235,8 +237,9 @@ class DeformableTransformerEncoderLayer(nn.Module):
|
|
|
235
237
|
spatial_shapes: torch.Tensor,
|
|
236
238
|
level_start_index: torch.Tensor,
|
|
237
239
|
mask: Optional[torch.Tensor],
|
|
240
|
+
src_shapes: Optional[list[list[int]]] = None,
|
|
238
241
|
) -> torch.Tensor:
|
|
239
|
-
src2 = self.self_attn(src + pos, reference_points, src, spatial_shapes, level_start_index, mask)
|
|
242
|
+
src2 = self.self_attn(src + pos, reference_points, src, spatial_shapes, level_start_index, mask, src_shapes)
|
|
240
243
|
src = src + self.dropout(src2)
|
|
241
244
|
src = self.norm1(src)
|
|
242
245
|
|
|
@@ -277,13 +280,13 @@ class DeformableTransformerDecoderLayer(nn.Module):
|
|
|
277
280
|
level_start_index: torch.Tensor,
|
|
278
281
|
src_padding_mask: Optional[torch.Tensor],
|
|
279
282
|
self_attn_mask: Optional[torch.Tensor] = None,
|
|
283
|
+
src_shapes: Optional[list[list[int]]] = None,
|
|
280
284
|
) -> torch.Tensor:
|
|
281
285
|
# Self attention
|
|
282
|
-
|
|
283
|
-
k = tgt + query_pos
|
|
286
|
+
q_k = tgt + query_pos
|
|
284
287
|
|
|
285
288
|
tgt2, _ = self.self_attn(
|
|
286
|
-
|
|
289
|
+
q_k.transpose(0, 1), q_k.transpose(0, 1), tgt.transpose(0, 1), need_weights=False, attn_mask=self_attn_mask
|
|
287
290
|
)
|
|
288
291
|
tgt2 = tgt2.transpose(0, 1)
|
|
289
292
|
tgt = tgt + self.dropout(tgt2)
|
|
@@ -291,7 +294,7 @@ class DeformableTransformerDecoderLayer(nn.Module):
|
|
|
291
294
|
|
|
292
295
|
# Cross attention
|
|
293
296
|
tgt2 = self.cross_attn(
|
|
294
|
-
tgt + query_pos, reference_points, src, src_spatial_shapes, level_start_index, src_padding_mask
|
|
297
|
+
tgt + query_pos, reference_points, src, src_spatial_shapes, level_start_index, src_padding_mask, src_shapes
|
|
295
298
|
)
|
|
296
299
|
tgt = tgt + self.dropout(tgt2)
|
|
297
300
|
tgt = self.norm2(tgt)
|
|
@@ -311,17 +314,15 @@ class DeformableTransformerEncoder(nn.Module):
|
|
|
311
314
|
|
|
312
315
|
@staticmethod
|
|
313
316
|
def get_reference_points(
|
|
314
|
-
|
|
317
|
+
src_shapes: list[list[int]], valid_ratios: torch.Tensor, device: torch.device
|
|
315
318
|
) -> torch.Tensor:
|
|
316
319
|
reference_points_list = []
|
|
317
|
-
for lvl,
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
ref_y,
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
indexing="ij",
|
|
324
|
-
)
|
|
320
|
+
for lvl, (H, W) in enumerate(src_shapes):
|
|
321
|
+
# Use arange instead of linspace - works with symbolic sizes
|
|
322
|
+
# linspace(0.5, H-0.5, H) is equivalent to arange(H) + 0.5
|
|
323
|
+
ref_y = (torch.arange(H, dtype=torch.float32, device=device) + 0.5).view(-1, 1).expand(-1, W)
|
|
324
|
+
ref_x = (torch.arange(W, dtype=torch.float32, device=device) + 0.5).view(1, -1).expand(H, -1)
|
|
325
|
+
|
|
325
326
|
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H)
|
|
326
327
|
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W)
|
|
327
328
|
ref = torch.stack((ref_x, ref_y), dim=-1)
|
|
@@ -336,15 +337,16 @@ class DeformableTransformerEncoder(nn.Module):
|
|
|
336
337
|
self,
|
|
337
338
|
src: torch.Tensor,
|
|
338
339
|
spatial_shapes: torch.Tensor,
|
|
340
|
+
src_shapes: list[list[int]],
|
|
339
341
|
level_start_index: torch.Tensor,
|
|
340
342
|
pos: torch.Tensor,
|
|
341
343
|
valid_ratios: torch.Tensor,
|
|
342
344
|
mask: torch.Tensor,
|
|
343
345
|
) -> torch.Tensor:
|
|
344
346
|
out = src
|
|
345
|
-
reference_points = self.get_reference_points(
|
|
347
|
+
reference_points = self.get_reference_points(src_shapes, valid_ratios, device=src.device)
|
|
346
348
|
for layer in self.layers:
|
|
347
|
-
out = layer(out, pos, reference_points, spatial_shapes, level_start_index, mask)
|
|
349
|
+
out = layer(out, pos, reference_points, spatial_shapes, level_start_index, mask, src_shapes)
|
|
348
350
|
|
|
349
351
|
return out
|
|
350
352
|
|
|
@@ -369,6 +371,7 @@ class DeformableTransformerDecoder(nn.Module):
|
|
|
369
371
|
query_pos: torch.Tensor,
|
|
370
372
|
src_valid_ratios: torch.Tensor,
|
|
371
373
|
src_padding_mask: torch.Tensor,
|
|
374
|
+
src_shapes: Optional[list[list[int]]] = None,
|
|
372
375
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
373
376
|
output = tgt
|
|
374
377
|
|
|
@@ -391,6 +394,7 @@ class DeformableTransformerDecoder(nn.Module):
|
|
|
391
394
|
src_spatial_shapes,
|
|
392
395
|
src_level_start_index,
|
|
393
396
|
src_padding_mask,
|
|
397
|
+
src_shapes=src_shapes,
|
|
394
398
|
)
|
|
395
399
|
|
|
396
400
|
if self.box_refine is True:
|
|
@@ -482,10 +486,11 @@ class DeformableTransformer(nn.Module):
|
|
|
482
486
|
src_list = []
|
|
483
487
|
lvl_pos_embed_list = []
|
|
484
488
|
mask_list = []
|
|
485
|
-
|
|
489
|
+
src_shapes: list[list[int]] = [] # list[tuple[int, int]] not supported on TorchScript
|
|
490
|
+
|
|
486
491
|
for lvl, (src, pos_embed, mask) in enumerate(zip(srcs, pos_embeds, masks)):
|
|
487
|
-
|
|
488
|
-
|
|
492
|
+
H, W = src.shape[-2], src.shape[-1]
|
|
493
|
+
src_shapes.append([H, W])
|
|
489
494
|
src = src.flatten(2).transpose(1, 2)
|
|
490
495
|
pos_embed = pos_embed.flatten(2).transpose(1, 2)
|
|
491
496
|
mask = mask.flatten(1)
|
|
@@ -497,13 +502,19 @@ class DeformableTransformer(nn.Module):
|
|
|
497
502
|
src_flatten = torch.concat(src_list, dim=1)
|
|
498
503
|
mask_flatten = torch.concat(mask_list, dim=1)
|
|
499
504
|
lvl_pos_embed_flatten = torch.concat(lvl_pos_embed_list, dim=1)
|
|
500
|
-
spatial_shapes = torch.
|
|
505
|
+
spatial_shapes = torch.tensor(src_shapes, dtype=torch.long, device=src_flatten.device)
|
|
501
506
|
level_start_index = torch.concat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]), dim=0)
|
|
502
507
|
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], dim=1)
|
|
503
508
|
|
|
504
509
|
# Encoder
|
|
505
510
|
memory = self.encoder(
|
|
506
|
-
src_flatten,
|
|
511
|
+
src_flatten,
|
|
512
|
+
spatial_shapes,
|
|
513
|
+
src_shapes,
|
|
514
|
+
level_start_index,
|
|
515
|
+
lvl_pos_embed_flatten,
|
|
516
|
+
valid_ratios,
|
|
517
|
+
mask_flatten,
|
|
507
518
|
)
|
|
508
519
|
|
|
509
520
|
# Prepare input for decoder
|
|
@@ -515,7 +526,15 @@ class DeformableTransformer(nn.Module):
|
|
|
515
526
|
|
|
516
527
|
# Decoder
|
|
517
528
|
hs, inter_references = self.decoder(
|
|
518
|
-
tgt,
|
|
529
|
+
tgt,
|
|
530
|
+
reference_points,
|
|
531
|
+
memory,
|
|
532
|
+
spatial_shapes,
|
|
533
|
+
level_start_index,
|
|
534
|
+
query_embed,
|
|
535
|
+
valid_ratios,
|
|
536
|
+
mask_flatten,
|
|
537
|
+
src_shapes,
|
|
519
538
|
)
|
|
520
539
|
|
|
521
540
|
return (hs, reference_points, inter_references)
|
|
@@ -587,7 +606,7 @@ class Deformable_DETR(DetectionBaseNet):
|
|
|
587
606
|
|
|
588
607
|
self.query_embed = nn.Embedding(num_queries, hidden_dim * 2)
|
|
589
608
|
self.pos_enc = PositionEmbeddingSine(hidden_dim // 2, normalize=True)
|
|
590
|
-
self.matcher = HungarianMatcher(cost_class=2, cost_bbox=5, cost_giou=2)
|
|
609
|
+
self.matcher = HungarianMatcher(cost_class=2.0, cost_bbox=5.0, cost_giou=2.0)
|
|
591
610
|
|
|
592
611
|
class_embed = nn.Linear(hidden_dim, self.num_classes)
|
|
593
612
|
bbox_embed = MLP(hidden_dim, [hidden_dim, hidden_dim, 4], activation_layer=nn.ReLU)
|
|
@@ -641,7 +660,8 @@ class Deformable_DETR(DetectionBaseNet):
|
|
|
641
660
|
for param in self.class_embed.parameters():
|
|
642
661
|
param.requires_grad_(True)
|
|
643
662
|
|
|
644
|
-
|
|
663
|
+
@staticmethod
|
|
664
|
+
def _get_src_permutation_idx(indices: list[tuple[torch.Tensor, torch.Tensor]]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
645
665
|
batch_idx = torch.concat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
|
|
646
666
|
src_idx = torch.concat([src for (src, _) in indices])
|
|
647
667
|
return (batch_idx, src_idx)
|
|
@@ -650,7 +670,7 @@ class Deformable_DETR(DetectionBaseNet):
|
|
|
650
670
|
self,
|
|
651
671
|
cls_logits: torch.Tensor,
|
|
652
672
|
targets: list[dict[str, torch.Tensor]],
|
|
653
|
-
indices: list[torch.Tensor],
|
|
673
|
+
indices: list[tuple[torch.Tensor, torch.Tensor]],
|
|
654
674
|
num_boxes: int,
|
|
655
675
|
) -> torch.Tensor:
|
|
656
676
|
idx = self._get_src_permutation_idx(indices)
|
|
@@ -675,7 +695,7 @@ class Deformable_DETR(DetectionBaseNet):
|
|
|
675
695
|
self,
|
|
676
696
|
box_output: torch.Tensor,
|
|
677
697
|
targets: list[dict[str, torch.Tensor]],
|
|
678
|
-
indices: list[torch.Tensor],
|
|
698
|
+
indices: list[tuple[torch.Tensor, torch.Tensor]],
|
|
679
699
|
num_boxes: int,
|
|
680
700
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
681
701
|
idx = self._get_src_permutation_idx(indices)
|
|
@@ -709,7 +729,7 @@ class Deformable_DETR(DetectionBaseNet):
|
|
|
709
729
|
if training_utils.is_dist_available_and_initialized() is True:
|
|
710
730
|
torch.distributed.all_reduce(num_boxes)
|
|
711
731
|
|
|
712
|
-
num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1)
|
|
732
|
+
num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1)
|
|
713
733
|
|
|
714
734
|
loss_ce_list = []
|
|
715
735
|
loss_bbox_list = []
|
|
@@ -734,7 +754,7 @@ class Deformable_DETR(DetectionBaseNet):
|
|
|
734
754
|
return losses
|
|
735
755
|
|
|
736
756
|
def postprocess_detections(
|
|
737
|
-
self, class_logits: torch.Tensor, box_regression: torch.Tensor,
|
|
757
|
+
self, class_logits: torch.Tensor, box_regression: torch.Tensor, image_sizes: torch.Tensor
|
|
738
758
|
) -> list[dict[str, torch.Tensor]]:
|
|
739
759
|
prob = class_logits.sigmoid()
|
|
740
760
|
topk_values, topk_indexes = torch.topk(prob.view(class_logits.shape[0], -1), k=100, dim=1)
|
|
@@ -743,14 +763,12 @@ class Deformable_DETR(DetectionBaseNet):
|
|
|
743
763
|
labels = topk_indexes % class_logits.shape[2]
|
|
744
764
|
labels += 1 # Background offset
|
|
745
765
|
|
|
746
|
-
target_sizes = torch.tensor(image_shapes, device=class_logits.device)
|
|
747
|
-
|
|
748
766
|
# Convert to [x0, y0, x1, y1] format
|
|
749
767
|
boxes = box_ops.box_convert(box_regression, in_fmt="cxcywh", out_fmt="xyxy")
|
|
750
768
|
boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
|
|
751
769
|
|
|
752
770
|
# Convert from relative [0, 1] to absolute [0, height] coordinates
|
|
753
|
-
img_h, img_w =
|
|
771
|
+
img_h, img_w = image_sizes.unbind(1)
|
|
754
772
|
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
|
|
755
773
|
boxes = boxes * scale_fct[:, None, :]
|
|
756
774
|
|
|
@@ -776,16 +794,7 @@ class Deformable_DETR(DetectionBaseNet):
|
|
|
776
794
|
return detections
|
|
777
795
|
|
|
778
796
|
# pylint: disable=too-many-locals
|
|
779
|
-
def
|
|
780
|
-
self,
|
|
781
|
-
x: torch.Tensor,
|
|
782
|
-
targets: Optional[list[dict[str, torch.Tensor]]] = None,
|
|
783
|
-
masks: Optional[torch.Tensor] = None,
|
|
784
|
-
image_sizes: Optional[list[list[int]]] = None,
|
|
785
|
-
) -> tuple[list[dict[str, torch.Tensor]], dict[str, torch.Tensor]]:
|
|
786
|
-
self._input_check(targets)
|
|
787
|
-
images = self._to_img_list(x, image_sizes)
|
|
788
|
-
|
|
797
|
+
def forward_net(self, x: torch.Tensor, masks: Optional[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
789
798
|
features: dict[str, torch.Tensor] = self.backbone.detection_features(x)
|
|
790
799
|
feature_list = list(features.values())
|
|
791
800
|
mask_list = []
|
|
@@ -829,6 +838,20 @@ class Deformable_DETR(DetectionBaseNet):
|
|
|
829
838
|
outputs_class = torch.stack(outputs_classes)
|
|
830
839
|
outputs_coord = torch.stack(outputs_coords)
|
|
831
840
|
|
|
841
|
+
return (outputs_class, outputs_coord)
|
|
842
|
+
|
|
843
|
+
def forward(
|
|
844
|
+
self,
|
|
845
|
+
x: torch.Tensor,
|
|
846
|
+
targets: Optional[list[dict[str, torch.Tensor]]] = None,
|
|
847
|
+
masks: Optional[torch.Tensor] = None,
|
|
848
|
+
image_sizes: Optional[list[tuple[int, int]]] = None,
|
|
849
|
+
) -> tuple[list[dict[str, torch.Tensor]], dict[str, torch.Tensor]]:
|
|
850
|
+
self._input_check(targets)
|
|
851
|
+
image_sizes_tensor = self._to_img_list(x, image_sizes).image_sizes
|
|
852
|
+
|
|
853
|
+
outputs_class, outputs_coord = self.forward_net(x, masks)
|
|
854
|
+
|
|
832
855
|
losses = {}
|
|
833
856
|
detections: list[dict[str, torch.Tensor]] = []
|
|
834
857
|
if self.training is True:
|
|
@@ -838,14 +861,15 @@ class Deformable_DETR(DetectionBaseNet):
|
|
|
838
861
|
for idx, target in enumerate(targets):
|
|
839
862
|
boxes = target["boxes"]
|
|
840
863
|
boxes = box_ops.box_convert(boxes, in_fmt="xyxy", out_fmt="cxcywh")
|
|
841
|
-
|
|
864
|
+
scale = image_sizes_tensor[idx].flip(0).repeat(2).float() # flip to [W, H], repeat to [W, H, W, H]
|
|
865
|
+
boxes = boxes / scale
|
|
842
866
|
targets[idx]["boxes"] = boxes
|
|
843
867
|
targets[idx]["labels"] = target["labels"] - 1 # No background
|
|
844
868
|
|
|
845
869
|
losses = self.compute_loss(targets, outputs_class, outputs_coord)
|
|
846
870
|
|
|
847
871
|
else:
|
|
848
|
-
detections = self.postprocess_detections(outputs_class[-1], outputs_coord[-1],
|
|
872
|
+
detections = self.postprocess_detections(outputs_class[-1], outputs_coord[-1], image_sizes_tensor)
|
|
849
873
|
|
|
850
874
|
return (detections, losses)
|
|
851
875
|
|
birder/net/detection/detr.py
CHANGED
|
@@ -49,7 +49,7 @@ class HungarianMatcher(nn.Module):
|
|
|
49
49
|
@torch.jit.unused # type: ignore[untyped-decorator]
|
|
50
50
|
def forward(
|
|
51
51
|
self, class_logits: torch.Tensor, box_regression: torch.Tensor, targets: list[dict[str, torch.Tensor]]
|
|
52
|
-
) -> list[torch.Tensor]:
|
|
52
|
+
) -> list[tuple[torch.Tensor, torch.Tensor]]:
|
|
53
53
|
with torch.no_grad():
|
|
54
54
|
B, num_queries = class_logits.shape[:2]
|
|
55
55
|
|
|
@@ -148,10 +148,9 @@ class TransformerDecoderLayer(nn.Module):
|
|
|
148
148
|
query_pos: torch.Tensor,
|
|
149
149
|
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
|
150
150
|
) -> torch.Tensor:
|
|
151
|
-
|
|
152
|
-
k = tgt + query_pos
|
|
151
|
+
q_k = tgt + query_pos
|
|
153
152
|
|
|
154
|
-
tgt2, _ = self.self_attn(
|
|
153
|
+
tgt2, _ = self.self_attn(q_k, q_k, value=tgt, need_weights=False)
|
|
155
154
|
tgt = tgt + self.dropout1(tgt2)
|
|
156
155
|
tgt = self.norm1(tgt)
|
|
157
156
|
tgt2, _ = self.multihead_attn(
|
|
@@ -341,7 +340,7 @@ class DETR(DetectionBaseNet):
|
|
|
341
340
|
)
|
|
342
341
|
self.pos_enc = PositionEmbeddingSine(hidden_dim // 2, normalize=True)
|
|
343
342
|
|
|
344
|
-
self.matcher = HungarianMatcher(cost_class=1, cost_bbox=5, cost_giou=2)
|
|
343
|
+
self.matcher = HungarianMatcher(cost_class=1.0, cost_bbox=5.0, cost_giou=2.0)
|
|
345
344
|
empty_weight = torch.ones(self.num_classes)
|
|
346
345
|
empty_weight[0] = 0.1
|
|
347
346
|
self.empty_weight = nn.Buffer(empty_weight)
|
|
@@ -365,7 +364,8 @@ class DETR(DetectionBaseNet):
|
|
|
365
364
|
for param in self.class_embed.parameters():
|
|
366
365
|
param.requires_grad_(True)
|
|
367
366
|
|
|
368
|
-
|
|
367
|
+
@staticmethod
|
|
368
|
+
def _get_src_permutation_idx(indices: list[tuple[torch.Tensor, torch.Tensor]]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
369
369
|
batch_idx = torch.concat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
|
|
370
370
|
src_idx = torch.concat([src for (src, _) in indices])
|
|
371
371
|
return (batch_idx, src_idx)
|
|
@@ -374,7 +374,7 @@ class DETR(DetectionBaseNet):
|
|
|
374
374
|
self,
|
|
375
375
|
cls_logits: torch.Tensor,
|
|
376
376
|
targets: list[dict[str, torch.Tensor]],
|
|
377
|
-
indices: list[torch.Tensor],
|
|
377
|
+
indices: list[tuple[torch.Tensor, torch.Tensor]],
|
|
378
378
|
) -> torch.Tensor:
|
|
379
379
|
idx = self._get_src_permutation_idx(indices)
|
|
380
380
|
target_classes_o = torch.concat([t["labels"][J] for t, (_, J) in zip(targets, indices)], dim=0)
|
|
@@ -388,7 +388,7 @@ class DETR(DetectionBaseNet):
|
|
|
388
388
|
self,
|
|
389
389
|
box_output: torch.Tensor,
|
|
390
390
|
targets: list[dict[str, torch.Tensor]],
|
|
391
|
-
indices: list[torch.Tensor],
|
|
391
|
+
indices: list[tuple[torch.Tensor, torch.Tensor]],
|
|
392
392
|
num_boxes: int,
|
|
393
393
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
394
394
|
idx = self._get_src_permutation_idx(indices)
|
|
@@ -422,7 +422,7 @@ class DETR(DetectionBaseNet):
|
|
|
422
422
|
if training_utils.is_dist_available_and_initialized() is True:
|
|
423
423
|
torch.distributed.all_reduce(num_boxes)
|
|
424
424
|
|
|
425
|
-
num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1)
|
|
425
|
+
num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1)
|
|
426
426
|
|
|
427
427
|
loss_ce_list = []
|
|
428
428
|
loss_bbox_list = []
|
|
@@ -447,20 +447,17 @@ class DETR(DetectionBaseNet):
|
|
|
447
447
|
return losses
|
|
448
448
|
|
|
449
449
|
def postprocess_detections(
|
|
450
|
-
self, class_logits: torch.Tensor, box_regression: torch.Tensor,
|
|
450
|
+
self, class_logits: torch.Tensor, box_regression: torch.Tensor, image_sizes: torch.Tensor
|
|
451
451
|
) -> list[dict[str, torch.Tensor]]:
|
|
452
452
|
prob = F.softmax(class_logits, -1)
|
|
453
453
|
scores, labels = prob[..., 1:].max(-1)
|
|
454
454
|
labels = labels + 1
|
|
455
455
|
|
|
456
|
-
# TorchScript doesn't support creating tensor from tuples, convert everything to lists
|
|
457
|
-
target_sizes = torch.tensor([list(s) for s in image_shapes], device=class_logits.device)
|
|
458
|
-
|
|
459
456
|
# Convert to [x0, y0, x1, y1] format
|
|
460
457
|
boxes = box_ops.box_convert(box_regression, in_fmt="cxcywh", out_fmt="xyxy")
|
|
461
458
|
|
|
462
459
|
# Convert from relative [0, 1] to absolute [0, height] coordinates
|
|
463
|
-
img_h, img_w =
|
|
460
|
+
img_h, img_w = image_sizes.unbind(1)
|
|
464
461
|
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
|
|
465
462
|
boxes = boxes * scale_fct[:, None, :]
|
|
466
463
|
|
|
@@ -485,16 +482,7 @@ class DETR(DetectionBaseNet):
|
|
|
485
482
|
|
|
486
483
|
return detections
|
|
487
484
|
|
|
488
|
-
def
|
|
489
|
-
self,
|
|
490
|
-
x: torch.Tensor,
|
|
491
|
-
targets: Optional[list[dict[str, torch.Tensor]]] = None,
|
|
492
|
-
masks: Optional[torch.Tensor] = None,
|
|
493
|
-
image_sizes: Optional[list[list[int]]] = None,
|
|
494
|
-
) -> tuple[list[dict[str, torch.Tensor]], dict[str, torch.Tensor]]:
|
|
495
|
-
self._input_check(targets)
|
|
496
|
-
images = self._to_img_list(x, image_sizes)
|
|
497
|
-
|
|
485
|
+
def forward_net(self, x: torch.Tensor, masks: Optional[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
498
486
|
features: dict[str, torch.Tensor] = self.backbone.detection_features(x)
|
|
499
487
|
x = features[self.backbone.return_stages[-1]]
|
|
500
488
|
if masks is not None:
|
|
@@ -505,6 +493,20 @@ class DETR(DetectionBaseNet):
|
|
|
505
493
|
outputs_class = self.class_embed(hs)
|
|
506
494
|
outputs_coord = self.bbox_embed(hs).sigmoid()
|
|
507
495
|
|
|
496
|
+
return (outputs_class, outputs_coord)
|
|
497
|
+
|
|
498
|
+
def forward(
|
|
499
|
+
self,
|
|
500
|
+
x: torch.Tensor,
|
|
501
|
+
targets: Optional[list[dict[str, torch.Tensor]]] = None,
|
|
502
|
+
masks: Optional[torch.Tensor] = None,
|
|
503
|
+
image_sizes: Optional[list[tuple[int, int]]] = None,
|
|
504
|
+
) -> tuple[list[dict[str, torch.Tensor]], dict[str, torch.Tensor]]:
|
|
505
|
+
self._input_check(targets)
|
|
506
|
+
image_sizes_tensor = self._to_img_list(x, image_sizes).image_sizes
|
|
507
|
+
|
|
508
|
+
outputs_class, outputs_coord = self.forward_net(x, masks)
|
|
509
|
+
|
|
508
510
|
losses = {}
|
|
509
511
|
detections: list[dict[str, torch.Tensor]] = []
|
|
510
512
|
if self.training is True:
|
|
@@ -514,13 +516,14 @@ class DETR(DetectionBaseNet):
|
|
|
514
516
|
for idx, target in enumerate(targets):
|
|
515
517
|
boxes = target["boxes"]
|
|
516
518
|
boxes = box_ops.box_convert(boxes, in_fmt="xyxy", out_fmt="cxcywh")
|
|
517
|
-
|
|
519
|
+
scale = image_sizes_tensor[idx].flip(0).repeat(2).float() # flip to [W, H], repeat to [W, H, W, H]
|
|
520
|
+
boxes = boxes / scale
|
|
518
521
|
targets[idx]["boxes"] = boxes
|
|
519
522
|
|
|
520
523
|
losses = self.compute_loss(targets, outputs_class, outputs_coord)
|
|
521
524
|
|
|
522
525
|
else:
|
|
523
|
-
detections = self.postprocess_detections(outputs_class[-1], outputs_coord[-1],
|
|
526
|
+
detections = self.postprocess_detections(outputs_class[-1], outputs_coord[-1], image_sizes_tensor)
|
|
524
527
|
|
|
525
528
|
return (detections, losses)
|
|
526
529
|
|