birder 0.4.1__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/common/training_cli.py +6 -1
- birder/common/training_utils.py +69 -12
- birder/net/_vit_configs.py +5 -0
- birder/net/cait.py +3 -3
- birder/net/coat.py +3 -3
- birder/net/deit.py +1 -1
- birder/net/deit3.py +1 -1
- birder/net/detection/__init__.py +2 -0
- birder/net/detection/deformable_detr.py +12 -12
- birder/net/detection/detr.py +7 -7
- birder/net/detection/lw_detr.py +1181 -0
- birder/net/detection/plain_detr.py +7 -5
- birder/net/detection/retinanet.py +1 -1
- birder/net/detection/rt_detr_v1.py +10 -10
- birder/net/detection/rt_detr_v2.py +47 -64
- birder/net/detection/ssdlite.py +2 -2
- birder/net/edgevit.py +3 -3
- birder/net/efficientvit_msft.py +1 -1
- birder/net/flexivit.py +1 -1
- birder/net/hieradet.py +2 -2
- birder/net/mnasnet.py +2 -2
- birder/net/resnext.py +2 -2
- birder/net/rope_deit3.py +1 -1
- birder/net/rope_flexivit.py +1 -1
- birder/net/rope_vit.py +1 -1
- birder/net/simple_vit.py +1 -1
- birder/net/vit.py +21 -3
- birder/net/vit_parallel.py +1 -1
- birder/net/vit_sam.py +62 -16
- birder/scripts/train.py +12 -8
- birder/scripts/train_capi.py +13 -10
- birder/scripts/train_detection.py +2 -1
- birder/scripts/train_kd.py +12 -8
- birder/version.py +1 -1
- {birder-0.4.1.dist-info → birder-0.4.2.dist-info}/METADATA +3 -3
- {birder-0.4.1.dist-info → birder-0.4.2.dist-info}/RECORD +40 -39
- {birder-0.4.1.dist-info → birder-0.4.2.dist-info}/WHEEL +1 -1
- {birder-0.4.1.dist-info → birder-0.4.2.dist-info}/entry_points.txt +0 -0
- {birder-0.4.1.dist-info → birder-0.4.2.dist-info}/licenses/LICENSE +0 -0
- {birder-0.4.1.dist-info → birder-0.4.2.dist-info}/top_level.txt +0 -0
birder/common/training_cli.py
CHANGED
|
@@ -56,7 +56,9 @@ def add_optimization_args(parser: argparse.ArgumentParser, default_batch_size: i
|
|
|
56
56
|
)
|
|
57
57
|
|
|
58
58
|
|
|
59
|
-
def add_lr_wd_args(
|
|
59
|
+
def add_lr_wd_args(
|
|
60
|
+
parser: argparse.ArgumentParser, backbone_lr: bool = False, wd_end: bool = False, backbone_layer_decay: bool = False
|
|
61
|
+
) -> None:
|
|
60
62
|
group = parser.add_argument_group("Learning rate and regularization parameters")
|
|
61
63
|
group.add_argument("--lr", type=float, default=0.1, metavar="LR", help="base learning rate")
|
|
62
64
|
group.add_argument("--bias-lr", type=float, metavar="LR", help="learning rate of biases")
|
|
@@ -92,6 +94,9 @@ def add_lr_wd_args(parser: argparse.ArgumentParser, backbone_lr: bool = False, w
|
|
|
92
94
|
help="custom weight decay for specific layers by name (e.g., offset_conv=0.0)",
|
|
93
95
|
)
|
|
94
96
|
group.add_argument("--layer-decay", type=float, help="layer-wise learning rate decay (LLRD)")
|
|
97
|
+
if backbone_layer_decay is True:
|
|
98
|
+
group.add_argument("--backbone-layer-decay", type=float, help="backbone layer-wise learning rate decay (LLRD)")
|
|
99
|
+
|
|
95
100
|
group.add_argument("--layer-decay-min-scale", type=float, help="minimum layer scale factor clamp value")
|
|
96
101
|
group.add_argument(
|
|
97
102
|
"--layer-decay-no-opt-scale", type=float, help="layer scale threshold below which parameters are frozen"
|
birder/common/training_utils.py
CHANGED
|
@@ -343,7 +343,7 @@ def count_layers(model: torch.nn.Module) -> int:
|
|
|
343
343
|
return num_layers
|
|
344
344
|
|
|
345
345
|
|
|
346
|
-
# pylint: disable=protected-access,too-many-locals,too-many-branches
|
|
346
|
+
# pylint: disable=protected-access,too-many-locals,too-many-branches,too-many-statements
|
|
347
347
|
def optimizer_parameter_groups(
|
|
348
348
|
model: torch.nn.Module,
|
|
349
349
|
weight_decay: float,
|
|
@@ -352,6 +352,7 @@ def optimizer_parameter_groups(
|
|
|
352
352
|
custom_keys_weight_decay: Optional[list[tuple[str, float]]] = None,
|
|
353
353
|
custom_layer_weight_decay: Optional[dict[str, float]] = None,
|
|
354
354
|
layer_decay: Optional[float] = None,
|
|
355
|
+
backbone_layer_decay: Optional[float] = None,
|
|
355
356
|
layer_decay_min_scale: Optional[float] = None,
|
|
356
357
|
layer_decay_no_opt_scale: Optional[float] = None,
|
|
357
358
|
bias_lr: Optional[float] = None,
|
|
@@ -388,6 +389,8 @@ def optimizer_parameter_groups(
|
|
|
388
389
|
Applied to parameters whose names contain the specified keys.
|
|
389
390
|
layer_decay
|
|
390
391
|
Layer-wise learning rate decay factor.
|
|
392
|
+
backbone_layer_decay
|
|
393
|
+
Layer-wise learning rate decay factor for backbone parameters only.
|
|
391
394
|
layer_decay_min_scale
|
|
392
395
|
Minimum learning rate scale factor when using layer decay. Prevents layers from having too small learning rates.
|
|
393
396
|
layer_decay_no_opt_scale
|
|
@@ -434,6 +437,27 @@ def optimizer_parameter_groups(
|
|
|
434
437
|
if layer_decay is not None:
|
|
435
438
|
logger.warning("Assigning lr scaling (layer decay) without a block group map")
|
|
436
439
|
|
|
440
|
+
backbone_group_map: dict[str, int] = {}
|
|
441
|
+
backbone_num_layers = 0
|
|
442
|
+
if backbone_layer_decay is not None:
|
|
443
|
+
backbone_module = getattr(model, "backbone", None)
|
|
444
|
+
if backbone_module is None:
|
|
445
|
+
logger.warning("Backbone layer decay requested but model has no backbone")
|
|
446
|
+
backbone_layer_decay = None
|
|
447
|
+
else:
|
|
448
|
+
backbone_block_group_regex = getattr(backbone_module, "block_group_regex", None)
|
|
449
|
+
if backbone_block_group_regex is not None:
|
|
450
|
+
names = [n for n, _ in backbone_module.named_parameters()]
|
|
451
|
+
groups = group_by_regex(names, backbone_block_group_regex)
|
|
452
|
+
backbone_group_map = {
|
|
453
|
+
f"backbone.{item}": index for index, sublist in enumerate(groups) for item in sublist
|
|
454
|
+
}
|
|
455
|
+
backbone_num_layers = len(groups)
|
|
456
|
+
else:
|
|
457
|
+
backbone_group_map = {}
|
|
458
|
+
backbone_num_layers = count_layers(backbone_module)
|
|
459
|
+
logger.warning("Assigning lr scaling (backbone layer decay) without a block group map")
|
|
460
|
+
|
|
437
461
|
# Build layer scale
|
|
438
462
|
if layer_decay_min_scale is None:
|
|
439
463
|
layer_decay_min_scale = 0.0
|
|
@@ -444,14 +468,28 @@ def optimizer_parameter_groups(
|
|
|
444
468
|
layer_scales = [max(layer_decay_min_scale, layer_decay ** (layer_max - i)) for i in range(num_layers)]
|
|
445
469
|
logger.info(f"Layer scaling ranges from {min(layer_scales)} to {max(layer_scales)} across {num_layers} layers")
|
|
446
470
|
|
|
471
|
+
backbone_layer_scales = []
|
|
472
|
+
if backbone_layer_decay is not None:
|
|
473
|
+
backbone_layer_max = backbone_num_layers - 1
|
|
474
|
+
backbone_layer_scales = [
|
|
475
|
+
max(layer_decay_min_scale, backbone_layer_decay ** (backbone_layer_max - i))
|
|
476
|
+
for i in range(backbone_num_layers)
|
|
477
|
+
]
|
|
478
|
+
logger.info(
|
|
479
|
+
"Backbone layer scaling ranges from "
|
|
480
|
+
f"{min(backbone_layer_scales)} to {max(backbone_layer_scales)} across {backbone_num_layers} layers"
|
|
481
|
+
)
|
|
482
|
+
|
|
447
483
|
# Set weight decay and layer decay
|
|
448
484
|
idx = 0
|
|
485
|
+
backbone_idx = 0
|
|
449
486
|
params = []
|
|
450
487
|
module_stack_with_prefix = [(model, "")]
|
|
451
488
|
visited_modules = []
|
|
452
489
|
while len(module_stack_with_prefix) > 0: # pylint: disable=too-many-nested-blocks
|
|
453
490
|
skip_module = False
|
|
454
491
|
module, prefix = module_stack_with_prefix.pop()
|
|
492
|
+
is_backbone_module = prefix == "backbone" or prefix.startswith("backbone.")
|
|
455
493
|
if id(module) in visited_modules:
|
|
456
494
|
skip_module = True
|
|
457
495
|
|
|
@@ -460,23 +498,35 @@ def optimizer_parameter_groups(
|
|
|
460
498
|
for name, p in module.named_parameters(recurse=False):
|
|
461
499
|
target_name = f"{prefix}.{name}" if prefix != "" else name
|
|
462
500
|
idx = group_map.get(target_name, idx)
|
|
501
|
+
is_backbone_param = target_name.startswith("backbone.")
|
|
502
|
+
if backbone_layer_decay is not None and is_backbone_param is True:
|
|
503
|
+
backbone_idx = backbone_group_map.get(target_name, backbone_idx)
|
|
463
504
|
if skip_module is True:
|
|
464
505
|
break
|
|
465
506
|
|
|
466
507
|
parameters_found = True
|
|
467
508
|
if p.requires_grad is False:
|
|
468
509
|
continue
|
|
469
|
-
if
|
|
470
|
-
if
|
|
471
|
-
|
|
510
|
+
if layer_decay_no_opt_scale is not None:
|
|
511
|
+
if backbone_layer_decay is not None and is_backbone_param is True:
|
|
512
|
+
if backbone_layer_scales and backbone_layer_scales[backbone_idx] < layer_decay_no_opt_scale:
|
|
513
|
+
p.requires_grad_(False)
|
|
514
|
+
elif layer_decay is not None:
|
|
515
|
+
if layer_scales[idx] < layer_decay_no_opt_scale:
|
|
516
|
+
p.requires_grad_(False)
|
|
472
517
|
|
|
473
518
|
is_custom_key = False
|
|
474
519
|
if custom_keys_weight_decay is not None:
|
|
475
520
|
for key, custom_wd in custom_keys_weight_decay:
|
|
476
521
|
target_name_for_custom_key = f"{prefix}.{name}" if prefix != "" and "." in key else name
|
|
477
522
|
if key == target_name_for_custom_key:
|
|
478
|
-
# Calculate lr_scale (from layer_decay or custom_layer_lr_scale)
|
|
479
|
-
|
|
523
|
+
# Calculate lr_scale (from layer_decay/backbone_layer_decay or custom_layer_lr_scale)
|
|
524
|
+
if layer_decay is not None and (backbone_layer_decay is None or is_backbone_param is False):
|
|
525
|
+
lr_scale = layer_scales[idx]
|
|
526
|
+
elif backbone_layer_decay is not None and is_backbone_param is True:
|
|
527
|
+
lr_scale = backbone_layer_scales[backbone_idx]
|
|
528
|
+
else:
|
|
529
|
+
lr_scale = 1.0
|
|
480
530
|
if custom_layer_lr_scale is not None:
|
|
481
531
|
for layer_name_key, custom_scale in custom_layer_lr_scale.items():
|
|
482
532
|
if layer_name_key in target_name:
|
|
@@ -500,8 +550,8 @@ def optimizer_parameter_groups(
|
|
|
500
550
|
# Apply learning rate based on priority: bias_lr > backbone_lr > lr_scale
|
|
501
551
|
if bias_lr is not None and target_name.endswith(".bias") is True:
|
|
502
552
|
d["lr"] = bias_lr
|
|
503
|
-
elif backbone_lr is not None and
|
|
504
|
-
d["lr"] = backbone_lr
|
|
553
|
+
elif backbone_lr is not None and is_backbone_param is True:
|
|
554
|
+
d["lr"] = backbone_lr * lr_scale if backbone_layer_decay is not None else backbone_lr
|
|
505
555
|
elif lr_scale != 1.0:
|
|
506
556
|
d["lr"] = base_lr * lr_scale
|
|
507
557
|
|
|
@@ -522,8 +572,13 @@ def optimizer_parameter_groups(
|
|
|
522
572
|
wd = custom_wd_value
|
|
523
573
|
break
|
|
524
574
|
|
|
525
|
-
# Calculate lr_scale (from layer_decay or custom_layer_lr_scale)
|
|
526
|
-
|
|
575
|
+
# Calculate lr_scale (from layer_decay/backbone_layer_decay or custom_layer_lr_scale)
|
|
576
|
+
if layer_decay is not None and (backbone_layer_decay is None or is_backbone_param is False):
|
|
577
|
+
lr_scale = layer_scales[idx]
|
|
578
|
+
elif backbone_layer_decay is not None and is_backbone_param is True:
|
|
579
|
+
lr_scale = backbone_layer_scales[backbone_idx]
|
|
580
|
+
else:
|
|
581
|
+
lr_scale = 1.0
|
|
527
582
|
if custom_layer_lr_scale is not None:
|
|
528
583
|
for layer_name_key, custom_scale in custom_layer_lr_scale.items():
|
|
529
584
|
if layer_name_key in target_name:
|
|
@@ -539,8 +594,8 @@ def optimizer_parameter_groups(
|
|
|
539
594
|
# Apply learning rate based on priority: bias_lr > backbone_lr > lr_scale
|
|
540
595
|
if bias_lr is not None and target_name.endswith(".bias") is True:
|
|
541
596
|
d["lr"] = bias_lr
|
|
542
|
-
elif backbone_lr is not None and
|
|
543
|
-
d["lr"] = backbone_lr
|
|
597
|
+
elif backbone_lr is not None and is_backbone_param is True:
|
|
598
|
+
d["lr"] = backbone_lr * lr_scale if backbone_layer_decay is not None else backbone_lr
|
|
544
599
|
elif lr_scale != 1.0:
|
|
545
600
|
d["lr"] = base_lr * lr_scale
|
|
546
601
|
|
|
@@ -548,6 +603,8 @@ def optimizer_parameter_groups(
|
|
|
548
603
|
|
|
549
604
|
if parameters_found is True:
|
|
550
605
|
idx += 1
|
|
606
|
+
if is_backbone_module is True:
|
|
607
|
+
backbone_idx += 1
|
|
551
608
|
|
|
552
609
|
for child_name, child_module in reversed(list(module.named_children())):
|
|
553
610
|
child_prefix = f"{prefix}.{child_name}" if prefix != "" else child_name
|
birder/net/_vit_configs.py
CHANGED
|
@@ -111,6 +111,11 @@ def register_vit_configs(vit: type[BaseNet]) -> None:
|
|
|
111
111
|
vit,
|
|
112
112
|
config={"patch_size": 16, **BASE, "layer_scale_init_value": 1e-5},
|
|
113
113
|
)
|
|
114
|
+
registry.register_model_config(
|
|
115
|
+
"vit_b16_pn",
|
|
116
|
+
vit,
|
|
117
|
+
config={"patch_size": 16, **BASE, "pre_norm": True, "norm_layer_eps": 1e-5},
|
|
118
|
+
)
|
|
114
119
|
registry.register_model_config(
|
|
115
120
|
"vit_b16_qkn_ls",
|
|
116
121
|
vit,
|
birder/net/cait.py
CHANGED
|
@@ -231,11 +231,11 @@ class CaiT(BaseNet):
|
|
|
231
231
|
if isinstance(m, nn.Linear):
|
|
232
232
|
nn.init.trunc_normal_(m.weight, std=0.02)
|
|
233
233
|
if m.bias is not None:
|
|
234
|
-
nn.init.
|
|
234
|
+
nn.init.zeros_(m.bias)
|
|
235
235
|
|
|
236
236
|
elif isinstance(m, nn.LayerNorm):
|
|
237
|
-
nn.init.
|
|
238
|
-
nn.init.
|
|
237
|
+
nn.init.zeros_(m.bias)
|
|
238
|
+
nn.init.ones_(m.weight)
|
|
239
239
|
|
|
240
240
|
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
|
241
241
|
nn.init.trunc_normal_(self.cls_token, std=0.02)
|
birder/net/coat.py
CHANGED
|
@@ -474,11 +474,11 @@ class CoaT(DetectorBackbone):
|
|
|
474
474
|
if isinstance(m, nn.Linear):
|
|
475
475
|
nn.init.trunc_normal_(m.weight, std=0.02)
|
|
476
476
|
if m.bias is not None:
|
|
477
|
-
nn.init.
|
|
477
|
+
nn.init.zeros_(m.bias)
|
|
478
478
|
|
|
479
479
|
elif isinstance(m, nn.LayerNorm):
|
|
480
|
-
nn.init.
|
|
481
|
-
nn.init.
|
|
480
|
+
nn.init.zeros_(m.bias)
|
|
481
|
+
nn.init.ones_(m.weight)
|
|
482
482
|
|
|
483
483
|
nn.init.trunc_normal_(self.cls_token1, std=0.02)
|
|
484
484
|
nn.init.trunc_normal_(self.cls_token2, std=0.02)
|
birder/net/deit.py
CHANGED
|
@@ -167,7 +167,7 @@ class DeiT(DetectorBackbone):
|
|
|
167
167
|
xs = self.encoder.forward_features(x, out_indices=self.out_indices)
|
|
168
168
|
|
|
169
169
|
out: dict[str, torch.Tensor] = {}
|
|
170
|
-
for stage_name, stage_x in zip(self.return_stages, xs):
|
|
170
|
+
for stage_name, stage_x in zip(self.return_stages, xs, strict=True):
|
|
171
171
|
stage_x = stage_x[:, self.num_special_tokens :]
|
|
172
172
|
stage_x = stage_x.permute(0, 2, 1)
|
|
173
173
|
B, C, _ = stage_x.size()
|
birder/net/deit3.py
CHANGED
|
@@ -185,7 +185,7 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
|
|
|
185
185
|
xs = self.encoder.forward_features(x, out_indices=self.out_indices)
|
|
186
186
|
|
|
187
187
|
out: dict[str, torch.Tensor] = {}
|
|
188
|
-
for stage_name, stage_x in zip(self.return_stages, xs):
|
|
188
|
+
for stage_name, stage_x in zip(self.return_stages, xs, strict=True):
|
|
189
189
|
stage_x = stage_x[:, self.num_special_tokens :]
|
|
190
190
|
stage_x = stage_x.permute(0, 2, 1)
|
|
191
191
|
B, C, _ = stage_x.size()
|
birder/net/detection/__init__.py
CHANGED
|
@@ -3,6 +3,7 @@ from birder.net.detection.detr import DETR
|
|
|
3
3
|
from birder.net.detection.efficientdet import EfficientDet
|
|
4
4
|
from birder.net.detection.faster_rcnn import Faster_RCNN
|
|
5
5
|
from birder.net.detection.fcos import FCOS
|
|
6
|
+
from birder.net.detection.lw_detr import LW_DETR
|
|
6
7
|
from birder.net.detection.plain_detr import Plain_DETR
|
|
7
8
|
from birder.net.detection.retinanet import RetinaNet
|
|
8
9
|
from birder.net.detection.rt_detr_v1 import RT_DETR_v1
|
|
@@ -21,6 +22,7 @@ __all__ = [
|
|
|
21
22
|
"EfficientDet",
|
|
22
23
|
"Faster_RCNN",
|
|
23
24
|
"FCOS",
|
|
25
|
+
"LW_DETR",
|
|
24
26
|
"Plain_DETR",
|
|
25
27
|
"RetinaNet",
|
|
26
28
|
"RT_DETR_v1",
|
|
@@ -56,7 +56,7 @@ class HungarianMatcher(nn.Module):
|
|
|
56
56
|
@torch.jit.unused # type: ignore[untyped-decorator]
|
|
57
57
|
def forward(
|
|
58
58
|
self, class_logits: torch.Tensor, box_regression: torch.Tensor, targets: list[dict[str, torch.Tensor]]
|
|
59
|
-
) -> list[torch.Tensor]:
|
|
59
|
+
) -> list[tuple[torch.Tensor, torch.Tensor]]:
|
|
60
60
|
with torch.no_grad():
|
|
61
61
|
B, num_queries = class_logits.shape[:2]
|
|
62
62
|
|
|
@@ -135,7 +135,7 @@ class MultiScaleDeformableAttention(nn.Module):
|
|
|
135
135
|
self.reset_parameters()
|
|
136
136
|
|
|
137
137
|
def reset_parameters(self) -> None:
|
|
138
|
-
nn.init.
|
|
138
|
+
nn.init.zeros_(self.sampling_offsets.weight)
|
|
139
139
|
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
|
|
140
140
|
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
|
|
141
141
|
grid_init = (
|
|
@@ -149,12 +149,12 @@ class MultiScaleDeformableAttention(nn.Module):
|
|
|
149
149
|
with torch.no_grad():
|
|
150
150
|
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
|
|
151
151
|
|
|
152
|
-
nn.init.
|
|
153
|
-
nn.init.
|
|
152
|
+
nn.init.zeros_(self.attention_weights.weight)
|
|
153
|
+
nn.init.zeros_(self.attention_weights.bias)
|
|
154
154
|
nn.init.xavier_uniform_(self.value_proj.weight)
|
|
155
|
-
nn.init.
|
|
155
|
+
nn.init.zeros_(self.value_proj.bias)
|
|
156
156
|
nn.init.xavier_uniform_(self.output_proj.weight)
|
|
157
|
-
nn.init.
|
|
157
|
+
nn.init.zeros_(self.output_proj.bias)
|
|
158
158
|
|
|
159
159
|
def forward(
|
|
160
160
|
self,
|
|
@@ -279,11 +279,10 @@ class DeformableTransformerDecoderLayer(nn.Module):
|
|
|
279
279
|
self_attn_mask: Optional[torch.Tensor] = None,
|
|
280
280
|
) -> torch.Tensor:
|
|
281
281
|
# Self attention
|
|
282
|
-
|
|
283
|
-
k = tgt + query_pos
|
|
282
|
+
q_k = tgt + query_pos
|
|
284
283
|
|
|
285
284
|
tgt2, _ = self.self_attn(
|
|
286
|
-
|
|
285
|
+
q_k.transpose(0, 1), q_k.transpose(0, 1), tgt.transpose(0, 1), need_weights=False, attn_mask=self_attn_mask
|
|
287
286
|
)
|
|
288
287
|
tgt2 = tgt2.transpose(0, 1)
|
|
289
288
|
tgt = tgt + self.dropout(tgt2)
|
|
@@ -587,7 +586,7 @@ class Deformable_DETR(DetectionBaseNet):
|
|
|
587
586
|
|
|
588
587
|
self.query_embed = nn.Embedding(num_queries, hidden_dim * 2)
|
|
589
588
|
self.pos_enc = PositionEmbeddingSine(hidden_dim // 2, normalize=True)
|
|
590
|
-
self.matcher = HungarianMatcher(cost_class=2, cost_bbox=5, cost_giou=2)
|
|
589
|
+
self.matcher = HungarianMatcher(cost_class=2.0, cost_bbox=5.0, cost_giou=2.0)
|
|
591
590
|
|
|
592
591
|
class_embed = nn.Linear(hidden_dim, self.num_classes)
|
|
593
592
|
bbox_embed = MLP(hidden_dim, [hidden_dim, hidden_dim, 4], activation_layer=nn.ReLU)
|
|
@@ -641,7 +640,8 @@ class Deformable_DETR(DetectionBaseNet):
|
|
|
641
640
|
for param in self.class_embed.parameters():
|
|
642
641
|
param.requires_grad_(True)
|
|
643
642
|
|
|
644
|
-
|
|
643
|
+
@staticmethod
|
|
644
|
+
def _get_src_permutation_idx(indices: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
645
645
|
batch_idx = torch.concat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
|
|
646
646
|
src_idx = torch.concat([src for (src, _) in indices])
|
|
647
647
|
return (batch_idx, src_idx)
|
|
@@ -709,7 +709,7 @@ class Deformable_DETR(DetectionBaseNet):
|
|
|
709
709
|
if training_utils.is_dist_available_and_initialized() is True:
|
|
710
710
|
torch.distributed.all_reduce(num_boxes)
|
|
711
711
|
|
|
712
|
-
num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1)
|
|
712
|
+
num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1)
|
|
713
713
|
|
|
714
714
|
loss_ce_list = []
|
|
715
715
|
loss_bbox_list = []
|
birder/net/detection/detr.py
CHANGED
|
@@ -49,7 +49,7 @@ class HungarianMatcher(nn.Module):
|
|
|
49
49
|
@torch.jit.unused # type: ignore[untyped-decorator]
|
|
50
50
|
def forward(
|
|
51
51
|
self, class_logits: torch.Tensor, box_regression: torch.Tensor, targets: list[dict[str, torch.Tensor]]
|
|
52
|
-
) -> list[torch.Tensor]:
|
|
52
|
+
) -> list[tuple[torch.Tensor, torch.Tensor]]:
|
|
53
53
|
with torch.no_grad():
|
|
54
54
|
B, num_queries = class_logits.shape[:2]
|
|
55
55
|
|
|
@@ -148,10 +148,9 @@ class TransformerDecoderLayer(nn.Module):
|
|
|
148
148
|
query_pos: torch.Tensor,
|
|
149
149
|
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
|
150
150
|
) -> torch.Tensor:
|
|
151
|
-
|
|
152
|
-
k = tgt + query_pos
|
|
151
|
+
q_k = tgt + query_pos
|
|
153
152
|
|
|
154
|
-
tgt2, _ = self.self_attn(
|
|
153
|
+
tgt2, _ = self.self_attn(q_k, q_k, value=tgt, need_weights=False)
|
|
155
154
|
tgt = tgt + self.dropout1(tgt2)
|
|
156
155
|
tgt = self.norm1(tgt)
|
|
157
156
|
tgt2, _ = self.multihead_attn(
|
|
@@ -341,7 +340,7 @@ class DETR(DetectionBaseNet):
|
|
|
341
340
|
)
|
|
342
341
|
self.pos_enc = PositionEmbeddingSine(hidden_dim // 2, normalize=True)
|
|
343
342
|
|
|
344
|
-
self.matcher = HungarianMatcher(cost_class=1, cost_bbox=5, cost_giou=2)
|
|
343
|
+
self.matcher = HungarianMatcher(cost_class=1.0, cost_bbox=5.0, cost_giou=2.0)
|
|
345
344
|
empty_weight = torch.ones(self.num_classes)
|
|
346
345
|
empty_weight[0] = 0.1
|
|
347
346
|
self.empty_weight = nn.Buffer(empty_weight)
|
|
@@ -365,7 +364,8 @@ class DETR(DetectionBaseNet):
|
|
|
365
364
|
for param in self.class_embed.parameters():
|
|
366
365
|
param.requires_grad_(True)
|
|
367
366
|
|
|
368
|
-
|
|
367
|
+
@staticmethod
|
|
368
|
+
def _get_src_permutation_idx(indices: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
369
369
|
batch_idx = torch.concat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
|
|
370
370
|
src_idx = torch.concat([src for (src, _) in indices])
|
|
371
371
|
return (batch_idx, src_idx)
|
|
@@ -422,7 +422,7 @@ class DETR(DetectionBaseNet):
|
|
|
422
422
|
if training_utils.is_dist_available_and_initialized() is True:
|
|
423
423
|
torch.distributed.all_reduce(num_boxes)
|
|
424
424
|
|
|
425
|
-
num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1)
|
|
425
|
+
num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1)
|
|
426
426
|
|
|
427
427
|
loss_ce_list = []
|
|
428
428
|
loss_bbox_list = []
|