birder 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/adversarial/__init__.py +13 -0
- birder/adversarial/base.py +101 -0
- birder/adversarial/deepfool.py +173 -0
- birder/adversarial/fgsm.py +51 -18
- birder/adversarial/pgd.py +79 -28
- birder/adversarial/simba.py +172 -0
- birder/common/lib.py +2 -9
- birder/common/training_cli.py +29 -3
- birder/common/training_utils.py +141 -11
- birder/data/collators/detection.py +10 -3
- birder/data/datasets/coco.py +8 -10
- birder/data/transforms/detection.py +30 -13
- birder/inference/data_parallel.py +1 -2
- birder/inference/detection.py +108 -4
- birder/inference/wbf.py +226 -0
- birder/introspection/__init__.py +10 -6
- birder/introspection/attention_rollout.py +122 -54
- birder/introspection/base.py +73 -29
- birder/introspection/gradcam.py +71 -100
- birder/introspection/guided_backprop.py +146 -72
- birder/introspection/transformer_attribution.py +182 -0
- birder/net/__init__.py +8 -0
- birder/net/detection/deformable_detr.py +14 -12
- birder/net/detection/detr.py +7 -3
- birder/net/detection/efficientdet.py +65 -86
- birder/net/detection/rt_detr_v1.py +4 -3
- birder/net/detection/yolo_anchors.py +205 -0
- birder/net/detection/yolo_v2.py +25 -24
- birder/net/detection/yolo_v3.py +42 -48
- birder/net/detection/yolo_v4.py +31 -40
- birder/net/detection/yolo_v4_tiny.py +24 -20
- birder/net/fasternet.py +1 -1
- birder/net/fastvit.py +1 -1
- birder/net/gc_vit.py +671 -0
- birder/net/lit_v1.py +472 -0
- birder/net/lit_v1_tiny.py +342 -0
- birder/net/lit_v2.py +436 -0
- birder/net/mim/mae_vit.py +7 -8
- birder/net/mobilenet_v4_hybrid.py +1 -1
- birder/net/pit.py +1 -1
- birder/net/resnet_v1.py +95 -35
- birder/net/resnext.py +67 -25
- birder/net/se_resnet_v1.py +46 -0
- birder/net/se_resnext.py +3 -0
- birder/net/simple_vit.py +2 -2
- birder/net/ssl/data2vec.py +1 -1
- birder/net/ssl/data2vec2.py +4 -2
- birder/net/vit.py +0 -15
- birder/net/vovnet_v2.py +31 -1
- birder/results/gui.py +15 -2
- birder/scripts/benchmark.py +90 -21
- birder/scripts/predict.py +1 -0
- birder/scripts/predict_detection.py +48 -9
- birder/scripts/train.py +33 -50
- birder/scripts/train_barlow_twins.py +19 -40
- birder/scripts/train_byol.py +19 -40
- birder/scripts/train_capi.py +21 -43
- birder/scripts/train_data2vec.py +18 -40
- birder/scripts/train_data2vec2.py +18 -40
- birder/scripts/train_detection.py +89 -57
- birder/scripts/train_dino_v1.py +19 -40
- birder/scripts/train_dino_v2.py +18 -40
- birder/scripts/train_dino_v2_dist.py +25 -40
- birder/scripts/train_franca.py +18 -40
- birder/scripts/train_i_jepa.py +25 -46
- birder/scripts/train_ibot.py +18 -40
- birder/scripts/train_kd.py +179 -81
- birder/scripts/train_mim.py +20 -43
- birder/scripts/train_mmcr.py +19 -40
- birder/scripts/train_rotnet.py +19 -40
- birder/scripts/train_simclr.py +19 -40
- birder/scripts/train_vicreg.py +19 -40
- birder/tools/__main__.py +6 -2
- birder/tools/adversarial.py +147 -96
- birder/tools/auto_anchors.py +380 -0
- birder/tools/ensemble_model.py +1 -1
- birder/tools/introspection.py +58 -31
- birder/tools/pack.py +172 -103
- birder/tools/show_det_iterator.py +10 -1
- birder/version.py +1 -1
- {birder-0.2.1.dist-info → birder-0.2.3.dist-info}/METADATA +4 -3
- {birder-0.2.1.dist-info → birder-0.2.3.dist-info}/RECORD +86 -75
- {birder-0.2.1.dist-info → birder-0.2.3.dist-info}/WHEEL +0 -0
- {birder-0.2.1.dist-info → birder-0.2.3.dist-info}/entry_points.txt +0 -0
- {birder-0.2.1.dist-info → birder-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {birder-0.2.1.dist-info → birder-0.2.3.dist-info}/top_level.txt +0 -0
birder/common/lib.py
CHANGED
|
@@ -1,11 +1,7 @@
|
|
|
1
1
|
import os
|
|
2
|
-
import random
|
|
3
2
|
from typing import Any
|
|
4
3
|
from typing import Optional
|
|
5
4
|
|
|
6
|
-
import numpy as np
|
|
7
|
-
import torch
|
|
8
|
-
|
|
9
5
|
from birder.conf import settings
|
|
10
6
|
from birder.data.transforms.classification import RGBType
|
|
11
7
|
from birder.model_registry import registry
|
|
@@ -19,11 +15,8 @@ from birder.net.ssl.base import SSLBaseNet
|
|
|
19
15
|
from birder.version import __version__
|
|
20
16
|
|
|
21
17
|
|
|
22
|
-
def
|
|
23
|
-
|
|
24
|
-
torch.cuda.manual_seed_all(seed)
|
|
25
|
-
np.random.seed(seed)
|
|
26
|
-
random.seed(seed)
|
|
18
|
+
def env_bool(name: str) -> bool:
|
|
19
|
+
return os.environ.get(name, "").lower() in {"1", "true", "yes", "on"}
|
|
27
20
|
|
|
28
21
|
|
|
29
22
|
def get_size_from_signature(signature: SignatureType | DetectionSignatureType) -> tuple[int, int]:
|
birder/common/training_cli.py
CHANGED
|
@@ -5,6 +5,7 @@ import typing
|
|
|
5
5
|
from typing import Optional
|
|
6
6
|
from typing import get_args
|
|
7
7
|
|
|
8
|
+
from birder.common.cli import FlexibleDictAction
|
|
8
9
|
from birder.common.cli import ValidationError
|
|
9
10
|
from birder.common.training_utils import OptimizerType
|
|
10
11
|
from birder.common.training_utils import SchedulerType
|
|
@@ -82,11 +83,23 @@ def add_lr_wd_args(parser: argparse.ArgumentParser, backbone_lr: bool = False, w
|
|
|
82
83
|
metavar="WD",
|
|
83
84
|
help="weight decay for embedding parameters for vision transformer models",
|
|
84
85
|
)
|
|
86
|
+
group.add_argument(
|
|
87
|
+
"--custom-layer-wd",
|
|
88
|
+
action=FlexibleDictAction,
|
|
89
|
+
metavar="LAYER=WD",
|
|
90
|
+
help="custom weight decay for specific layers by name (e.g., offset_conv=0.0)",
|
|
91
|
+
)
|
|
85
92
|
group.add_argument("--layer-decay", type=float, help="layer-wise learning rate decay (LLRD)")
|
|
86
93
|
group.add_argument("--layer-decay-min-scale", type=float, help="minimum layer scale factor clamp value")
|
|
87
94
|
group.add_argument(
|
|
88
95
|
"--layer-decay-no-opt-scale", type=float, help="layer scale threshold below which parameters are frozen"
|
|
89
96
|
)
|
|
97
|
+
group.add_argument(
|
|
98
|
+
"--custom-layer-lr-scale",
|
|
99
|
+
action=FlexibleDictAction,
|
|
100
|
+
metavar="LAYER=SCALE",
|
|
101
|
+
help="custom lr_scale for specific layers by name (e.g., offset_conv=0.01,attention=0.5)",
|
|
102
|
+
)
|
|
90
103
|
|
|
91
104
|
|
|
92
105
|
def add_lr_scheduler_args(parser: argparse.ArgumentParser) -> None:
|
|
@@ -110,10 +123,13 @@ def add_lr_scheduler_args(parser: argparse.ArgumentParser) -> None:
|
|
|
110
123
|
type=int,
|
|
111
124
|
default=40,
|
|
112
125
|
metavar="N",
|
|
113
|
-
help="decrease lr every
|
|
126
|
+
help="decrease lr every N epochs/steps (relative to after warmup, step scheduler only)",
|
|
114
127
|
)
|
|
115
128
|
group.add_argument(
|
|
116
|
-
"--lr-steps",
|
|
129
|
+
"--lr-steps",
|
|
130
|
+
type=int,
|
|
131
|
+
nargs="+",
|
|
132
|
+
help="absolute epoch/step milestones when to decrease lr (multistep scheduler only)",
|
|
117
133
|
)
|
|
118
134
|
group.add_argument(
|
|
119
135
|
"--lr-step-gamma",
|
|
@@ -182,6 +198,11 @@ def add_detection_input_args(parser: argparse.ArgumentParser) -> None:
|
|
|
182
198
|
action="store_true",
|
|
183
199
|
help="enable random square resize once per batch (capped by max(--size))",
|
|
184
200
|
)
|
|
201
|
+
group.add_argument(
|
|
202
|
+
"--multiscale-min-size",
|
|
203
|
+
type=int,
|
|
204
|
+
help="minimum short-edge size for multiscale lists (rounded up to nearest multiple of 32)",
|
|
205
|
+
)
|
|
185
206
|
|
|
186
207
|
|
|
187
208
|
def add_training_schedule_args(parser: argparse.ArgumentParser, default_epochs: int = 100) -> None:
|
|
@@ -391,7 +412,7 @@ def add_ema_args(
|
|
|
391
412
|
"--model-ema-warmup",
|
|
392
413
|
type=int,
|
|
393
414
|
metavar="N",
|
|
394
|
-
help="number of epochs before EMA is applied (defaults to warmup epochs/
|
|
415
|
+
help="number of epochs/steps before EMA is applied (defaults to warmup epochs/steps, pass 0 to disable warmup)",
|
|
395
416
|
)
|
|
396
417
|
|
|
397
418
|
|
|
@@ -656,6 +677,11 @@ def common_args_validation(args: argparse.Namespace) -> None:
|
|
|
656
677
|
f"but it is set to '{args.lr_scheduler_update}'"
|
|
657
678
|
)
|
|
658
679
|
|
|
680
|
+
# EMA
|
|
681
|
+
if hasattr(args, "model_ema_steps") is True:
|
|
682
|
+
if args.model_ema_steps < 1:
|
|
683
|
+
raise ValidationError("--model-ema-steps must be >= 1")
|
|
684
|
+
|
|
659
685
|
# Compile args, argument dependant
|
|
660
686
|
if hasattr(args, "compile_teacher") is True:
|
|
661
687
|
if args.compile is True and args.compile_teacher is True:
|
birder/common/training_utils.py
CHANGED
|
@@ -3,8 +3,10 @@ import contextlib
|
|
|
3
3
|
import logging
|
|
4
4
|
import math
|
|
5
5
|
import os
|
|
6
|
+
import random
|
|
6
7
|
import re
|
|
7
8
|
import subprocess
|
|
9
|
+
import sys
|
|
8
10
|
from collections import deque
|
|
9
11
|
from collections.abc import Callable
|
|
10
12
|
from collections.abc import Generator
|
|
@@ -29,12 +31,25 @@ from birder.data.transforms.classification import training_preset
|
|
|
29
31
|
from birder.optim import Lamb
|
|
30
32
|
from birder.optim import Lars
|
|
31
33
|
from birder.scheduler import CooldownLR
|
|
34
|
+
from birder.version import __version__ as birder_version
|
|
32
35
|
|
|
33
36
|
logger = logging.getLogger(__name__)
|
|
34
37
|
|
|
35
38
|
OptimizerType = Literal["sgd", "rmsprop", "adam", "adamw", "nadam", "nadamw", "lamb", "lambw", "lars"]
|
|
36
39
|
SchedulerType = Literal["constant", "step", "multistep", "cosine", "polynomial"]
|
|
37
40
|
|
|
41
|
+
###############################################################################
|
|
42
|
+
# Core Utilities
|
|
43
|
+
###############################################################################
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def set_random_seeds(seed: int) -> None:
|
|
47
|
+
torch.manual_seed(seed)
|
|
48
|
+
torch.cuda.manual_seed_all(seed)
|
|
49
|
+
np.random.seed(seed)
|
|
50
|
+
random.seed(seed)
|
|
51
|
+
|
|
52
|
+
|
|
38
53
|
###############################################################################
|
|
39
54
|
# Data Sampling
|
|
40
55
|
###############################################################################
|
|
@@ -207,13 +222,16 @@ def count_layers(model: torch.nn.Module) -> int:
|
|
|
207
222
|
def optimizer_parameter_groups(
|
|
208
223
|
model: torch.nn.Module,
|
|
209
224
|
weight_decay: float,
|
|
225
|
+
base_lr: float,
|
|
210
226
|
norm_weight_decay: Optional[float] = None,
|
|
211
227
|
custom_keys_weight_decay: Optional[list[tuple[str, float]]] = None,
|
|
228
|
+
custom_layer_weight_decay: Optional[dict[str, float]] = None,
|
|
212
229
|
layer_decay: Optional[float] = None,
|
|
213
230
|
layer_decay_min_scale: Optional[float] = None,
|
|
214
231
|
layer_decay_no_opt_scale: Optional[float] = None,
|
|
215
232
|
bias_lr: Optional[float] = None,
|
|
216
233
|
backbone_lr: Optional[float] = None,
|
|
234
|
+
custom_layer_lr_scale: Optional[dict[str, float]] = None,
|
|
217
235
|
) -> list[dict[str, Any]]:
|
|
218
236
|
"""
|
|
219
237
|
Return parameter groups for optimizers with per-parameter group weight decay.
|
|
@@ -233,11 +251,16 @@ def optimizer_parameter_groups(
|
|
|
233
251
|
The PyTorch model whose parameters will be grouped for optimization.
|
|
234
252
|
weight_decay
|
|
235
253
|
Default weight decay (L2 regularization) value applied to parameters.
|
|
254
|
+
base_lr
|
|
255
|
+
Base learning rate that will be scaled by lr_scale factors for each parameter group.
|
|
236
256
|
norm_weight_decay
|
|
237
257
|
Weight decay value specifically for normalization layers. If None, uses weight_decay.
|
|
238
258
|
custom_keys_weight_decay
|
|
239
259
|
List of (parameter_name, weight_decay) tuples for applying custom weight decay
|
|
240
260
|
values to specific parameters by name matching.
|
|
261
|
+
custom_layer_weight_decay
|
|
262
|
+
Dictionary mapping layer name substrings to custom weight decay values.
|
|
263
|
+
Applied to parameters whose names contain the specified keys.
|
|
241
264
|
layer_decay
|
|
242
265
|
Layer-wise learning rate decay factor.
|
|
243
266
|
layer_decay_min_scale
|
|
@@ -248,6 +271,9 @@ def optimizer_parameter_groups(
|
|
|
248
271
|
Custom learning rate for bias parameters (parameters ending with '.bias').
|
|
249
272
|
backbone_lr
|
|
250
273
|
Custom learning rate for backbone parameters (parameters starting with 'backbone.').
|
|
274
|
+
custom_layer_lr_scale
|
|
275
|
+
Dictionary mapping layer name substrings to custom lr_scale values.
|
|
276
|
+
Applied to parameters whose names contain the specified keys.
|
|
251
277
|
|
|
252
278
|
Returns
|
|
253
279
|
-------
|
|
@@ -291,14 +317,14 @@ def optimizer_parameter_groups(
|
|
|
291
317
|
if layer_decay is not None:
|
|
292
318
|
layer_max = num_layers - 1
|
|
293
319
|
layer_scales = [max(layer_decay_min_scale, layer_decay ** (layer_max - i)) for i in range(num_layers)]
|
|
294
|
-
logger.info(f"Layer scaling
|
|
320
|
+
logger.info(f"Layer scaling ranges from {min(layer_scales)} to {max(layer_scales)} across {num_layers} layers")
|
|
295
321
|
|
|
296
322
|
# Set weight decay and layer decay
|
|
297
323
|
idx = 0
|
|
298
324
|
params = []
|
|
299
325
|
module_stack_with_prefix = [(model, "")]
|
|
300
326
|
visited_modules = []
|
|
301
|
-
while len(module_stack_with_prefix) > 0:
|
|
327
|
+
while len(module_stack_with_prefix) > 0: # pylint: disable=too-many-nested-blocks
|
|
302
328
|
skip_module = False
|
|
303
329
|
(module, prefix) = module_stack_with_prefix.pop()
|
|
304
330
|
if id(module) in visited_modules:
|
|
@@ -324,13 +350,35 @@ def optimizer_parameter_groups(
|
|
|
324
350
|
for key, custom_wd in custom_keys_weight_decay:
|
|
325
351
|
target_name_for_custom_key = f"{prefix}.{name}" if prefix != "" and "." in key else name
|
|
326
352
|
if key == target_name_for_custom_key:
|
|
353
|
+
# Calculate lr_scale (from layer_decay or custom_layer_lr_scale)
|
|
354
|
+
lr_scale = 1.0 if layer_decay is None else layer_scales[idx]
|
|
355
|
+
if custom_layer_lr_scale is not None:
|
|
356
|
+
for layer_name_key, custom_scale in custom_layer_lr_scale.items():
|
|
357
|
+
if layer_name_key in target_name:
|
|
358
|
+
lr_scale = custom_scale
|
|
359
|
+
break
|
|
360
|
+
|
|
361
|
+
# Apply custom layer weight decay (substring matching)
|
|
362
|
+
wd = custom_wd
|
|
363
|
+
if custom_layer_weight_decay is not None:
|
|
364
|
+
for layer_name_key, custom_wd_value in custom_layer_weight_decay.items():
|
|
365
|
+
if layer_name_key in target_name:
|
|
366
|
+
wd = custom_wd_value
|
|
367
|
+
break
|
|
368
|
+
|
|
327
369
|
d = {
|
|
328
370
|
"params": p,
|
|
329
|
-
"weight_decay":
|
|
330
|
-
"lr_scale":
|
|
371
|
+
"weight_decay": wd,
|
|
372
|
+
"lr_scale": lr_scale, # Used only for reference/debugging
|
|
331
373
|
}
|
|
332
|
-
|
|
374
|
+
|
|
375
|
+
# Apply learning rate based on priority: bias_lr > backbone_lr > lr_scale
|
|
376
|
+
if bias_lr is not None and target_name.endswith(".bias") is True:
|
|
377
|
+
d["lr"] = bias_lr
|
|
378
|
+
elif backbone_lr is not None and target_name.startswith("backbone.") is True:
|
|
333
379
|
d["lr"] = backbone_lr
|
|
380
|
+
elif lr_scale != 1.0:
|
|
381
|
+
d["lr"] = base_lr * lr_scale
|
|
334
382
|
|
|
335
383
|
params.append(d)
|
|
336
384
|
is_custom_key = True
|
|
@@ -342,16 +390,34 @@ def optimizer_parameter_groups(
|
|
|
342
390
|
else:
|
|
343
391
|
wd = weight_decay
|
|
344
392
|
|
|
393
|
+
# Apply custom layer weight decay (substring matching)
|
|
394
|
+
if custom_layer_weight_decay is not None:
|
|
395
|
+
for layer_name_key, custom_wd_value in custom_layer_weight_decay.items():
|
|
396
|
+
if layer_name_key in target_name:
|
|
397
|
+
wd = custom_wd_value
|
|
398
|
+
break
|
|
399
|
+
|
|
400
|
+
# Calculate lr_scale (from layer_decay or custom_layer_lr_scale)
|
|
401
|
+
lr_scale = 1.0 if layer_decay is None else layer_scales[idx]
|
|
402
|
+
if custom_layer_lr_scale is not None:
|
|
403
|
+
for layer_name_key, custom_scale in custom_layer_lr_scale.items():
|
|
404
|
+
if layer_name_key in target_name:
|
|
405
|
+
lr_scale = custom_scale
|
|
406
|
+
break
|
|
407
|
+
|
|
345
408
|
d = {
|
|
346
409
|
"params": p,
|
|
347
410
|
"weight_decay": wd,
|
|
348
|
-
"lr_scale":
|
|
411
|
+
"lr_scale": lr_scale, # Used only for reference/debugging
|
|
349
412
|
}
|
|
350
|
-
if backbone_lr is not None and target_name.startswith("backbone.") is True:
|
|
351
|
-
d["lr"] = backbone_lr
|
|
352
413
|
|
|
414
|
+
# Apply learning rate based on priority: bias_lr > backbone_lr > lr_scale
|
|
353
415
|
if bias_lr is not None and target_name.endswith(".bias") is True:
|
|
354
416
|
d["lr"] = bias_lr
|
|
417
|
+
elif backbone_lr is not None and target_name.startswith("backbone.") is True:
|
|
418
|
+
d["lr"] = backbone_lr
|
|
419
|
+
elif lr_scale != 1.0:
|
|
420
|
+
d["lr"] = base_lr * lr_scale
|
|
355
421
|
|
|
356
422
|
params.append(d)
|
|
357
423
|
|
|
@@ -442,6 +508,8 @@ def get_optimizer(parameters: list[dict[str, Any]], l_rate: float, args: argpars
|
|
|
442
508
|
else:
|
|
443
509
|
raise ValueError("Unknown optimizer")
|
|
444
510
|
|
|
511
|
+
logger.debug(f"Created {opt} optimizer with lr={lr}, weight_decay={args.wd}")
|
|
512
|
+
|
|
445
513
|
return optimizer
|
|
446
514
|
|
|
447
515
|
|
|
@@ -477,10 +545,10 @@ def get_scheduler(
|
|
|
477
545
|
|
|
478
546
|
main_steps = steps - begin_step - remaining_warmup - remaining_cooldown - 1
|
|
479
547
|
|
|
480
|
-
logger.debug(f"
|
|
548
|
+
logger.debug(f"Scheduler using {steps_per_epoch} steps per epoch")
|
|
481
549
|
logger.debug(
|
|
482
550
|
f"Scheduler {args.lr_scheduler} set for {steps} steps of which {warmup_steps} "
|
|
483
|
-
f"are warmup and {cooldown_steps} cooldown"
|
|
551
|
+
f"are warmup and {cooldown_steps} are cooldown"
|
|
484
552
|
)
|
|
485
553
|
logger.debug(
|
|
486
554
|
f"Currently starting from step {begin_step} with {remaining_warmup} remaining warmup steps "
|
|
@@ -491,12 +559,29 @@ def get_scheduler(
|
|
|
491
559
|
if args.lr_scheduler == "constant":
|
|
492
560
|
main_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0, total_iters=1)
|
|
493
561
|
elif args.lr_scheduler == "step":
|
|
562
|
+
# Note: StepLR step_size is relative to when the main scheduler starts (after warmup)
|
|
563
|
+
# This means drops occur relative to the end of warmup, not at absolute epoch numbers
|
|
494
564
|
main_scheduler = torch.optim.lr_scheduler.StepLR(
|
|
495
565
|
optimizer, step_size=args.lr_step_size, gamma=args.lr_step_gamma
|
|
496
566
|
)
|
|
497
567
|
elif args.lr_scheduler == "multistep":
|
|
568
|
+
# For MultiStepLR, milestones should be absolute step numbers
|
|
569
|
+
# Adjust them to be relative to when the main scheduler starts (after warmup)
|
|
570
|
+
# This ensures drops occur at the specified absolute steps, not relative to after warmup
|
|
571
|
+
adjusted_milestones = [m - warmup_steps for m in args.lr_steps if m >= warmup_steps]
|
|
572
|
+
if len(adjusted_milestones) == 0:
|
|
573
|
+
logger.debug(
|
|
574
|
+
f"All MultiStepLR milestones {args.lr_steps} are before warmup "
|
|
575
|
+
f"(warmup ends at step {warmup_steps}). Using empty milestone list."
|
|
576
|
+
)
|
|
577
|
+
adjusted_milestones = []
|
|
578
|
+
|
|
579
|
+
logger.debug(
|
|
580
|
+
f"MultiStepLR milestones adjusted from {args.lr_steps} to {adjusted_milestones} "
|
|
581
|
+
f"(relative to main scheduler start after {warmup_steps} warmup steps)"
|
|
582
|
+
)
|
|
498
583
|
main_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
|
499
|
-
optimizer, milestones=
|
|
584
|
+
optimizer, milestones=adjusted_milestones, gamma=args.lr_step_gamma
|
|
500
585
|
)
|
|
501
586
|
elif args.lr_scheduler == "cosine":
|
|
502
587
|
main_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
|
@@ -793,6 +878,51 @@ def is_local_primary(args: argparse.Namespace) -> bool:
|
|
|
793
878
|
return args.local_rank == 0 # type: ignore[no-any-return]
|
|
794
879
|
|
|
795
880
|
|
|
881
|
+
def init_training(
|
|
882
|
+
args: argparse.Namespace,
|
|
883
|
+
log: logging.Logger,
|
|
884
|
+
*,
|
|
885
|
+
cudnn_dynamic_size: bool = False,
|
|
886
|
+
) -> tuple[torch.device, int, bool]:
|
|
887
|
+
init_distributed_mode(args)
|
|
888
|
+
|
|
889
|
+
log.info(f"Starting training, birder version: {birder_version}, pytorch version: {torch.__version__}")
|
|
890
|
+
|
|
891
|
+
log_git_info()
|
|
892
|
+
|
|
893
|
+
if args.cpu is True:
|
|
894
|
+
device = torch.device("cpu")
|
|
895
|
+
device_id = 0
|
|
896
|
+
else:
|
|
897
|
+
device = torch.device("cuda")
|
|
898
|
+
device_id = torch.cuda.current_device()
|
|
899
|
+
|
|
900
|
+
if args.use_deterministic_algorithms is True:
|
|
901
|
+
torch.backends.cudnn.benchmark = False
|
|
902
|
+
torch.use_deterministic_algorithms(True)
|
|
903
|
+
elif cudnn_dynamic_size is True:
|
|
904
|
+
# Dynamic sizes: avoid per-size algorithm selection overhead.
|
|
905
|
+
torch.backends.cudnn.enabled = False
|
|
906
|
+
else:
|
|
907
|
+
torch.backends.cudnn.enabled = True
|
|
908
|
+
torch.backends.cudnn.benchmark = True
|
|
909
|
+
|
|
910
|
+
if args.seed is not None:
|
|
911
|
+
set_random_seeds(args.seed)
|
|
912
|
+
|
|
913
|
+
if args.non_interactive is True or is_local_primary(args) is False:
|
|
914
|
+
disable_tqdm = True
|
|
915
|
+
elif sys.stderr.isatty() is False:
|
|
916
|
+
disable_tqdm = True
|
|
917
|
+
else:
|
|
918
|
+
disable_tqdm = False
|
|
919
|
+
|
|
920
|
+
# Enable or disable the autograd anomaly detection.
|
|
921
|
+
torch.autograd.set_detect_anomaly(args.grad_anomaly_detection)
|
|
922
|
+
|
|
923
|
+
return (device, device_id, disable_tqdm)
|
|
924
|
+
|
|
925
|
+
|
|
796
926
|
###############################################################################
|
|
797
927
|
# Utility Functions
|
|
798
928
|
###############################################################################
|
|
@@ -1,13 +1,14 @@
|
|
|
1
1
|
import math
|
|
2
2
|
import random
|
|
3
3
|
from typing import Any
|
|
4
|
+
from typing import Optional
|
|
4
5
|
|
|
5
6
|
import torch
|
|
6
7
|
from torchvision import tv_tensors
|
|
7
8
|
from torchvision.transforms import v2
|
|
8
9
|
from torchvision.transforms.v2 import functional as F
|
|
9
10
|
|
|
10
|
-
|
|
11
|
+
from birder.data.transforms.detection import build_multiscale_sizes
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
def collate_fn(batch: list[tuple[Any, ...]]) -> tuple[Any, ...]:
|
|
@@ -63,13 +64,19 @@ class DetectionCollator:
|
|
|
63
64
|
|
|
64
65
|
|
|
65
66
|
class BatchRandomResizeCollator(DetectionCollator):
|
|
66
|
-
def __init__(
|
|
67
|
+
def __init__(
|
|
68
|
+
self,
|
|
69
|
+
input_offset: int,
|
|
70
|
+
size: tuple[int, int],
|
|
71
|
+
size_divisible: int = 32,
|
|
72
|
+
multiscale_min_size: Optional[int] = None,
|
|
73
|
+
) -> None:
|
|
67
74
|
super().__init__(input_offset, size_divisible=size_divisible)
|
|
68
75
|
if size is None:
|
|
69
76
|
raise ValueError("size must be provided for batch multiscale")
|
|
70
77
|
|
|
71
78
|
max_side = max(size)
|
|
72
|
-
sizes = [side for side in
|
|
79
|
+
sizes = [side for side in build_multiscale_sizes(multiscale_min_size) if side <= max_side]
|
|
73
80
|
if len(sizes) == 0:
|
|
74
81
|
sizes = [max_side]
|
|
75
82
|
|
birder/data/datasets/coco.py
CHANGED
|
@@ -98,10 +98,14 @@ class CocoTraining(CocoBase):
|
|
|
98
98
|
class CocoInference(CocoBase):
|
|
99
99
|
def __getitem__(self, index: int) -> tuple[str, torch.Tensor, Any, list[int]]:
|
|
100
100
|
coco_id = self.dataset.ids[index]
|
|
101
|
-
|
|
101
|
+
img_info = self.dataset.coco.loadImgs(coco_id)[0]
|
|
102
|
+
path = img_info["file_name"]
|
|
102
103
|
(sample, labels) = self.dataset[index]
|
|
103
104
|
|
|
104
|
-
|
|
105
|
+
# Get original image size (height, width) before transforms
|
|
106
|
+
orig_size = [img_info["height"], img_info["width"]]
|
|
107
|
+
|
|
108
|
+
return (path, sample, labels, orig_size)
|
|
105
109
|
|
|
106
110
|
|
|
107
111
|
class CocoMosaicTraining(CocoBase):
|
|
@@ -127,9 +131,7 @@ class CocoMosaicTraining(CocoBase):
|
|
|
127
131
|
self._mosaic_decay_epochs: Optional[int] = None
|
|
128
132
|
self._mosaic_decay_start: Optional[int] = None
|
|
129
133
|
|
|
130
|
-
def configure_mosaic_linear_decay(
|
|
131
|
-
self, base_prob: float, total_epochs: int, decay_fraction: float = 0.1
|
|
132
|
-
) -> None:
|
|
134
|
+
def configure_mosaic_linear_decay(self, base_prob: float, total_epochs: int, decay_fraction: float = 0.1) -> None:
|
|
133
135
|
if total_epochs <= 0:
|
|
134
136
|
raise ValueError("total_epochs must be positive")
|
|
135
137
|
if decay_fraction <= 0.0 or decay_fraction > 1.0:
|
|
@@ -141,11 +143,7 @@ class CocoMosaicTraining(CocoBase):
|
|
|
141
143
|
self._mosaic_decay_start = max(1, total_epochs - decay_epochs + 1)
|
|
142
144
|
|
|
143
145
|
def update_mosaic_prob(self, epoch: int) -> Optional[float]:
|
|
144
|
-
if
|
|
145
|
-
self._mosaic_base_prob is None
|
|
146
|
-
or self._mosaic_decay_epochs is None
|
|
147
|
-
or self._mosaic_decay_start is None
|
|
148
|
-
):
|
|
146
|
+
if self._mosaic_base_prob is None or self._mosaic_decay_epochs is None or self._mosaic_decay_start is None:
|
|
149
147
|
return None
|
|
150
148
|
|
|
151
149
|
if epoch >= self._mosaic_decay_start:
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import math
|
|
1
2
|
import random
|
|
2
3
|
from collections.abc import Callable
|
|
3
4
|
from typing import Any
|
|
@@ -10,6 +11,24 @@ from torchvision.transforms import v2
|
|
|
10
11
|
|
|
11
12
|
from birder.data.transforms.classification import RGBType
|
|
12
13
|
|
|
14
|
+
MULTISCALE_STEP = 32
|
|
15
|
+
DEFAULT_MULTISCALE_MIN_SIZE = 480
|
|
16
|
+
DEFAULT_MULTISCALE_MAX_SIZE = 800
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def build_multiscale_sizes(
|
|
20
|
+
min_size: Optional[int] = None, max_size: int = DEFAULT_MULTISCALE_MAX_SIZE
|
|
21
|
+
) -> tuple[int, ...]:
|
|
22
|
+
if min_size is None:
|
|
23
|
+
min_size = DEFAULT_MULTISCALE_MIN_SIZE
|
|
24
|
+
|
|
25
|
+
start = int(math.ceil(min_size / MULTISCALE_STEP) * MULTISCALE_STEP)
|
|
26
|
+
end = int(math.floor(max_size / MULTISCALE_STEP) * MULTISCALE_STEP)
|
|
27
|
+
if end < start:
|
|
28
|
+
return (start,)
|
|
29
|
+
|
|
30
|
+
return tuple(range(start, end + 1, MULTISCALE_STEP))
|
|
31
|
+
|
|
13
32
|
|
|
14
33
|
class ResizeWithRandomInterpolation(nn.Module):
|
|
15
34
|
def __init__(
|
|
@@ -39,6 +58,7 @@ def get_birder_augment(
|
|
|
39
58
|
dynamic_size: bool,
|
|
40
59
|
multiscale: bool,
|
|
41
60
|
max_size: Optional[int],
|
|
61
|
+
multiscale_min_size: Optional[int],
|
|
42
62
|
post_mosaic: bool = False,
|
|
43
63
|
) -> Callable[..., torch.Tensor]:
|
|
44
64
|
if dynamic_size is True:
|
|
@@ -78,9 +98,7 @@ def get_birder_augment(
|
|
|
78
98
|
# Resize
|
|
79
99
|
if multiscale is True:
|
|
80
100
|
transformations.append(
|
|
81
|
-
v2.RandomShortestSize(
|
|
82
|
-
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=max_size or 1333
|
|
83
|
-
),
|
|
101
|
+
v2.RandomShortestSize(min_size=build_multiscale_sizes(multiscale_min_size), max_size=max_size or 1333),
|
|
84
102
|
)
|
|
85
103
|
else:
|
|
86
104
|
transformations.append(
|
|
@@ -132,6 +150,7 @@ def get_birder_augment(
|
|
|
132
150
|
AugType = Literal["birder", "lsj", "multiscale", "ssd", "ssdlite", "yolo", "detr"]
|
|
133
151
|
|
|
134
152
|
|
|
153
|
+
# pylint: disable=too-many-return-statements
|
|
135
154
|
def training_preset(
|
|
136
155
|
size: tuple[int, int],
|
|
137
156
|
aug_type: AugType,
|
|
@@ -140,6 +159,7 @@ def training_preset(
|
|
|
140
159
|
dynamic_size: bool = False,
|
|
141
160
|
multiscale: bool = False,
|
|
142
161
|
max_size: Optional[int] = None,
|
|
162
|
+
multiscale_min_size: Optional[int] = None,
|
|
143
163
|
post_mosaic: bool = False,
|
|
144
164
|
) -> Callable[..., torch.Tensor]:
|
|
145
165
|
mean = rgv_values["mean"]
|
|
@@ -159,7 +179,9 @@ def training_preset(
|
|
|
159
179
|
return v2.Compose( # type:ignore
|
|
160
180
|
[
|
|
161
181
|
v2.ToImage(),
|
|
162
|
-
get_birder_augment(
|
|
182
|
+
get_birder_augment(
|
|
183
|
+
size, level, fill_value, dynamic_size, multiscale, max_size, multiscale_min_size, post_mosaic
|
|
184
|
+
),
|
|
163
185
|
v2.ToDtype(torch.float32, scale=True),
|
|
164
186
|
v2.Normalize(mean=mean, std=std),
|
|
165
187
|
v2.ToPureTensor(),
|
|
@@ -190,9 +212,7 @@ def training_preset(
|
|
|
190
212
|
return v2.Compose( # type: ignore
|
|
191
213
|
[
|
|
192
214
|
v2.ToImage(),
|
|
193
|
-
v2.RandomShortestSize(
|
|
194
|
-
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=max_size or 1333
|
|
195
|
-
),
|
|
215
|
+
v2.RandomShortestSize(min_size=build_multiscale_sizes(multiscale_min_size), max_size=max_size or 1333),
|
|
196
216
|
v2.RandomHorizontalFlip(0.5),
|
|
197
217
|
v2.SanitizeBoundingBoxes(),
|
|
198
218
|
v2.ToDtype(torch.float32, scale=True),
|
|
@@ -264,21 +284,18 @@ def training_preset(
|
|
|
264
284
|
)
|
|
265
285
|
|
|
266
286
|
if aug_type == "detr":
|
|
287
|
+
multiscale_sizes = build_multiscale_sizes(multiscale_min_size)
|
|
267
288
|
return v2.Compose( # type: ignore
|
|
268
289
|
[
|
|
269
290
|
v2.ToImage(),
|
|
270
291
|
v2.RandomChoice(
|
|
271
292
|
[
|
|
272
|
-
v2.RandomShortestSize(
|
|
273
|
-
(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=max_size or 1333
|
|
274
|
-
),
|
|
293
|
+
v2.RandomShortestSize(min_size=multiscale_sizes, max_size=max_size or 1333),
|
|
275
294
|
v2.Compose(
|
|
276
295
|
[
|
|
277
296
|
v2.RandomShortestSize((400, 500, 600)),
|
|
278
297
|
v2.RandomIoUCrop() if post_mosaic is False else v2.Identity(), # RandomSizeCrop
|
|
279
|
-
v2.RandomShortestSize(
|
|
280
|
-
(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=max_size or 1333
|
|
281
|
-
),
|
|
298
|
+
v2.RandomShortestSize(min_size=multiscale_sizes, max_size=max_size or 1333),
|
|
282
299
|
]
|
|
283
300
|
),
|
|
284
301
|
]
|
|
@@ -1,8 +1,7 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Inference-optimized multi-GPU parallelization
|
|
3
3
|
|
|
4
|
-
This module provides InferenceDataParallel, an inference-specific alternative to
|
|
5
|
-
torch.nn.DataParallel.
|
|
4
|
+
This module provides InferenceDataParallel, an inference-specific alternative to torch.nn.DataParallel.
|
|
6
5
|
"""
|
|
7
6
|
|
|
8
7
|
import copy
|