birder 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/__init__.py +2 -0
- birder/common/fs_ops.py +81 -1
- birder/common/training_cli.py +12 -2
- birder/common/training_utils.py +73 -12
- birder/data/collators/detection.py +3 -1
- birder/datahub/_lib.py +15 -6
- birder/datahub/evaluation.py +591 -0
- birder/eval/__init__.py +0 -0
- birder/eval/__main__.py +74 -0
- birder/eval/_embeddings.py +50 -0
- birder/eval/adversarial.py +315 -0
- birder/eval/benchmarks/__init__.py +0 -0
- birder/eval/benchmarks/awa2.py +357 -0
- birder/eval/benchmarks/bioscan5m.py +198 -0
- birder/eval/benchmarks/fishnet.py +318 -0
- birder/eval/benchmarks/flowers102.py +210 -0
- birder/eval/benchmarks/fungiclef.py +261 -0
- birder/eval/benchmarks/nabirds.py +202 -0
- birder/eval/benchmarks/newt.py +262 -0
- birder/eval/benchmarks/plankton.py +255 -0
- birder/eval/benchmarks/plantdoc.py +259 -0
- birder/eval/benchmarks/plantnet.py +252 -0
- birder/eval/classification.py +235 -0
- birder/eval/methods/__init__.py +0 -0
- birder/eval/methods/ami.py +78 -0
- birder/eval/methods/knn.py +71 -0
- birder/eval/methods/linear.py +152 -0
- birder/eval/methods/mlp.py +178 -0
- birder/eval/methods/simpleshot.py +100 -0
- birder/eval/methods/svm.py +92 -0
- birder/inference/classification.py +23 -2
- birder/inference/detection.py +35 -15
- birder/net/_vit_configs.py +5 -0
- birder/net/cait.py +3 -3
- birder/net/coat.py +3 -3
- birder/net/cswin_transformer.py +2 -1
- birder/net/deit.py +1 -1
- birder/net/deit3.py +1 -1
- birder/net/detection/__init__.py +2 -0
- birder/net/detection/base.py +41 -18
- birder/net/detection/deformable_detr.py +74 -50
- birder/net/detection/detr.py +29 -26
- birder/net/detection/efficientdet.py +42 -25
- birder/net/detection/faster_rcnn.py +53 -21
- birder/net/detection/fcos.py +42 -23
- birder/net/detection/lw_detr.py +1204 -0
- birder/net/detection/plain_detr.py +60 -47
- birder/net/detection/retinanet.py +47 -35
- birder/net/detection/rt_detr_v1.py +49 -46
- birder/net/detection/rt_detr_v2.py +95 -102
- birder/net/detection/ssd.py +47 -31
- birder/net/detection/ssdlite.py +2 -2
- birder/net/detection/yolo_v2.py +33 -18
- birder/net/detection/yolo_v3.py +35 -33
- birder/net/detection/yolo_v4.py +35 -20
- birder/net/detection/yolo_v4_tiny.py +1 -2
- birder/net/edgevit.py +3 -3
- birder/net/efficientvit_msft.py +1 -1
- birder/net/flexivit.py +1 -1
- birder/net/hiera.py +44 -67
- birder/net/hieradet.py +2 -2
- birder/net/maxvit.py +2 -2
- birder/net/mim/fcmae.py +2 -2
- birder/net/mim/mae_hiera.py +9 -16
- birder/net/mnasnet.py +2 -2
- birder/net/nextvit.py +4 -4
- birder/net/resnext.py +2 -2
- birder/net/rope_deit3.py +2 -2
- birder/net/rope_flexivit.py +2 -2
- birder/net/rope_vit.py +2 -2
- birder/net/simple_vit.py +1 -1
- birder/net/squeezenet.py +1 -1
- birder/net/ssl/capi.py +32 -25
- birder/net/ssl/dino_v2.py +12 -15
- birder/net/ssl/franca.py +26 -19
- birder/net/van.py +2 -2
- birder/net/vit.py +21 -3
- birder/net/vit_parallel.py +1 -1
- birder/net/vit_sam.py +62 -16
- birder/net/xcit.py +1 -1
- birder/ops/msda.py +46 -16
- birder/scripts/benchmark.py +35 -8
- birder/scripts/predict.py +14 -1
- birder/scripts/predict_detection.py +7 -1
- birder/scripts/train.py +27 -11
- birder/scripts/train_capi.py +13 -10
- birder/scripts/train_detection.py +18 -7
- birder/scripts/train_franca.py +10 -2
- birder/scripts/train_kd.py +28 -11
- birder/tools/adversarial.py +5 -0
- birder/tools/convert_model.py +101 -43
- birder/tools/quantize_model.py +33 -16
- birder/version.py +1 -1
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/METADATA +17 -10
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/RECORD +99 -75
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/WHEEL +1 -1
- birder/scripts/evaluate.py +0 -176
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/entry_points.txt +0 -0
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/licenses/LICENSE +0 -0
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/top_level.txt +0 -0
birder/__init__.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from birder.common.fs_ops import load_model_with_cfg
|
|
2
2
|
from birder.common.fs_ops import load_pretrained_model
|
|
3
|
+
from birder.common.fs_ops import load_pretrained_model_and_transform
|
|
3
4
|
from birder.common.lib import get_channels_from_signature
|
|
4
5
|
from birder.common.lib import get_size_from_signature
|
|
5
6
|
from birder.data.transforms.classification import inference_preset as classification_transform
|
|
@@ -17,5 +18,6 @@ __all__ = [
|
|
|
17
18
|
"list_pretrained_models",
|
|
18
19
|
"load_model_with_cfg",
|
|
19
20
|
"load_pretrained_model",
|
|
21
|
+
"load_pretrained_model_and_transform",
|
|
20
22
|
"__version__",
|
|
21
23
|
]
|
birder/common/fs_ops.py
CHANGED
|
@@ -2,6 +2,7 @@ import json
|
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
4
|
import re
|
|
5
|
+
from collections.abc import Callable
|
|
5
6
|
from collections.abc import Iterator
|
|
6
7
|
from pathlib import Path
|
|
7
8
|
from typing import Any
|
|
@@ -24,6 +25,8 @@ from birder.common.lib import get_network_name
|
|
|
24
25
|
from birder.common.lib import get_pretrained_model_url
|
|
25
26
|
from birder.conf import settings
|
|
26
27
|
from birder.data.transforms.classification import RGBType
|
|
28
|
+
from birder.data.transforms.classification import inference_preset
|
|
29
|
+
from birder.data.transforms.detection import InferenceTransform
|
|
27
30
|
from birder.model_registry import Task
|
|
28
31
|
from birder.model_registry import registry
|
|
29
32
|
from birder.model_registry.manifest import FileFormatType
|
|
@@ -801,7 +804,8 @@ def load_detection_model(
|
|
|
801
804
|
for param in net.parameters():
|
|
802
805
|
param.requires_grad_(False)
|
|
803
806
|
|
|
804
|
-
|
|
807
|
+
if pt2 is False: # NOTE: Remove when GraphModule add support for 'eval'
|
|
808
|
+
net.eval()
|
|
805
809
|
|
|
806
810
|
if len(backbone_loaded_config) == 0:
|
|
807
811
|
backbone_custom_config = None
|
|
@@ -918,6 +922,82 @@ def load_pretrained_model(
|
|
|
918
922
|
raise ValueError(f"Unknown model type: {model_metadata['task']}")
|
|
919
923
|
|
|
920
924
|
|
|
925
|
+
def load_pretrained_model_and_transform(
|
|
926
|
+
weights: str,
|
|
927
|
+
*,
|
|
928
|
+
dst: Optional[str | Path] = None,
|
|
929
|
+
file_format: FileFormatType = "pt",
|
|
930
|
+
inference: bool = True,
|
|
931
|
+
device: Optional[torch.device] = None,
|
|
932
|
+
dtype: Optional[torch.dtype] = None,
|
|
933
|
+
custom_config: Optional[dict[str, Any]] = None,
|
|
934
|
+
progress_bar: bool = True,
|
|
935
|
+
classification_kwargs: Optional[dict[str, Any]] = None,
|
|
936
|
+
detection_kwargs: Optional[dict[str, Any]] = None,
|
|
937
|
+
) -> tuple[BaseNet | DetectionBaseNet, ModelInfo | DetectionModelInfo, Callable[..., torch.Tensor]]:
|
|
938
|
+
"""
|
|
939
|
+
Loads a pre-trained model and builds the matching inference transform
|
|
940
|
+
|
|
941
|
+
This is a convenience helper for the common inference path where the model and
|
|
942
|
+
its default preprocessing are needed together. Classification models use
|
|
943
|
+
inference_preset, detection models use InferenceTransform.
|
|
944
|
+
|
|
945
|
+
Parameters
|
|
946
|
+
----------
|
|
947
|
+
weights
|
|
948
|
+
Name of the pre-trained weights to load from the model registry.
|
|
949
|
+
dst
|
|
950
|
+
Destination path where the model weights will be downloaded or loaded from.
|
|
951
|
+
file_format
|
|
952
|
+
Model format (e.g. pt, pt2, safetensors, etc.)
|
|
953
|
+
inference
|
|
954
|
+
Flag to prepare the model for inference mode.
|
|
955
|
+
device
|
|
956
|
+
The device to load the model on (cpu/cuda).
|
|
957
|
+
dtype
|
|
958
|
+
Data type for model parameters and computations (e.g., torch.float32, torch.float16).
|
|
959
|
+
custom_config
|
|
960
|
+
Additional model configuration that overrides or extends the predefined configuration.
|
|
961
|
+
progress_bar
|
|
962
|
+
Whether to display a progress bar during file download.
|
|
963
|
+
classification_kwargs
|
|
964
|
+
Optional keyword arguments forwarded to inference_preset.
|
|
965
|
+
detection_kwargs
|
|
966
|
+
Optional keyword arguments forwarded to InferenceTransform. If dynamic_size is
|
|
967
|
+
not provided it defaults to the model signature value.
|
|
968
|
+
|
|
969
|
+
Returns
|
|
970
|
+
-------
|
|
971
|
+
A tuple containing three elements:
|
|
972
|
+
- A PyTorch module (neural network model) loaded with pre-trained weights.
|
|
973
|
+
- Model info containing class mappings, signature, and RGB stats.
|
|
974
|
+
- An inference transform matching the model task.
|
|
975
|
+
"""
|
|
976
|
+
|
|
977
|
+
net, model_info = load_pretrained_model(
|
|
978
|
+
weights,
|
|
979
|
+
dst=dst,
|
|
980
|
+
file_format=file_format,
|
|
981
|
+
inference=inference,
|
|
982
|
+
device=device,
|
|
983
|
+
dtype=dtype,
|
|
984
|
+
custom_config=custom_config,
|
|
985
|
+
progress_bar=progress_bar,
|
|
986
|
+
)
|
|
987
|
+
|
|
988
|
+
size = lib.get_size_from_signature(model_info.signature)
|
|
989
|
+
transform: Callable[..., torch.Tensor]
|
|
990
|
+
if isinstance(model_info, DetectionModelInfo):
|
|
991
|
+
detection_args = {} if detection_kwargs is None else dict(detection_kwargs)
|
|
992
|
+
detection_args.setdefault("dynamic_size", model_info.signature["dynamic"])
|
|
993
|
+
transform = InferenceTransform(size, model_info.rgb_stats, **detection_args)
|
|
994
|
+
else:
|
|
995
|
+
classification_args = {} if classification_kwargs is None else dict(classification_kwargs)
|
|
996
|
+
transform = inference_preset(size, model_info.rgb_stats, **classification_args)
|
|
997
|
+
|
|
998
|
+
return (net, model_info, transform)
|
|
999
|
+
|
|
1000
|
+
|
|
921
1001
|
def load_model_with_cfg(
|
|
922
1002
|
cfg: dict[str, Any] | str | Path, weights_path: Optional[str | Path]
|
|
923
1003
|
) -> tuple[torch.nn.Module, dict[str, Any]]:
|
birder/common/training_cli.py
CHANGED
|
@@ -56,7 +56,9 @@ def add_optimization_args(parser: argparse.ArgumentParser, default_batch_size: i
|
|
|
56
56
|
)
|
|
57
57
|
|
|
58
58
|
|
|
59
|
-
def add_lr_wd_args(
|
|
59
|
+
def add_lr_wd_args(
|
|
60
|
+
parser: argparse.ArgumentParser, backbone_lr: bool = False, wd_end: bool = False, backbone_layer_decay: bool = False
|
|
61
|
+
) -> None:
|
|
60
62
|
group = parser.add_argument_group("Learning rate and regularization parameters")
|
|
61
63
|
group.add_argument("--lr", type=float, default=0.1, metavar="LR", help="base learning rate")
|
|
62
64
|
group.add_argument("--bias-lr", type=float, metavar="LR", help="learning rate of biases")
|
|
@@ -92,6 +94,9 @@ def add_lr_wd_args(parser: argparse.ArgumentParser, backbone_lr: bool = False, w
|
|
|
92
94
|
help="custom weight decay for specific layers by name (e.g., offset_conv=0.0)",
|
|
93
95
|
)
|
|
94
96
|
group.add_argument("--layer-decay", type=float, help="layer-wise learning rate decay (LLRD)")
|
|
97
|
+
if backbone_layer_decay is True:
|
|
98
|
+
group.add_argument("--backbone-layer-decay", type=float, help="backbone layer-wise learning rate decay (LLRD)")
|
|
99
|
+
|
|
95
100
|
group.add_argument("--layer-decay-min-scale", type=float, help="minimum layer scale factor clamp value")
|
|
96
101
|
group.add_argument(
|
|
97
102
|
"--layer-decay-no-opt-scale", type=float, help="layer scale threshold below which parameters are frozen"
|
|
@@ -480,8 +485,13 @@ def add_dataloader_args(
|
|
|
480
485
|
)
|
|
481
486
|
|
|
482
487
|
|
|
483
|
-
def add_precision_args(parser: argparse.ArgumentParser) -> None:
|
|
488
|
+
def add_precision_args(parser: argparse.ArgumentParser, channels_last: bool = False) -> None:
|
|
484
489
|
group = parser.add_argument_group("Precision parameters")
|
|
490
|
+
if channels_last is True:
|
|
491
|
+
group.add_argument(
|
|
492
|
+
"--channels-last", default=False, action="store_true", help="use channels-last memory format"
|
|
493
|
+
)
|
|
494
|
+
|
|
485
495
|
group.add_argument(
|
|
486
496
|
"--model-dtype",
|
|
487
497
|
type=str,
|
birder/common/training_utils.py
CHANGED
|
@@ -343,7 +343,7 @@ def count_layers(model: torch.nn.Module) -> int:
|
|
|
343
343
|
return num_layers
|
|
344
344
|
|
|
345
345
|
|
|
346
|
-
# pylint: disable=protected-access,too-many-locals,too-many-branches
|
|
346
|
+
# pylint: disable=protected-access,too-many-locals,too-many-branches,too-many-statements
|
|
347
347
|
def optimizer_parameter_groups(
|
|
348
348
|
model: torch.nn.Module,
|
|
349
349
|
weight_decay: float,
|
|
@@ -352,6 +352,7 @@ def optimizer_parameter_groups(
|
|
|
352
352
|
custom_keys_weight_decay: Optional[list[tuple[str, float]]] = None,
|
|
353
353
|
custom_layer_weight_decay: Optional[dict[str, float]] = None,
|
|
354
354
|
layer_decay: Optional[float] = None,
|
|
355
|
+
backbone_layer_decay: Optional[float] = None,
|
|
355
356
|
layer_decay_min_scale: Optional[float] = None,
|
|
356
357
|
layer_decay_no_opt_scale: Optional[float] = None,
|
|
357
358
|
bias_lr: Optional[float] = None,
|
|
@@ -388,6 +389,8 @@ def optimizer_parameter_groups(
|
|
|
388
389
|
Applied to parameters whose names contain the specified keys.
|
|
389
390
|
layer_decay
|
|
390
391
|
Layer-wise learning rate decay factor.
|
|
392
|
+
backbone_layer_decay
|
|
393
|
+
Layer-wise learning rate decay factor for backbone parameters only.
|
|
391
394
|
layer_decay_min_scale
|
|
392
395
|
Minimum learning rate scale factor when using layer decay. Prevents layers from having too small learning rates.
|
|
393
396
|
layer_decay_no_opt_scale
|
|
@@ -434,6 +437,27 @@ def optimizer_parameter_groups(
|
|
|
434
437
|
if layer_decay is not None:
|
|
435
438
|
logger.warning("Assigning lr scaling (layer decay) without a block group map")
|
|
436
439
|
|
|
440
|
+
backbone_group_map: dict[str, int] = {}
|
|
441
|
+
backbone_num_layers = 0
|
|
442
|
+
if backbone_layer_decay is not None:
|
|
443
|
+
backbone_module = getattr(model, "backbone", None)
|
|
444
|
+
if backbone_module is None:
|
|
445
|
+
logger.warning("Backbone layer decay requested but model has no backbone")
|
|
446
|
+
backbone_layer_decay = None
|
|
447
|
+
else:
|
|
448
|
+
backbone_block_group_regex = getattr(backbone_module, "block_group_regex", None)
|
|
449
|
+
if backbone_block_group_regex is not None:
|
|
450
|
+
names = [n for n, _ in backbone_module.named_parameters()]
|
|
451
|
+
groups = group_by_regex(names, backbone_block_group_regex)
|
|
452
|
+
backbone_group_map = {
|
|
453
|
+
f"backbone.{item}": index for index, sublist in enumerate(groups) for item in sublist
|
|
454
|
+
}
|
|
455
|
+
backbone_num_layers = len(groups)
|
|
456
|
+
else:
|
|
457
|
+
backbone_group_map = {}
|
|
458
|
+
backbone_num_layers = count_layers(backbone_module)
|
|
459
|
+
logger.warning("Assigning lr scaling (backbone layer decay) without a block group map")
|
|
460
|
+
|
|
437
461
|
# Build layer scale
|
|
438
462
|
if layer_decay_min_scale is None:
|
|
439
463
|
layer_decay_min_scale = 0.0
|
|
@@ -444,14 +468,28 @@ def optimizer_parameter_groups(
|
|
|
444
468
|
layer_scales = [max(layer_decay_min_scale, layer_decay ** (layer_max - i)) for i in range(num_layers)]
|
|
445
469
|
logger.info(f"Layer scaling ranges from {min(layer_scales)} to {max(layer_scales)} across {num_layers} layers")
|
|
446
470
|
|
|
471
|
+
backbone_layer_scales = []
|
|
472
|
+
if backbone_layer_decay is not None:
|
|
473
|
+
backbone_layer_max = backbone_num_layers - 1
|
|
474
|
+
backbone_layer_scales = [
|
|
475
|
+
max(layer_decay_min_scale, backbone_layer_decay ** (backbone_layer_max - i))
|
|
476
|
+
for i in range(backbone_num_layers)
|
|
477
|
+
]
|
|
478
|
+
logger.info(
|
|
479
|
+
"Backbone layer scaling ranges from "
|
|
480
|
+
f"{min(backbone_layer_scales)} to {max(backbone_layer_scales)} across {backbone_num_layers} layers"
|
|
481
|
+
)
|
|
482
|
+
|
|
447
483
|
# Set weight decay and layer decay
|
|
448
484
|
idx = 0
|
|
485
|
+
backbone_idx = 0
|
|
449
486
|
params = []
|
|
450
487
|
module_stack_with_prefix = [(model, "")]
|
|
451
488
|
visited_modules = []
|
|
452
489
|
while len(module_stack_with_prefix) > 0: # pylint: disable=too-many-nested-blocks
|
|
453
490
|
skip_module = False
|
|
454
491
|
module, prefix = module_stack_with_prefix.pop()
|
|
492
|
+
is_backbone_module = prefix == "backbone" or prefix.startswith("backbone.")
|
|
455
493
|
if id(module) in visited_modules:
|
|
456
494
|
skip_module = True
|
|
457
495
|
|
|
@@ -460,23 +498,35 @@ def optimizer_parameter_groups(
|
|
|
460
498
|
for name, p in module.named_parameters(recurse=False):
|
|
461
499
|
target_name = f"{prefix}.{name}" if prefix != "" else name
|
|
462
500
|
idx = group_map.get(target_name, idx)
|
|
501
|
+
is_backbone_param = target_name.startswith("backbone.")
|
|
502
|
+
if backbone_layer_decay is not None and is_backbone_param is True:
|
|
503
|
+
backbone_idx = backbone_group_map.get(target_name, backbone_idx)
|
|
463
504
|
if skip_module is True:
|
|
464
505
|
break
|
|
465
506
|
|
|
466
507
|
parameters_found = True
|
|
467
508
|
if p.requires_grad is False:
|
|
468
509
|
continue
|
|
469
|
-
if
|
|
470
|
-
if
|
|
471
|
-
|
|
510
|
+
if layer_decay_no_opt_scale is not None:
|
|
511
|
+
if backbone_layer_decay is not None and is_backbone_param is True:
|
|
512
|
+
if backbone_layer_scales and backbone_layer_scales[backbone_idx] < layer_decay_no_opt_scale:
|
|
513
|
+
p.requires_grad_(False)
|
|
514
|
+
elif layer_decay is not None:
|
|
515
|
+
if layer_scales[idx] < layer_decay_no_opt_scale:
|
|
516
|
+
p.requires_grad_(False)
|
|
472
517
|
|
|
473
518
|
is_custom_key = False
|
|
474
519
|
if custom_keys_weight_decay is not None:
|
|
475
520
|
for key, custom_wd in custom_keys_weight_decay:
|
|
476
521
|
target_name_for_custom_key = f"{prefix}.{name}" if prefix != "" and "." in key else name
|
|
477
522
|
if key == target_name_for_custom_key:
|
|
478
|
-
# Calculate lr_scale (from layer_decay or custom_layer_lr_scale)
|
|
479
|
-
|
|
523
|
+
# Calculate lr_scale (from layer_decay/backbone_layer_decay or custom_layer_lr_scale)
|
|
524
|
+
if layer_decay is not None and (backbone_layer_decay is None or is_backbone_param is False):
|
|
525
|
+
lr_scale = layer_scales[idx]
|
|
526
|
+
elif backbone_layer_decay is not None and is_backbone_param is True:
|
|
527
|
+
lr_scale = backbone_layer_scales[backbone_idx]
|
|
528
|
+
else:
|
|
529
|
+
lr_scale = 1.0
|
|
480
530
|
if custom_layer_lr_scale is not None:
|
|
481
531
|
for layer_name_key, custom_scale in custom_layer_lr_scale.items():
|
|
482
532
|
if layer_name_key in target_name:
|
|
@@ -500,8 +550,8 @@ def optimizer_parameter_groups(
|
|
|
500
550
|
# Apply learning rate based on priority: bias_lr > backbone_lr > lr_scale
|
|
501
551
|
if bias_lr is not None and target_name.endswith(".bias") is True:
|
|
502
552
|
d["lr"] = bias_lr
|
|
503
|
-
elif backbone_lr is not None and
|
|
504
|
-
d["lr"] = backbone_lr
|
|
553
|
+
elif backbone_lr is not None and is_backbone_param is True:
|
|
554
|
+
d["lr"] = backbone_lr * lr_scale if backbone_layer_decay is not None else backbone_lr
|
|
505
555
|
elif lr_scale != 1.0:
|
|
506
556
|
d["lr"] = base_lr * lr_scale
|
|
507
557
|
|
|
@@ -522,8 +572,13 @@ def optimizer_parameter_groups(
|
|
|
522
572
|
wd = custom_wd_value
|
|
523
573
|
break
|
|
524
574
|
|
|
525
|
-
# Calculate lr_scale (from layer_decay or custom_layer_lr_scale)
|
|
526
|
-
|
|
575
|
+
# Calculate lr_scale (from layer_decay/backbone_layer_decay or custom_layer_lr_scale)
|
|
576
|
+
if layer_decay is not None and (backbone_layer_decay is None or is_backbone_param is False):
|
|
577
|
+
lr_scale = layer_scales[idx]
|
|
578
|
+
elif backbone_layer_decay is not None and is_backbone_param is True:
|
|
579
|
+
lr_scale = backbone_layer_scales[backbone_idx]
|
|
580
|
+
else:
|
|
581
|
+
lr_scale = 1.0
|
|
527
582
|
if custom_layer_lr_scale is not None:
|
|
528
583
|
for layer_name_key, custom_scale in custom_layer_lr_scale.items():
|
|
529
584
|
if layer_name_key in target_name:
|
|
@@ -539,8 +594,8 @@ def optimizer_parameter_groups(
|
|
|
539
594
|
# Apply learning rate based on priority: bias_lr > backbone_lr > lr_scale
|
|
540
595
|
if bias_lr is not None and target_name.endswith(".bias") is True:
|
|
541
596
|
d["lr"] = bias_lr
|
|
542
|
-
elif backbone_lr is not None and
|
|
543
|
-
d["lr"] = backbone_lr
|
|
597
|
+
elif backbone_lr is not None and is_backbone_param is True:
|
|
598
|
+
d["lr"] = backbone_lr * lr_scale if backbone_layer_decay is not None else backbone_lr
|
|
544
599
|
elif lr_scale != 1.0:
|
|
545
600
|
d["lr"] = base_lr * lr_scale
|
|
546
601
|
|
|
@@ -548,6 +603,8 @@ def optimizer_parameter_groups(
|
|
|
548
603
|
|
|
549
604
|
if parameters_found is True:
|
|
550
605
|
idx += 1
|
|
606
|
+
if is_backbone_module is True:
|
|
607
|
+
backbone_idx += 1
|
|
551
608
|
|
|
552
609
|
for child_name, child_module in reversed(list(module.named_children())):
|
|
553
610
|
child_prefix = f"{prefix}.{child_name}" if prefix != "" else child_name
|
|
@@ -1108,12 +1165,16 @@ def init_training(
|
|
|
1108
1165
|
device_id = torch.cuda.current_device()
|
|
1109
1166
|
|
|
1110
1167
|
if args.use_deterministic_algorithms is True:
|
|
1168
|
+
log.debug("Turning on deterministic algorithms")
|
|
1111
1169
|
torch.backends.cudnn.benchmark = False
|
|
1112
1170
|
torch.use_deterministic_algorithms(True)
|
|
1113
1171
|
elif cudnn_dynamic_size is True:
|
|
1114
1172
|
# Dynamic sizes: avoid per-size algorithm selection overhead.
|
|
1173
|
+
log.debug("Turning off cudnn")
|
|
1115
1174
|
torch.backends.cudnn.enabled = False
|
|
1175
|
+
torch.backends.cudnn.benchmark = False
|
|
1116
1176
|
else:
|
|
1177
|
+
log.debug("Turning on cudnn")
|
|
1117
1178
|
torch.backends.cudnn.enabled = True
|
|
1118
1179
|
torch.backends.cudnn.benchmark = True
|
|
1119
1180
|
|
|
@@ -15,7 +15,9 @@ def collate_fn(batch: list[tuple[Any, ...]]) -> tuple[Any, ...]:
|
|
|
15
15
|
return tuple(zip(*batch))
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
def batch_images(
|
|
18
|
+
def batch_images(
|
|
19
|
+
images: list[torch.Tensor], size_divisible: int
|
|
20
|
+
) -> tuple[torch.Tensor, torch.Tensor, list[tuple[int, int]]]:
|
|
19
21
|
"""
|
|
20
22
|
Batch list of image tensors of different sizes into a single batch.
|
|
21
23
|
Pad with zeros all images to the shape of the largest image in the list.
|
birder/datahub/_lib.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import tarfile
|
|
3
|
+
import zipfile
|
|
3
4
|
from pathlib import Path
|
|
4
5
|
|
|
5
6
|
from birder.common import cli
|
|
@@ -26,9 +27,17 @@ def download_url(url: str, target: str | Path, sha256: str, progress_bar: bool =
|
|
|
26
27
|
|
|
27
28
|
def extract_archive(from_path: str | Path, to_path: str | Path) -> None:
|
|
28
29
|
logger.info(f"Extracting {from_path} to {to_path}")
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
30
|
+
if isinstance(from_path, str):
|
|
31
|
+
from_path = Path(from_path)
|
|
32
|
+
|
|
33
|
+
if from_path.suffix == ".zip":
|
|
34
|
+
with zipfile.ZipFile(from_path, "r") as zf:
|
|
35
|
+
zf.extractall(to_path) # nosec # tarfile_unsafe_members
|
|
36
|
+
|
|
37
|
+
else:
|
|
38
|
+
with tarfile.open(from_path, "r") as tar:
|
|
39
|
+
if hasattr(tarfile, "data_filter") is True:
|
|
40
|
+
tar.extractall(to_path, filter="data")
|
|
41
|
+
else:
|
|
42
|
+
# NOTE: Remove once minimum Python version is 3.12 or above
|
|
43
|
+
tar.extractall(to_path) # nosec # tarfile_unsafe_members
|