birder 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/common/lib.py +2 -9
- birder/common/training_cli.py +18 -0
- birder/common/training_utils.py +123 -10
- birder/data/collators/detection.py +10 -3
- birder/data/datasets/coco.py +8 -10
- birder/data/transforms/detection.py +30 -13
- birder/inference/detection.py +108 -4
- birder/inference/wbf.py +226 -0
- birder/net/__init__.py +8 -0
- birder/net/detection/efficientdet.py +65 -86
- birder/net/detection/rt_detr_v1.py +1 -0
- birder/net/detection/yolo_anchors.py +205 -0
- birder/net/detection/yolo_v2.py +25 -24
- birder/net/detection/yolo_v3.py +39 -40
- birder/net/detection/yolo_v4.py +28 -26
- birder/net/detection/yolo_v4_tiny.py +24 -20
- birder/net/fasternet.py +1 -1
- birder/net/gc_vit.py +671 -0
- birder/net/lit_v1.py +472 -0
- birder/net/lit_v1_tiny.py +342 -0
- birder/net/lit_v2.py +436 -0
- birder/net/mobilenet_v4_hybrid.py +1 -1
- birder/net/resnet_v1.py +1 -1
- birder/net/resnext.py +67 -25
- birder/net/se_resnet_v1.py +46 -0
- birder/net/se_resnext.py +3 -0
- birder/net/simple_vit.py +2 -2
- birder/net/vit.py +0 -15
- birder/net/vovnet_v2.py +31 -1
- birder/scripts/benchmark.py +90 -21
- birder/scripts/predict.py +1 -0
- birder/scripts/predict_detection.py +18 -11
- birder/scripts/train.py +10 -34
- birder/scripts/train_barlow_twins.py +10 -34
- birder/scripts/train_byol.py +10 -34
- birder/scripts/train_capi.py +10 -35
- birder/scripts/train_data2vec.py +9 -34
- birder/scripts/train_data2vec2.py +9 -34
- birder/scripts/train_detection.py +48 -40
- birder/scripts/train_dino_v1.py +10 -34
- birder/scripts/train_dino_v2.py +9 -34
- birder/scripts/train_dino_v2_dist.py +9 -34
- birder/scripts/train_franca.py +9 -34
- birder/scripts/train_i_jepa.py +9 -34
- birder/scripts/train_ibot.py +9 -34
- birder/scripts/train_kd.py +156 -64
- birder/scripts/train_mim.py +10 -34
- birder/scripts/train_mmcr.py +10 -34
- birder/scripts/train_rotnet.py +10 -34
- birder/scripts/train_simclr.py +10 -34
- birder/scripts/train_vicreg.py +10 -34
- birder/tools/auto_anchors.py +20 -1
- birder/tools/pack.py +172 -103
- birder/tools/show_det_iterator.py +10 -1
- birder/version.py +1 -1
- {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/METADATA +3 -3
- {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/RECORD +61 -55
- {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/WHEEL +0 -0
- {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/entry_points.txt +0 -0
- {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/top_level.txt +0 -0
birder/common/lib.py
CHANGED
|
@@ -1,11 +1,7 @@
|
|
|
1
1
|
import os
|
|
2
|
-
import random
|
|
3
2
|
from typing import Any
|
|
4
3
|
from typing import Optional
|
|
5
4
|
|
|
6
|
-
import numpy as np
|
|
7
|
-
import torch
|
|
8
|
-
|
|
9
5
|
from birder.conf import settings
|
|
10
6
|
from birder.data.transforms.classification import RGBType
|
|
11
7
|
from birder.model_registry import registry
|
|
@@ -19,11 +15,8 @@ from birder.net.ssl.base import SSLBaseNet
|
|
|
19
15
|
from birder.version import __version__
|
|
20
16
|
|
|
21
17
|
|
|
22
|
-
def
|
|
23
|
-
|
|
24
|
-
torch.cuda.manual_seed_all(seed)
|
|
25
|
-
np.random.seed(seed)
|
|
26
|
-
random.seed(seed)
|
|
18
|
+
def env_bool(name: str) -> bool:
|
|
19
|
+
return os.environ.get(name, "").lower() in {"1", "true", "yes", "on"}
|
|
27
20
|
|
|
28
21
|
|
|
29
22
|
def get_size_from_signature(signature: SignatureType | DetectionSignatureType) -> tuple[int, int]:
|
birder/common/training_cli.py
CHANGED
|
@@ -5,6 +5,7 @@ import typing
|
|
|
5
5
|
from typing import Optional
|
|
6
6
|
from typing import get_args
|
|
7
7
|
|
|
8
|
+
from birder.common.cli import FlexibleDictAction
|
|
8
9
|
from birder.common.cli import ValidationError
|
|
9
10
|
from birder.common.training_utils import OptimizerType
|
|
10
11
|
from birder.common.training_utils import SchedulerType
|
|
@@ -82,11 +83,23 @@ def add_lr_wd_args(parser: argparse.ArgumentParser, backbone_lr: bool = False, w
|
|
|
82
83
|
metavar="WD",
|
|
83
84
|
help="weight decay for embedding parameters for vision transformer models",
|
|
84
85
|
)
|
|
86
|
+
group.add_argument(
|
|
87
|
+
"--custom-layer-wd",
|
|
88
|
+
action=FlexibleDictAction,
|
|
89
|
+
metavar="LAYER=WD",
|
|
90
|
+
help="custom weight decay for specific layers by name (e.g., offset_conv=0.0)",
|
|
91
|
+
)
|
|
85
92
|
group.add_argument("--layer-decay", type=float, help="layer-wise learning rate decay (LLRD)")
|
|
86
93
|
group.add_argument("--layer-decay-min-scale", type=float, help="minimum layer scale factor clamp value")
|
|
87
94
|
group.add_argument(
|
|
88
95
|
"--layer-decay-no-opt-scale", type=float, help="layer scale threshold below which parameters are frozen"
|
|
89
96
|
)
|
|
97
|
+
group.add_argument(
|
|
98
|
+
"--custom-layer-lr-scale",
|
|
99
|
+
action=FlexibleDictAction,
|
|
100
|
+
metavar="LAYER=SCALE",
|
|
101
|
+
help="custom lr_scale for specific layers by name (e.g., offset_conv=0.01,attention=0.5)",
|
|
102
|
+
)
|
|
90
103
|
|
|
91
104
|
|
|
92
105
|
def add_lr_scheduler_args(parser: argparse.ArgumentParser) -> None:
|
|
@@ -185,6 +198,11 @@ def add_detection_input_args(parser: argparse.ArgumentParser) -> None:
|
|
|
185
198
|
action="store_true",
|
|
186
199
|
help="enable random square resize once per batch (capped by max(--size))",
|
|
187
200
|
)
|
|
201
|
+
group.add_argument(
|
|
202
|
+
"--multiscale-min-size",
|
|
203
|
+
type=int,
|
|
204
|
+
help="minimum short-edge size for multiscale lists (rounded up to nearest multiple of 32)",
|
|
205
|
+
)
|
|
188
206
|
|
|
189
207
|
|
|
190
208
|
def add_training_schedule_args(parser: argparse.ArgumentParser, default_epochs: int = 100) -> None:
|
birder/common/training_utils.py
CHANGED
|
@@ -3,8 +3,10 @@ import contextlib
|
|
|
3
3
|
import logging
|
|
4
4
|
import math
|
|
5
5
|
import os
|
|
6
|
+
import random
|
|
6
7
|
import re
|
|
7
8
|
import subprocess
|
|
9
|
+
import sys
|
|
8
10
|
from collections import deque
|
|
9
11
|
from collections.abc import Callable
|
|
10
12
|
from collections.abc import Generator
|
|
@@ -29,12 +31,25 @@ from birder.data.transforms.classification import training_preset
|
|
|
29
31
|
from birder.optim import Lamb
|
|
30
32
|
from birder.optim import Lars
|
|
31
33
|
from birder.scheduler import CooldownLR
|
|
34
|
+
from birder.version import __version__ as birder_version
|
|
32
35
|
|
|
33
36
|
logger = logging.getLogger(__name__)
|
|
34
37
|
|
|
35
38
|
OptimizerType = Literal["sgd", "rmsprop", "adam", "adamw", "nadam", "nadamw", "lamb", "lambw", "lars"]
|
|
36
39
|
SchedulerType = Literal["constant", "step", "multistep", "cosine", "polynomial"]
|
|
37
40
|
|
|
41
|
+
###############################################################################
|
|
42
|
+
# Core Utilities
|
|
43
|
+
###############################################################################
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def set_random_seeds(seed: int) -> None:
|
|
47
|
+
torch.manual_seed(seed)
|
|
48
|
+
torch.cuda.manual_seed_all(seed)
|
|
49
|
+
np.random.seed(seed)
|
|
50
|
+
random.seed(seed)
|
|
51
|
+
|
|
52
|
+
|
|
38
53
|
###############################################################################
|
|
39
54
|
# Data Sampling
|
|
40
55
|
###############################################################################
|
|
@@ -207,13 +222,16 @@ def count_layers(model: torch.nn.Module) -> int:
|
|
|
207
222
|
def optimizer_parameter_groups(
|
|
208
223
|
model: torch.nn.Module,
|
|
209
224
|
weight_decay: float,
|
|
225
|
+
base_lr: float,
|
|
210
226
|
norm_weight_decay: Optional[float] = None,
|
|
211
227
|
custom_keys_weight_decay: Optional[list[tuple[str, float]]] = None,
|
|
228
|
+
custom_layer_weight_decay: Optional[dict[str, float]] = None,
|
|
212
229
|
layer_decay: Optional[float] = None,
|
|
213
230
|
layer_decay_min_scale: Optional[float] = None,
|
|
214
231
|
layer_decay_no_opt_scale: Optional[float] = None,
|
|
215
232
|
bias_lr: Optional[float] = None,
|
|
216
233
|
backbone_lr: Optional[float] = None,
|
|
234
|
+
custom_layer_lr_scale: Optional[dict[str, float]] = None,
|
|
217
235
|
) -> list[dict[str, Any]]:
|
|
218
236
|
"""
|
|
219
237
|
Return parameter groups for optimizers with per-parameter group weight decay.
|
|
@@ -233,11 +251,16 @@ def optimizer_parameter_groups(
|
|
|
233
251
|
The PyTorch model whose parameters will be grouped for optimization.
|
|
234
252
|
weight_decay
|
|
235
253
|
Default weight decay (L2 regularization) value applied to parameters.
|
|
254
|
+
base_lr
|
|
255
|
+
Base learning rate that will be scaled by lr_scale factors for each parameter group.
|
|
236
256
|
norm_weight_decay
|
|
237
257
|
Weight decay value specifically for normalization layers. If None, uses weight_decay.
|
|
238
258
|
custom_keys_weight_decay
|
|
239
259
|
List of (parameter_name, weight_decay) tuples for applying custom weight decay
|
|
240
260
|
values to specific parameters by name matching.
|
|
261
|
+
custom_layer_weight_decay
|
|
262
|
+
Dictionary mapping layer name substrings to custom weight decay values.
|
|
263
|
+
Applied to parameters whose names contain the specified keys.
|
|
241
264
|
layer_decay
|
|
242
265
|
Layer-wise learning rate decay factor.
|
|
243
266
|
layer_decay_min_scale
|
|
@@ -248,6 +271,9 @@ def optimizer_parameter_groups(
|
|
|
248
271
|
Custom learning rate for bias parameters (parameters ending with '.bias').
|
|
249
272
|
backbone_lr
|
|
250
273
|
Custom learning rate for backbone parameters (parameters starting with 'backbone.').
|
|
274
|
+
custom_layer_lr_scale
|
|
275
|
+
Dictionary mapping layer name substrings to custom lr_scale values.
|
|
276
|
+
Applied to parameters whose names contain the specified keys.
|
|
251
277
|
|
|
252
278
|
Returns
|
|
253
279
|
-------
|
|
@@ -291,14 +317,14 @@ def optimizer_parameter_groups(
|
|
|
291
317
|
if layer_decay is not None:
|
|
292
318
|
layer_max = num_layers - 1
|
|
293
319
|
layer_scales = [max(layer_decay_min_scale, layer_decay ** (layer_max - i)) for i in range(num_layers)]
|
|
294
|
-
logger.info(f"Layer scaling
|
|
320
|
+
logger.info(f"Layer scaling ranges from {min(layer_scales)} to {max(layer_scales)} across {num_layers} layers")
|
|
295
321
|
|
|
296
322
|
# Set weight decay and layer decay
|
|
297
323
|
idx = 0
|
|
298
324
|
params = []
|
|
299
325
|
module_stack_with_prefix = [(model, "")]
|
|
300
326
|
visited_modules = []
|
|
301
|
-
while len(module_stack_with_prefix) > 0:
|
|
327
|
+
while len(module_stack_with_prefix) > 0: # pylint: disable=too-many-nested-blocks
|
|
302
328
|
skip_module = False
|
|
303
329
|
(module, prefix) = module_stack_with_prefix.pop()
|
|
304
330
|
if id(module) in visited_modules:
|
|
@@ -324,13 +350,35 @@ def optimizer_parameter_groups(
|
|
|
324
350
|
for key, custom_wd in custom_keys_weight_decay:
|
|
325
351
|
target_name_for_custom_key = f"{prefix}.{name}" if prefix != "" and "." in key else name
|
|
326
352
|
if key == target_name_for_custom_key:
|
|
353
|
+
# Calculate lr_scale (from layer_decay or custom_layer_lr_scale)
|
|
354
|
+
lr_scale = 1.0 if layer_decay is None else layer_scales[idx]
|
|
355
|
+
if custom_layer_lr_scale is not None:
|
|
356
|
+
for layer_name_key, custom_scale in custom_layer_lr_scale.items():
|
|
357
|
+
if layer_name_key in target_name:
|
|
358
|
+
lr_scale = custom_scale
|
|
359
|
+
break
|
|
360
|
+
|
|
361
|
+
# Apply custom layer weight decay (substring matching)
|
|
362
|
+
wd = custom_wd
|
|
363
|
+
if custom_layer_weight_decay is not None:
|
|
364
|
+
for layer_name_key, custom_wd_value in custom_layer_weight_decay.items():
|
|
365
|
+
if layer_name_key in target_name:
|
|
366
|
+
wd = custom_wd_value
|
|
367
|
+
break
|
|
368
|
+
|
|
327
369
|
d = {
|
|
328
370
|
"params": p,
|
|
329
|
-
"weight_decay":
|
|
330
|
-
"lr_scale":
|
|
371
|
+
"weight_decay": wd,
|
|
372
|
+
"lr_scale": lr_scale, # Used only for reference/debugging
|
|
331
373
|
}
|
|
332
|
-
|
|
374
|
+
|
|
375
|
+
# Apply learning rate based on priority: bias_lr > backbone_lr > lr_scale
|
|
376
|
+
if bias_lr is not None and target_name.endswith(".bias") is True:
|
|
377
|
+
d["lr"] = bias_lr
|
|
378
|
+
elif backbone_lr is not None and target_name.startswith("backbone.") is True:
|
|
333
379
|
d["lr"] = backbone_lr
|
|
380
|
+
elif lr_scale != 1.0:
|
|
381
|
+
d["lr"] = base_lr * lr_scale
|
|
334
382
|
|
|
335
383
|
params.append(d)
|
|
336
384
|
is_custom_key = True
|
|
@@ -342,16 +390,34 @@ def optimizer_parameter_groups(
|
|
|
342
390
|
else:
|
|
343
391
|
wd = weight_decay
|
|
344
392
|
|
|
393
|
+
# Apply custom layer weight decay (substring matching)
|
|
394
|
+
if custom_layer_weight_decay is not None:
|
|
395
|
+
for layer_name_key, custom_wd_value in custom_layer_weight_decay.items():
|
|
396
|
+
if layer_name_key in target_name:
|
|
397
|
+
wd = custom_wd_value
|
|
398
|
+
break
|
|
399
|
+
|
|
400
|
+
# Calculate lr_scale (from layer_decay or custom_layer_lr_scale)
|
|
401
|
+
lr_scale = 1.0 if layer_decay is None else layer_scales[idx]
|
|
402
|
+
if custom_layer_lr_scale is not None:
|
|
403
|
+
for layer_name_key, custom_scale in custom_layer_lr_scale.items():
|
|
404
|
+
if layer_name_key in target_name:
|
|
405
|
+
lr_scale = custom_scale
|
|
406
|
+
break
|
|
407
|
+
|
|
345
408
|
d = {
|
|
346
409
|
"params": p,
|
|
347
410
|
"weight_decay": wd,
|
|
348
|
-
"lr_scale":
|
|
411
|
+
"lr_scale": lr_scale, # Used only for reference/debugging
|
|
349
412
|
}
|
|
350
|
-
if backbone_lr is not None and target_name.startswith("backbone.") is True:
|
|
351
|
-
d["lr"] = backbone_lr
|
|
352
413
|
|
|
414
|
+
# Apply learning rate based on priority: bias_lr > backbone_lr > lr_scale
|
|
353
415
|
if bias_lr is not None and target_name.endswith(".bias") is True:
|
|
354
416
|
d["lr"] = bias_lr
|
|
417
|
+
elif backbone_lr is not None and target_name.startswith("backbone.") is True:
|
|
418
|
+
d["lr"] = backbone_lr
|
|
419
|
+
elif lr_scale != 1.0:
|
|
420
|
+
d["lr"] = base_lr * lr_scale
|
|
355
421
|
|
|
356
422
|
params.append(d)
|
|
357
423
|
|
|
@@ -442,6 +508,8 @@ def get_optimizer(parameters: list[dict[str, Any]], l_rate: float, args: argpars
|
|
|
442
508
|
else:
|
|
443
509
|
raise ValueError("Unknown optimizer")
|
|
444
510
|
|
|
511
|
+
logger.debug(f"Created {opt} optimizer with lr={lr}, weight_decay={args.wd}")
|
|
512
|
+
|
|
445
513
|
return optimizer
|
|
446
514
|
|
|
447
515
|
|
|
@@ -477,10 +545,10 @@ def get_scheduler(
|
|
|
477
545
|
|
|
478
546
|
main_steps = steps - begin_step - remaining_warmup - remaining_cooldown - 1
|
|
479
547
|
|
|
480
|
-
logger.debug(f"
|
|
548
|
+
logger.debug(f"Scheduler using {steps_per_epoch} steps per epoch")
|
|
481
549
|
logger.debug(
|
|
482
550
|
f"Scheduler {args.lr_scheduler} set for {steps} steps of which {warmup_steps} "
|
|
483
|
-
f"are warmup and {cooldown_steps} cooldown"
|
|
551
|
+
f"are warmup and {cooldown_steps} are cooldown"
|
|
484
552
|
)
|
|
485
553
|
logger.debug(
|
|
486
554
|
f"Currently starting from step {begin_step} with {remaining_warmup} remaining warmup steps "
|
|
@@ -810,6 +878,51 @@ def is_local_primary(args: argparse.Namespace) -> bool:
|
|
|
810
878
|
return args.local_rank == 0 # type: ignore[no-any-return]
|
|
811
879
|
|
|
812
880
|
|
|
881
|
+
def init_training(
|
|
882
|
+
args: argparse.Namespace,
|
|
883
|
+
log: logging.Logger,
|
|
884
|
+
*,
|
|
885
|
+
cudnn_dynamic_size: bool = False,
|
|
886
|
+
) -> tuple[torch.device, int, bool]:
|
|
887
|
+
init_distributed_mode(args)
|
|
888
|
+
|
|
889
|
+
log.info(f"Starting training, birder version: {birder_version}, pytorch version: {torch.__version__}")
|
|
890
|
+
|
|
891
|
+
log_git_info()
|
|
892
|
+
|
|
893
|
+
if args.cpu is True:
|
|
894
|
+
device = torch.device("cpu")
|
|
895
|
+
device_id = 0
|
|
896
|
+
else:
|
|
897
|
+
device = torch.device("cuda")
|
|
898
|
+
device_id = torch.cuda.current_device()
|
|
899
|
+
|
|
900
|
+
if args.use_deterministic_algorithms is True:
|
|
901
|
+
torch.backends.cudnn.benchmark = False
|
|
902
|
+
torch.use_deterministic_algorithms(True)
|
|
903
|
+
elif cudnn_dynamic_size is True:
|
|
904
|
+
# Dynamic sizes: avoid per-size algorithm selection overhead.
|
|
905
|
+
torch.backends.cudnn.enabled = False
|
|
906
|
+
else:
|
|
907
|
+
torch.backends.cudnn.enabled = True
|
|
908
|
+
torch.backends.cudnn.benchmark = True
|
|
909
|
+
|
|
910
|
+
if args.seed is not None:
|
|
911
|
+
set_random_seeds(args.seed)
|
|
912
|
+
|
|
913
|
+
if args.non_interactive is True or is_local_primary(args) is False:
|
|
914
|
+
disable_tqdm = True
|
|
915
|
+
elif sys.stderr.isatty() is False:
|
|
916
|
+
disable_tqdm = True
|
|
917
|
+
else:
|
|
918
|
+
disable_tqdm = False
|
|
919
|
+
|
|
920
|
+
# Enable or disable the autograd anomaly detection.
|
|
921
|
+
torch.autograd.set_detect_anomaly(args.grad_anomaly_detection)
|
|
922
|
+
|
|
923
|
+
return (device, device_id, disable_tqdm)
|
|
924
|
+
|
|
925
|
+
|
|
813
926
|
###############################################################################
|
|
814
927
|
# Utility Functions
|
|
815
928
|
###############################################################################
|
|
@@ -1,13 +1,14 @@
|
|
|
1
1
|
import math
|
|
2
2
|
import random
|
|
3
3
|
from typing import Any
|
|
4
|
+
from typing import Optional
|
|
4
5
|
|
|
5
6
|
import torch
|
|
6
7
|
from torchvision import tv_tensors
|
|
7
8
|
from torchvision.transforms import v2
|
|
8
9
|
from torchvision.transforms.v2 import functional as F
|
|
9
10
|
|
|
10
|
-
|
|
11
|
+
from birder.data.transforms.detection import build_multiscale_sizes
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
def collate_fn(batch: list[tuple[Any, ...]]) -> tuple[Any, ...]:
|
|
@@ -63,13 +64,19 @@ class DetectionCollator:
|
|
|
63
64
|
|
|
64
65
|
|
|
65
66
|
class BatchRandomResizeCollator(DetectionCollator):
|
|
66
|
-
def __init__(
|
|
67
|
+
def __init__(
|
|
68
|
+
self,
|
|
69
|
+
input_offset: int,
|
|
70
|
+
size: tuple[int, int],
|
|
71
|
+
size_divisible: int = 32,
|
|
72
|
+
multiscale_min_size: Optional[int] = None,
|
|
73
|
+
) -> None:
|
|
67
74
|
super().__init__(input_offset, size_divisible=size_divisible)
|
|
68
75
|
if size is None:
|
|
69
76
|
raise ValueError("size must be provided for batch multiscale")
|
|
70
77
|
|
|
71
78
|
max_side = max(size)
|
|
72
|
-
sizes = [side for side in
|
|
79
|
+
sizes = [side for side in build_multiscale_sizes(multiscale_min_size) if side <= max_side]
|
|
73
80
|
if len(sizes) == 0:
|
|
74
81
|
sizes = [max_side]
|
|
75
82
|
|
birder/data/datasets/coco.py
CHANGED
|
@@ -98,10 +98,14 @@ class CocoTraining(CocoBase):
|
|
|
98
98
|
class CocoInference(CocoBase):
|
|
99
99
|
def __getitem__(self, index: int) -> tuple[str, torch.Tensor, Any, list[int]]:
|
|
100
100
|
coco_id = self.dataset.ids[index]
|
|
101
|
-
|
|
101
|
+
img_info = self.dataset.coco.loadImgs(coco_id)[0]
|
|
102
|
+
path = img_info["file_name"]
|
|
102
103
|
(sample, labels) = self.dataset[index]
|
|
103
104
|
|
|
104
|
-
|
|
105
|
+
# Get original image size (height, width) before transforms
|
|
106
|
+
orig_size = [img_info["height"], img_info["width"]]
|
|
107
|
+
|
|
108
|
+
return (path, sample, labels, orig_size)
|
|
105
109
|
|
|
106
110
|
|
|
107
111
|
class CocoMosaicTraining(CocoBase):
|
|
@@ -127,9 +131,7 @@ class CocoMosaicTraining(CocoBase):
|
|
|
127
131
|
self._mosaic_decay_epochs: Optional[int] = None
|
|
128
132
|
self._mosaic_decay_start: Optional[int] = None
|
|
129
133
|
|
|
130
|
-
def configure_mosaic_linear_decay(
|
|
131
|
-
self, base_prob: float, total_epochs: int, decay_fraction: float = 0.1
|
|
132
|
-
) -> None:
|
|
134
|
+
def configure_mosaic_linear_decay(self, base_prob: float, total_epochs: int, decay_fraction: float = 0.1) -> None:
|
|
133
135
|
if total_epochs <= 0:
|
|
134
136
|
raise ValueError("total_epochs must be positive")
|
|
135
137
|
if decay_fraction <= 0.0 or decay_fraction > 1.0:
|
|
@@ -141,11 +143,7 @@ class CocoMosaicTraining(CocoBase):
|
|
|
141
143
|
self._mosaic_decay_start = max(1, total_epochs - decay_epochs + 1)
|
|
142
144
|
|
|
143
145
|
def update_mosaic_prob(self, epoch: int) -> Optional[float]:
|
|
144
|
-
if
|
|
145
|
-
self._mosaic_base_prob is None
|
|
146
|
-
or self._mosaic_decay_epochs is None
|
|
147
|
-
or self._mosaic_decay_start is None
|
|
148
|
-
):
|
|
146
|
+
if self._mosaic_base_prob is None or self._mosaic_decay_epochs is None or self._mosaic_decay_start is None:
|
|
149
147
|
return None
|
|
150
148
|
|
|
151
149
|
if epoch >= self._mosaic_decay_start:
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import math
|
|
1
2
|
import random
|
|
2
3
|
from collections.abc import Callable
|
|
3
4
|
from typing import Any
|
|
@@ -10,6 +11,24 @@ from torchvision.transforms import v2
|
|
|
10
11
|
|
|
11
12
|
from birder.data.transforms.classification import RGBType
|
|
12
13
|
|
|
14
|
+
MULTISCALE_STEP = 32
|
|
15
|
+
DEFAULT_MULTISCALE_MIN_SIZE = 480
|
|
16
|
+
DEFAULT_MULTISCALE_MAX_SIZE = 800
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def build_multiscale_sizes(
|
|
20
|
+
min_size: Optional[int] = None, max_size: int = DEFAULT_MULTISCALE_MAX_SIZE
|
|
21
|
+
) -> tuple[int, ...]:
|
|
22
|
+
if min_size is None:
|
|
23
|
+
min_size = DEFAULT_MULTISCALE_MIN_SIZE
|
|
24
|
+
|
|
25
|
+
start = int(math.ceil(min_size / MULTISCALE_STEP) * MULTISCALE_STEP)
|
|
26
|
+
end = int(math.floor(max_size / MULTISCALE_STEP) * MULTISCALE_STEP)
|
|
27
|
+
if end < start:
|
|
28
|
+
return (start,)
|
|
29
|
+
|
|
30
|
+
return tuple(range(start, end + 1, MULTISCALE_STEP))
|
|
31
|
+
|
|
13
32
|
|
|
14
33
|
class ResizeWithRandomInterpolation(nn.Module):
|
|
15
34
|
def __init__(
|
|
@@ -39,6 +58,7 @@ def get_birder_augment(
|
|
|
39
58
|
dynamic_size: bool,
|
|
40
59
|
multiscale: bool,
|
|
41
60
|
max_size: Optional[int],
|
|
61
|
+
multiscale_min_size: Optional[int],
|
|
42
62
|
post_mosaic: bool = False,
|
|
43
63
|
) -> Callable[..., torch.Tensor]:
|
|
44
64
|
if dynamic_size is True:
|
|
@@ -78,9 +98,7 @@ def get_birder_augment(
|
|
|
78
98
|
# Resize
|
|
79
99
|
if multiscale is True:
|
|
80
100
|
transformations.append(
|
|
81
|
-
v2.RandomShortestSize(
|
|
82
|
-
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=max_size or 1333
|
|
83
|
-
),
|
|
101
|
+
v2.RandomShortestSize(min_size=build_multiscale_sizes(multiscale_min_size), max_size=max_size or 1333),
|
|
84
102
|
)
|
|
85
103
|
else:
|
|
86
104
|
transformations.append(
|
|
@@ -132,6 +150,7 @@ def get_birder_augment(
|
|
|
132
150
|
AugType = Literal["birder", "lsj", "multiscale", "ssd", "ssdlite", "yolo", "detr"]
|
|
133
151
|
|
|
134
152
|
|
|
153
|
+
# pylint: disable=too-many-return-statements
|
|
135
154
|
def training_preset(
|
|
136
155
|
size: tuple[int, int],
|
|
137
156
|
aug_type: AugType,
|
|
@@ -140,6 +159,7 @@ def training_preset(
|
|
|
140
159
|
dynamic_size: bool = False,
|
|
141
160
|
multiscale: bool = False,
|
|
142
161
|
max_size: Optional[int] = None,
|
|
162
|
+
multiscale_min_size: Optional[int] = None,
|
|
143
163
|
post_mosaic: bool = False,
|
|
144
164
|
) -> Callable[..., torch.Tensor]:
|
|
145
165
|
mean = rgv_values["mean"]
|
|
@@ -159,7 +179,9 @@ def training_preset(
|
|
|
159
179
|
return v2.Compose( # type:ignore
|
|
160
180
|
[
|
|
161
181
|
v2.ToImage(),
|
|
162
|
-
get_birder_augment(
|
|
182
|
+
get_birder_augment(
|
|
183
|
+
size, level, fill_value, dynamic_size, multiscale, max_size, multiscale_min_size, post_mosaic
|
|
184
|
+
),
|
|
163
185
|
v2.ToDtype(torch.float32, scale=True),
|
|
164
186
|
v2.Normalize(mean=mean, std=std),
|
|
165
187
|
v2.ToPureTensor(),
|
|
@@ -190,9 +212,7 @@ def training_preset(
|
|
|
190
212
|
return v2.Compose( # type: ignore
|
|
191
213
|
[
|
|
192
214
|
v2.ToImage(),
|
|
193
|
-
v2.RandomShortestSize(
|
|
194
|
-
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=max_size or 1333
|
|
195
|
-
),
|
|
215
|
+
v2.RandomShortestSize(min_size=build_multiscale_sizes(multiscale_min_size), max_size=max_size or 1333),
|
|
196
216
|
v2.RandomHorizontalFlip(0.5),
|
|
197
217
|
v2.SanitizeBoundingBoxes(),
|
|
198
218
|
v2.ToDtype(torch.float32, scale=True),
|
|
@@ -264,21 +284,18 @@ def training_preset(
|
|
|
264
284
|
)
|
|
265
285
|
|
|
266
286
|
if aug_type == "detr":
|
|
287
|
+
multiscale_sizes = build_multiscale_sizes(multiscale_min_size)
|
|
267
288
|
return v2.Compose( # type: ignore
|
|
268
289
|
[
|
|
269
290
|
v2.ToImage(),
|
|
270
291
|
v2.RandomChoice(
|
|
271
292
|
[
|
|
272
|
-
v2.RandomShortestSize(
|
|
273
|
-
(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=max_size or 1333
|
|
274
|
-
),
|
|
293
|
+
v2.RandomShortestSize(min_size=multiscale_sizes, max_size=max_size or 1333),
|
|
275
294
|
v2.Compose(
|
|
276
295
|
[
|
|
277
296
|
v2.RandomShortestSize((400, 500, 600)),
|
|
278
297
|
v2.RandomIoUCrop() if post_mosaic is False else v2.Identity(), # RandomSizeCrop
|
|
279
|
-
v2.RandomShortestSize(
|
|
280
|
-
(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=max_size or 1333
|
|
281
|
-
),
|
|
298
|
+
v2.RandomShortestSize(min_size=multiscale_sizes, max_size=max_size or 1333),
|
|
282
299
|
]
|
|
283
300
|
),
|
|
284
301
|
]
|
birder/inference/detection.py
CHANGED
|
@@ -5,17 +5,99 @@ from typing import Optional
|
|
|
5
5
|
import torch
|
|
6
6
|
import torch.amp
|
|
7
7
|
from PIL import Image
|
|
8
|
+
from torch.nn import functional as F
|
|
8
9
|
from torch.utils.data import DataLoader
|
|
9
10
|
from tqdm import tqdm
|
|
10
11
|
|
|
11
12
|
from birder.conf import settings
|
|
13
|
+
from birder.data.collators.detection import batch_images
|
|
12
14
|
from birder.data.transforms.detection import InferenceTransform
|
|
15
|
+
from birder.inference.wbf import fuse_detections_wbf
|
|
16
|
+
from birder.net.base import make_divisible
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _normalize_image_sizes(inputs: torch.Tensor, image_sizes: Optional[list[list[int]]]) -> list[list[int]]:
|
|
20
|
+
if image_sizes is not None:
|
|
21
|
+
return image_sizes
|
|
22
|
+
|
|
23
|
+
(_, _, height, width) = inputs.shape
|
|
24
|
+
return [[height, width] for _ in range(inputs.size(0))]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _hflip_inputs(inputs: torch.Tensor, image_sizes: list[list[int]]) -> torch.Tensor:
|
|
28
|
+
# Detection collator pads on the right/bottom, so flip only the valid region to keep padding aligned.
|
|
29
|
+
flipped = inputs.clone()
|
|
30
|
+
for idx, (height, width) in enumerate(image_sizes):
|
|
31
|
+
flipped[idx, :, :height, :width] = torch.flip(inputs[idx, :, :height, :width], dims=[2])
|
|
32
|
+
|
|
33
|
+
return flipped
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _resize_batch(
|
|
37
|
+
inputs: torch.Tensor, image_sizes: list[list[int]], scale: float, size_divisible: int
|
|
38
|
+
) -> tuple[torch.Tensor, torch.Tensor, list[list[int]]]:
|
|
39
|
+
resized_images: list[torch.Tensor] = []
|
|
40
|
+
for idx, (height, width) in enumerate(image_sizes):
|
|
41
|
+
target_h = make_divisible(height * scale, size_divisible)
|
|
42
|
+
target_w = make_divisible(width * scale, size_divisible)
|
|
43
|
+
image = inputs[idx, :, :height, :width]
|
|
44
|
+
resized = F.interpolate(image.unsqueeze(0), size=(target_h, target_w), mode="bilinear", align_corners=False)
|
|
45
|
+
resized_images.append(resized.squeeze(0))
|
|
46
|
+
|
|
47
|
+
return batch_images(resized_images, size_divisible)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _rescale_boxes(boxes: torch.Tensor, from_size: list[int], to_size: list[int]) -> torch.Tensor:
|
|
51
|
+
scale_w = to_size[1] / from_size[1]
|
|
52
|
+
scale_h = to_size[0] / from_size[0]
|
|
53
|
+
scale = boxes.new_tensor([scale_w, scale_h, scale_w, scale_h])
|
|
54
|
+
return boxes * scale
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _rescale_detections(
|
|
58
|
+
detections: list[dict[str, torch.Tensor]],
|
|
59
|
+
from_sizes: list[list[int]],
|
|
60
|
+
to_sizes: list[list[int]],
|
|
61
|
+
) -> list[dict[str, torch.Tensor]]:
|
|
62
|
+
for idx, (detection, from_size, to_size) in enumerate(zip(detections, from_sizes, to_sizes)):
|
|
63
|
+
boxes = detection["boxes"]
|
|
64
|
+
if boxes.numel() == 0:
|
|
65
|
+
continue
|
|
66
|
+
|
|
67
|
+
detections[idx]["boxes"] = _rescale_boxes(boxes, from_size, to_size)
|
|
68
|
+
|
|
69
|
+
return detections
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _invert_hflip_boxes(boxes: torch.Tensor, image_size: list[int]) -> torch.Tensor:
|
|
73
|
+
width = boxes.new_tensor(image_size[1])
|
|
74
|
+
x1 = boxes[:, 0]
|
|
75
|
+
x2 = boxes[:, 2]
|
|
76
|
+
flipped = boxes.clone()
|
|
77
|
+
flipped[:, 0] = width - x2
|
|
78
|
+
flipped[:, 2] = width - x1
|
|
79
|
+
|
|
80
|
+
return flipped
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _invert_detections(
|
|
84
|
+
detections: list[dict[str, torch.Tensor]], image_sizes: list[list[int]]
|
|
85
|
+
) -> list[dict[str, torch.Tensor]]:
|
|
86
|
+
for idx, (detection, image_size) in enumerate(zip(detections, image_sizes)):
|
|
87
|
+
boxes = detection["boxes"]
|
|
88
|
+
if boxes.numel() == 0:
|
|
89
|
+
continue
|
|
90
|
+
|
|
91
|
+
detections[idx]["boxes"] = _invert_hflip_boxes(boxes, image_size)
|
|
92
|
+
|
|
93
|
+
return detections
|
|
13
94
|
|
|
14
95
|
|
|
15
96
|
def infer_image(
|
|
16
97
|
net: torch.nn.Module | torch.ScriptModule,
|
|
17
98
|
sample: Image.Image | str,
|
|
18
99
|
transform: Callable[..., torch.Tensor],
|
|
100
|
+
tta: bool = False,
|
|
19
101
|
device: Optional[torch.device] = None,
|
|
20
102
|
score_threshold: Optional[float] = None,
|
|
21
103
|
**kwargs: Any,
|
|
@@ -43,7 +125,7 @@ def infer_image(
|
|
|
43
125
|
device = torch.device("cpu")
|
|
44
126
|
|
|
45
127
|
input_tensor = transform(image).unsqueeze(dim=0).to(device)
|
|
46
|
-
detections = infer_batch(net, input_tensor, **kwargs)
|
|
128
|
+
detections = infer_batch(net, input_tensor, tta=tta, **kwargs)
|
|
47
129
|
if score_threshold is not None:
|
|
48
130
|
for i, detection in enumerate(detections):
|
|
49
131
|
idxs = torch.where(detection["scores"] > score_threshold)
|
|
@@ -63,16 +145,36 @@ def infer_batch(
|
|
|
63
145
|
inputs: torch.Tensor,
|
|
64
146
|
masks: Optional[torch.Tensor] = None,
|
|
65
147
|
image_sizes: Optional[list[list[int]]] = None,
|
|
148
|
+
tta: bool = False,
|
|
66
149
|
**kwargs: Any,
|
|
67
150
|
) -> list[dict[str, torch.Tensor]]:
|
|
68
|
-
|
|
69
|
-
|
|
151
|
+
if tta is False:
|
|
152
|
+
(detections, _) = net(inputs, masks=masks, image_sizes=image_sizes, **kwargs)
|
|
153
|
+
return detections # type: ignore[no-any-return]
|
|
154
|
+
|
|
155
|
+
normalized_sizes = _normalize_image_sizes(inputs, image_sizes)
|
|
156
|
+
detections_list: list[list[dict[str, torch.Tensor]]] = []
|
|
157
|
+
|
|
158
|
+
for scale in (0.8, 1.0, 1.2):
|
|
159
|
+
(scaled_inputs, scaled_masks, scaled_sizes) = _resize_batch(inputs, normalized_sizes, scale, size_divisible=32)
|
|
160
|
+
(detections, _) = net(scaled_inputs, masks=scaled_masks, image_sizes=scaled_sizes, **kwargs)
|
|
161
|
+
detections = _rescale_detections(detections, scaled_sizes, normalized_sizes)
|
|
162
|
+
detections_list.append(detections)
|
|
163
|
+
|
|
164
|
+
flipped_inputs = _hflip_inputs(scaled_inputs, scaled_sizes)
|
|
165
|
+
(flipped_detections, _) = net(flipped_inputs, masks=scaled_masks, image_sizes=scaled_sizes, **kwargs)
|
|
166
|
+
flipped_detections = _invert_detections(flipped_detections, scaled_sizes)
|
|
167
|
+
flipped_detections = _rescale_detections(flipped_detections, scaled_sizes, normalized_sizes)
|
|
168
|
+
detections_list.append(flipped_detections)
|
|
169
|
+
|
|
170
|
+
return fuse_detections_wbf(detections_list, iou_thr=0.55, conf_type="avg")
|
|
70
171
|
|
|
71
172
|
|
|
72
173
|
def infer_dataloader(
|
|
73
174
|
device: torch.device,
|
|
74
175
|
net: torch.nn.Module | torch.ScriptModule,
|
|
75
176
|
dataloader: DataLoader,
|
|
177
|
+
tta: bool = False,
|
|
76
178
|
model_dtype: torch.dtype = torch.float32,
|
|
77
179
|
amp: bool = False,
|
|
78
180
|
amp_dtype: Optional[torch.dtype] = None,
|
|
@@ -97,6 +199,8 @@ def infer_dataloader(
|
|
|
97
199
|
The model to use for inference.
|
|
98
200
|
dataloader
|
|
99
201
|
The DataLoader containing the dataset to perform inference on.
|
|
202
|
+
tta
|
|
203
|
+
Run inference with multi-scale and horizontal flip test time augmentation and fuse results with WBF.
|
|
100
204
|
model_dtype
|
|
101
205
|
The base dtype to use.
|
|
102
206
|
amp
|
|
@@ -142,7 +246,7 @@ def infer_dataloader(
|
|
|
142
246
|
masks = masks.to(device, non_blocking=True)
|
|
143
247
|
|
|
144
248
|
with torch.amp.autocast(device.type, enabled=amp, dtype=amp_dtype):
|
|
145
|
-
detections = infer_batch(net, inputs, masks, image_sizes)
|
|
249
|
+
detections = infer_batch(net, inputs, masks=masks, image_sizes=image_sizes, tta=tta)
|
|
146
250
|
|
|
147
251
|
detections = InferenceTransform.postprocess(detections, image_sizes, orig_sizes)
|
|
148
252
|
if targets[0] != settings.NO_LABEL:
|