birder 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. birder/common/lib.py +2 -9
  2. birder/common/training_cli.py +18 -0
  3. birder/common/training_utils.py +123 -10
  4. birder/data/collators/detection.py +10 -3
  5. birder/data/datasets/coco.py +8 -10
  6. birder/data/transforms/detection.py +30 -13
  7. birder/inference/detection.py +108 -4
  8. birder/inference/wbf.py +226 -0
  9. birder/net/__init__.py +8 -0
  10. birder/net/detection/efficientdet.py +65 -86
  11. birder/net/detection/rt_detr_v1.py +1 -0
  12. birder/net/detection/yolo_anchors.py +205 -0
  13. birder/net/detection/yolo_v2.py +25 -24
  14. birder/net/detection/yolo_v3.py +39 -40
  15. birder/net/detection/yolo_v4.py +28 -26
  16. birder/net/detection/yolo_v4_tiny.py +24 -20
  17. birder/net/fasternet.py +1 -1
  18. birder/net/gc_vit.py +671 -0
  19. birder/net/lit_v1.py +472 -0
  20. birder/net/lit_v1_tiny.py +342 -0
  21. birder/net/lit_v2.py +436 -0
  22. birder/net/mobilenet_v4_hybrid.py +1 -1
  23. birder/net/resnet_v1.py +1 -1
  24. birder/net/resnext.py +67 -25
  25. birder/net/se_resnet_v1.py +46 -0
  26. birder/net/se_resnext.py +3 -0
  27. birder/net/simple_vit.py +2 -2
  28. birder/net/vit.py +0 -15
  29. birder/net/vovnet_v2.py +31 -1
  30. birder/scripts/benchmark.py +90 -21
  31. birder/scripts/predict.py +1 -0
  32. birder/scripts/predict_detection.py +18 -11
  33. birder/scripts/train.py +10 -34
  34. birder/scripts/train_barlow_twins.py +10 -34
  35. birder/scripts/train_byol.py +10 -34
  36. birder/scripts/train_capi.py +10 -35
  37. birder/scripts/train_data2vec.py +9 -34
  38. birder/scripts/train_data2vec2.py +9 -34
  39. birder/scripts/train_detection.py +48 -40
  40. birder/scripts/train_dino_v1.py +10 -34
  41. birder/scripts/train_dino_v2.py +9 -34
  42. birder/scripts/train_dino_v2_dist.py +9 -34
  43. birder/scripts/train_franca.py +9 -34
  44. birder/scripts/train_i_jepa.py +9 -34
  45. birder/scripts/train_ibot.py +9 -34
  46. birder/scripts/train_kd.py +156 -64
  47. birder/scripts/train_mim.py +10 -34
  48. birder/scripts/train_mmcr.py +10 -34
  49. birder/scripts/train_rotnet.py +10 -34
  50. birder/scripts/train_simclr.py +10 -34
  51. birder/scripts/train_vicreg.py +10 -34
  52. birder/tools/auto_anchors.py +20 -1
  53. birder/tools/pack.py +172 -103
  54. birder/tools/show_det_iterator.py +10 -1
  55. birder/version.py +1 -1
  56. {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/METADATA +3 -3
  57. {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/RECORD +61 -55
  58. {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/WHEEL +0 -0
  59. {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/entry_points.txt +0 -0
  60. {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/licenses/LICENSE +0 -0
  61. {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/top_level.txt +0 -0
birder/common/lib.py CHANGED
@@ -1,11 +1,7 @@
1
1
  import os
2
- import random
3
2
  from typing import Any
4
3
  from typing import Optional
5
4
 
6
- import numpy as np
7
- import torch
8
-
9
5
  from birder.conf import settings
10
6
  from birder.data.transforms.classification import RGBType
11
7
  from birder.model_registry import registry
@@ -19,11 +15,8 @@ from birder.net.ssl.base import SSLBaseNet
19
15
  from birder.version import __version__
20
16
 
21
17
 
22
- def set_random_seeds(seed: int) -> None:
23
- torch.manual_seed(seed)
24
- torch.cuda.manual_seed_all(seed)
25
- np.random.seed(seed)
26
- random.seed(seed)
18
+ def env_bool(name: str) -> bool:
19
+ return os.environ.get(name, "").lower() in {"1", "true", "yes", "on"}
27
20
 
28
21
 
29
22
  def get_size_from_signature(signature: SignatureType | DetectionSignatureType) -> tuple[int, int]:
@@ -5,6 +5,7 @@ import typing
5
5
  from typing import Optional
6
6
  from typing import get_args
7
7
 
8
+ from birder.common.cli import FlexibleDictAction
8
9
  from birder.common.cli import ValidationError
9
10
  from birder.common.training_utils import OptimizerType
10
11
  from birder.common.training_utils import SchedulerType
@@ -82,11 +83,23 @@ def add_lr_wd_args(parser: argparse.ArgumentParser, backbone_lr: bool = False, w
82
83
  metavar="WD",
83
84
  help="weight decay for embedding parameters for vision transformer models",
84
85
  )
86
+ group.add_argument(
87
+ "--custom-layer-wd",
88
+ action=FlexibleDictAction,
89
+ metavar="LAYER=WD",
90
+ help="custom weight decay for specific layers by name (e.g., offset_conv=0.0)",
91
+ )
85
92
  group.add_argument("--layer-decay", type=float, help="layer-wise learning rate decay (LLRD)")
86
93
  group.add_argument("--layer-decay-min-scale", type=float, help="minimum layer scale factor clamp value")
87
94
  group.add_argument(
88
95
  "--layer-decay-no-opt-scale", type=float, help="layer scale threshold below which parameters are frozen"
89
96
  )
97
+ group.add_argument(
98
+ "--custom-layer-lr-scale",
99
+ action=FlexibleDictAction,
100
+ metavar="LAYER=SCALE",
101
+ help="custom lr_scale for specific layers by name (e.g., offset_conv=0.01,attention=0.5)",
102
+ )
90
103
 
91
104
 
92
105
  def add_lr_scheduler_args(parser: argparse.ArgumentParser) -> None:
@@ -185,6 +198,11 @@ def add_detection_input_args(parser: argparse.ArgumentParser) -> None:
185
198
  action="store_true",
186
199
  help="enable random square resize once per batch (capped by max(--size))",
187
200
  )
201
+ group.add_argument(
202
+ "--multiscale-min-size",
203
+ type=int,
204
+ help="minimum short-edge size for multiscale lists (rounded up to nearest multiple of 32)",
205
+ )
188
206
 
189
207
 
190
208
  def add_training_schedule_args(parser: argparse.ArgumentParser, default_epochs: int = 100) -> None:
@@ -3,8 +3,10 @@ import contextlib
3
3
  import logging
4
4
  import math
5
5
  import os
6
+ import random
6
7
  import re
7
8
  import subprocess
9
+ import sys
8
10
  from collections import deque
9
11
  from collections.abc import Callable
10
12
  from collections.abc import Generator
@@ -29,12 +31,25 @@ from birder.data.transforms.classification import training_preset
29
31
  from birder.optim import Lamb
30
32
  from birder.optim import Lars
31
33
  from birder.scheduler import CooldownLR
34
+ from birder.version import __version__ as birder_version
32
35
 
33
36
  logger = logging.getLogger(__name__)
34
37
 
35
38
  OptimizerType = Literal["sgd", "rmsprop", "adam", "adamw", "nadam", "nadamw", "lamb", "lambw", "lars"]
36
39
  SchedulerType = Literal["constant", "step", "multistep", "cosine", "polynomial"]
37
40
 
41
+ ###############################################################################
42
+ # Core Utilities
43
+ ###############################################################################
44
+
45
+
46
+ def set_random_seeds(seed: int) -> None:
47
+ torch.manual_seed(seed)
48
+ torch.cuda.manual_seed_all(seed)
49
+ np.random.seed(seed)
50
+ random.seed(seed)
51
+
52
+
38
53
  ###############################################################################
39
54
  # Data Sampling
40
55
  ###############################################################################
@@ -207,13 +222,16 @@ def count_layers(model: torch.nn.Module) -> int:
207
222
  def optimizer_parameter_groups(
208
223
  model: torch.nn.Module,
209
224
  weight_decay: float,
225
+ base_lr: float,
210
226
  norm_weight_decay: Optional[float] = None,
211
227
  custom_keys_weight_decay: Optional[list[tuple[str, float]]] = None,
228
+ custom_layer_weight_decay: Optional[dict[str, float]] = None,
212
229
  layer_decay: Optional[float] = None,
213
230
  layer_decay_min_scale: Optional[float] = None,
214
231
  layer_decay_no_opt_scale: Optional[float] = None,
215
232
  bias_lr: Optional[float] = None,
216
233
  backbone_lr: Optional[float] = None,
234
+ custom_layer_lr_scale: Optional[dict[str, float]] = None,
217
235
  ) -> list[dict[str, Any]]:
218
236
  """
219
237
  Return parameter groups for optimizers with per-parameter group weight decay.
@@ -233,11 +251,16 @@ def optimizer_parameter_groups(
233
251
  The PyTorch model whose parameters will be grouped for optimization.
234
252
  weight_decay
235
253
  Default weight decay (L2 regularization) value applied to parameters.
254
+ base_lr
255
+ Base learning rate that will be scaled by lr_scale factors for each parameter group.
236
256
  norm_weight_decay
237
257
  Weight decay value specifically for normalization layers. If None, uses weight_decay.
238
258
  custom_keys_weight_decay
239
259
  List of (parameter_name, weight_decay) tuples for applying custom weight decay
240
260
  values to specific parameters by name matching.
261
+ custom_layer_weight_decay
262
+ Dictionary mapping layer name substrings to custom weight decay values.
263
+ Applied to parameters whose names contain the specified keys.
241
264
  layer_decay
242
265
  Layer-wise learning rate decay factor.
243
266
  layer_decay_min_scale
@@ -248,6 +271,9 @@ def optimizer_parameter_groups(
248
271
  Custom learning rate for bias parameters (parameters ending with '.bias').
249
272
  backbone_lr
250
273
  Custom learning rate for backbone parameters (parameters starting with 'backbone.').
274
+ custom_layer_lr_scale
275
+ Dictionary mapping layer name substrings to custom lr_scale values.
276
+ Applied to parameters whose names contain the specified keys.
251
277
 
252
278
  Returns
253
279
  -------
@@ -291,14 +317,14 @@ def optimizer_parameter_groups(
291
317
  if layer_decay is not None:
292
318
  layer_max = num_layers - 1
293
319
  layer_scales = [max(layer_decay_min_scale, layer_decay ** (layer_max - i)) for i in range(num_layers)]
294
- logger.info(f"Layer scaling in range of {min(layer_scales)} - {max(layer_scales)} on {num_layers} layers")
320
+ logger.info(f"Layer scaling ranges from {min(layer_scales)} to {max(layer_scales)} across {num_layers} layers")
295
321
 
296
322
  # Set weight decay and layer decay
297
323
  idx = 0
298
324
  params = []
299
325
  module_stack_with_prefix = [(model, "")]
300
326
  visited_modules = []
301
- while len(module_stack_with_prefix) > 0:
327
+ while len(module_stack_with_prefix) > 0: # pylint: disable=too-many-nested-blocks
302
328
  skip_module = False
303
329
  (module, prefix) = module_stack_with_prefix.pop()
304
330
  if id(module) in visited_modules:
@@ -324,13 +350,35 @@ def optimizer_parameter_groups(
324
350
  for key, custom_wd in custom_keys_weight_decay:
325
351
  target_name_for_custom_key = f"{prefix}.{name}" if prefix != "" and "." in key else name
326
352
  if key == target_name_for_custom_key:
353
+ # Calculate lr_scale (from layer_decay or custom_layer_lr_scale)
354
+ lr_scale = 1.0 if layer_decay is None else layer_scales[idx]
355
+ if custom_layer_lr_scale is not None:
356
+ for layer_name_key, custom_scale in custom_layer_lr_scale.items():
357
+ if layer_name_key in target_name:
358
+ lr_scale = custom_scale
359
+ break
360
+
361
+ # Apply custom layer weight decay (substring matching)
362
+ wd = custom_wd
363
+ if custom_layer_weight_decay is not None:
364
+ for layer_name_key, custom_wd_value in custom_layer_weight_decay.items():
365
+ if layer_name_key in target_name:
366
+ wd = custom_wd_value
367
+ break
368
+
327
369
  d = {
328
370
  "params": p,
329
- "weight_decay": custom_wd,
330
- "lr_scale": 1.0 if layer_decay is None else layer_scales[idx],
371
+ "weight_decay": wd,
372
+ "lr_scale": lr_scale, # Used only for reference/debugging
331
373
  }
332
- if backbone_lr is not None and target_name.startswith("backbone.") is True:
374
+
375
+ # Apply learning rate based on priority: bias_lr > backbone_lr > lr_scale
376
+ if bias_lr is not None and target_name.endswith(".bias") is True:
377
+ d["lr"] = bias_lr
378
+ elif backbone_lr is not None and target_name.startswith("backbone.") is True:
333
379
  d["lr"] = backbone_lr
380
+ elif lr_scale != 1.0:
381
+ d["lr"] = base_lr * lr_scale
334
382
 
335
383
  params.append(d)
336
384
  is_custom_key = True
@@ -342,16 +390,34 @@ def optimizer_parameter_groups(
342
390
  else:
343
391
  wd = weight_decay
344
392
 
393
+ # Apply custom layer weight decay (substring matching)
394
+ if custom_layer_weight_decay is not None:
395
+ for layer_name_key, custom_wd_value in custom_layer_weight_decay.items():
396
+ if layer_name_key in target_name:
397
+ wd = custom_wd_value
398
+ break
399
+
400
+ # Calculate lr_scale (from layer_decay or custom_layer_lr_scale)
401
+ lr_scale = 1.0 if layer_decay is None else layer_scales[idx]
402
+ if custom_layer_lr_scale is not None:
403
+ for layer_name_key, custom_scale in custom_layer_lr_scale.items():
404
+ if layer_name_key in target_name:
405
+ lr_scale = custom_scale
406
+ break
407
+
345
408
  d = {
346
409
  "params": p,
347
410
  "weight_decay": wd,
348
- "lr_scale": 1.0 if layer_decay is None else layer_scales[idx],
411
+ "lr_scale": lr_scale, # Used only for reference/debugging
349
412
  }
350
- if backbone_lr is not None and target_name.startswith("backbone.") is True:
351
- d["lr"] = backbone_lr
352
413
 
414
+ # Apply learning rate based on priority: bias_lr > backbone_lr > lr_scale
353
415
  if bias_lr is not None and target_name.endswith(".bias") is True:
354
416
  d["lr"] = bias_lr
417
+ elif backbone_lr is not None and target_name.startswith("backbone.") is True:
418
+ d["lr"] = backbone_lr
419
+ elif lr_scale != 1.0:
420
+ d["lr"] = base_lr * lr_scale
355
421
 
356
422
  params.append(d)
357
423
 
@@ -442,6 +508,8 @@ def get_optimizer(parameters: list[dict[str, Any]], l_rate: float, args: argpars
442
508
  else:
443
509
  raise ValueError("Unknown optimizer")
444
510
 
511
+ logger.debug(f"Created {opt} optimizer with lr={lr}, weight_decay={args.wd}")
512
+
445
513
  return optimizer
446
514
 
447
515
 
@@ -477,10 +545,10 @@ def get_scheduler(
477
545
 
478
546
  main_steps = steps - begin_step - remaining_warmup - remaining_cooldown - 1
479
547
 
480
- logger.debug(f"Using {steps_per_epoch} steps per epoch")
548
+ logger.debug(f"Scheduler using {steps_per_epoch} steps per epoch")
481
549
  logger.debug(
482
550
  f"Scheduler {args.lr_scheduler} set for {steps} steps of which {warmup_steps} "
483
- f"are warmup and {cooldown_steps} cooldown"
551
+ f"are warmup and {cooldown_steps} are cooldown"
484
552
  )
485
553
  logger.debug(
486
554
  f"Currently starting from step {begin_step} with {remaining_warmup} remaining warmup steps "
@@ -810,6 +878,51 @@ def is_local_primary(args: argparse.Namespace) -> bool:
810
878
  return args.local_rank == 0 # type: ignore[no-any-return]
811
879
 
812
880
 
881
+ def init_training(
882
+ args: argparse.Namespace,
883
+ log: logging.Logger,
884
+ *,
885
+ cudnn_dynamic_size: bool = False,
886
+ ) -> tuple[torch.device, int, bool]:
887
+ init_distributed_mode(args)
888
+
889
+ log.info(f"Starting training, birder version: {birder_version}, pytorch version: {torch.__version__}")
890
+
891
+ log_git_info()
892
+
893
+ if args.cpu is True:
894
+ device = torch.device("cpu")
895
+ device_id = 0
896
+ else:
897
+ device = torch.device("cuda")
898
+ device_id = torch.cuda.current_device()
899
+
900
+ if args.use_deterministic_algorithms is True:
901
+ torch.backends.cudnn.benchmark = False
902
+ torch.use_deterministic_algorithms(True)
903
+ elif cudnn_dynamic_size is True:
904
+ # Dynamic sizes: avoid per-size algorithm selection overhead.
905
+ torch.backends.cudnn.enabled = False
906
+ else:
907
+ torch.backends.cudnn.enabled = True
908
+ torch.backends.cudnn.benchmark = True
909
+
910
+ if args.seed is not None:
911
+ set_random_seeds(args.seed)
912
+
913
+ if args.non_interactive is True or is_local_primary(args) is False:
914
+ disable_tqdm = True
915
+ elif sys.stderr.isatty() is False:
916
+ disable_tqdm = True
917
+ else:
918
+ disable_tqdm = False
919
+
920
+ # Enable or disable the autograd anomaly detection.
921
+ torch.autograd.set_detect_anomaly(args.grad_anomaly_detection)
922
+
923
+ return (device, device_id, disable_tqdm)
924
+
925
+
813
926
  ###############################################################################
814
927
  # Utility Functions
815
928
  ###############################################################################
@@ -1,13 +1,14 @@
1
1
  import math
2
2
  import random
3
3
  from typing import Any
4
+ from typing import Optional
4
5
 
5
6
  import torch
6
7
  from torchvision import tv_tensors
7
8
  from torchvision.transforms import v2
8
9
  from torchvision.transforms.v2 import functional as F
9
10
 
10
- BATCH_MULTISCALE_SIZES = (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
11
+ from birder.data.transforms.detection import build_multiscale_sizes
11
12
 
12
13
 
13
14
  def collate_fn(batch: list[tuple[Any, ...]]) -> tuple[Any, ...]:
@@ -63,13 +64,19 @@ class DetectionCollator:
63
64
 
64
65
 
65
66
  class BatchRandomResizeCollator(DetectionCollator):
66
- def __init__(self, input_offset: int, size: tuple[int, int], size_divisible: int = 32) -> None:
67
+ def __init__(
68
+ self,
69
+ input_offset: int,
70
+ size: tuple[int, int],
71
+ size_divisible: int = 32,
72
+ multiscale_min_size: Optional[int] = None,
73
+ ) -> None:
67
74
  super().__init__(input_offset, size_divisible=size_divisible)
68
75
  if size is None:
69
76
  raise ValueError("size must be provided for batch multiscale")
70
77
 
71
78
  max_side = max(size)
72
- sizes = [side for side in BATCH_MULTISCALE_SIZES if side <= max_side]
79
+ sizes = [side for side in build_multiscale_sizes(multiscale_min_size) if side <= max_side]
73
80
  if len(sizes) == 0:
74
81
  sizes = [max_side]
75
82
 
@@ -98,10 +98,14 @@ class CocoTraining(CocoBase):
98
98
  class CocoInference(CocoBase):
99
99
  def __getitem__(self, index: int) -> tuple[str, torch.Tensor, Any, list[int]]:
100
100
  coco_id = self.dataset.ids[index]
101
- path = self.dataset.coco.loadImgs(coco_id)[0]["file_name"]
101
+ img_info = self.dataset.coco.loadImgs(coco_id)[0]
102
+ path = img_info["file_name"]
102
103
  (sample, labels) = self.dataset[index]
103
104
 
104
- return (path, sample, labels, F.get_size(sample))
105
+ # Get original image size (height, width) before transforms
106
+ orig_size = [img_info["height"], img_info["width"]]
107
+
108
+ return (path, sample, labels, orig_size)
105
109
 
106
110
 
107
111
  class CocoMosaicTraining(CocoBase):
@@ -127,9 +131,7 @@ class CocoMosaicTraining(CocoBase):
127
131
  self._mosaic_decay_epochs: Optional[int] = None
128
132
  self._mosaic_decay_start: Optional[int] = None
129
133
 
130
- def configure_mosaic_linear_decay(
131
- self, base_prob: float, total_epochs: int, decay_fraction: float = 0.1
132
- ) -> None:
134
+ def configure_mosaic_linear_decay(self, base_prob: float, total_epochs: int, decay_fraction: float = 0.1) -> None:
133
135
  if total_epochs <= 0:
134
136
  raise ValueError("total_epochs must be positive")
135
137
  if decay_fraction <= 0.0 or decay_fraction > 1.0:
@@ -141,11 +143,7 @@ class CocoMosaicTraining(CocoBase):
141
143
  self._mosaic_decay_start = max(1, total_epochs - decay_epochs + 1)
142
144
 
143
145
  def update_mosaic_prob(self, epoch: int) -> Optional[float]:
144
- if (
145
- self._mosaic_base_prob is None
146
- or self._mosaic_decay_epochs is None
147
- or self._mosaic_decay_start is None
148
- ):
146
+ if self._mosaic_base_prob is None or self._mosaic_decay_epochs is None or self._mosaic_decay_start is None:
149
147
  return None
150
148
 
151
149
  if epoch >= self._mosaic_decay_start:
@@ -1,3 +1,4 @@
1
+ import math
1
2
  import random
2
3
  from collections.abc import Callable
3
4
  from typing import Any
@@ -10,6 +11,24 @@ from torchvision.transforms import v2
10
11
 
11
12
  from birder.data.transforms.classification import RGBType
12
13
 
14
+ MULTISCALE_STEP = 32
15
+ DEFAULT_MULTISCALE_MIN_SIZE = 480
16
+ DEFAULT_MULTISCALE_MAX_SIZE = 800
17
+
18
+
19
+ def build_multiscale_sizes(
20
+ min_size: Optional[int] = None, max_size: int = DEFAULT_MULTISCALE_MAX_SIZE
21
+ ) -> tuple[int, ...]:
22
+ if min_size is None:
23
+ min_size = DEFAULT_MULTISCALE_MIN_SIZE
24
+
25
+ start = int(math.ceil(min_size / MULTISCALE_STEP) * MULTISCALE_STEP)
26
+ end = int(math.floor(max_size / MULTISCALE_STEP) * MULTISCALE_STEP)
27
+ if end < start:
28
+ return (start,)
29
+
30
+ return tuple(range(start, end + 1, MULTISCALE_STEP))
31
+
13
32
 
14
33
  class ResizeWithRandomInterpolation(nn.Module):
15
34
  def __init__(
@@ -39,6 +58,7 @@ def get_birder_augment(
39
58
  dynamic_size: bool,
40
59
  multiscale: bool,
41
60
  max_size: Optional[int],
61
+ multiscale_min_size: Optional[int],
42
62
  post_mosaic: bool = False,
43
63
  ) -> Callable[..., torch.Tensor]:
44
64
  if dynamic_size is True:
@@ -78,9 +98,7 @@ def get_birder_augment(
78
98
  # Resize
79
99
  if multiscale is True:
80
100
  transformations.append(
81
- v2.RandomShortestSize(
82
- min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=max_size or 1333
83
- ),
101
+ v2.RandomShortestSize(min_size=build_multiscale_sizes(multiscale_min_size), max_size=max_size or 1333),
84
102
  )
85
103
  else:
86
104
  transformations.append(
@@ -132,6 +150,7 @@ def get_birder_augment(
132
150
  AugType = Literal["birder", "lsj", "multiscale", "ssd", "ssdlite", "yolo", "detr"]
133
151
 
134
152
 
153
+ # pylint: disable=too-many-return-statements
135
154
  def training_preset(
136
155
  size: tuple[int, int],
137
156
  aug_type: AugType,
@@ -140,6 +159,7 @@ def training_preset(
140
159
  dynamic_size: bool = False,
141
160
  multiscale: bool = False,
142
161
  max_size: Optional[int] = None,
162
+ multiscale_min_size: Optional[int] = None,
143
163
  post_mosaic: bool = False,
144
164
  ) -> Callable[..., torch.Tensor]:
145
165
  mean = rgv_values["mean"]
@@ -159,7 +179,9 @@ def training_preset(
159
179
  return v2.Compose( # type:ignore
160
180
  [
161
181
  v2.ToImage(),
162
- get_birder_augment(size, level, fill_value, dynamic_size, multiscale, max_size, post_mosaic),
182
+ get_birder_augment(
183
+ size, level, fill_value, dynamic_size, multiscale, max_size, multiscale_min_size, post_mosaic
184
+ ),
163
185
  v2.ToDtype(torch.float32, scale=True),
164
186
  v2.Normalize(mean=mean, std=std),
165
187
  v2.ToPureTensor(),
@@ -190,9 +212,7 @@ def training_preset(
190
212
  return v2.Compose( # type: ignore
191
213
  [
192
214
  v2.ToImage(),
193
- v2.RandomShortestSize(
194
- min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=max_size or 1333
195
- ),
215
+ v2.RandomShortestSize(min_size=build_multiscale_sizes(multiscale_min_size), max_size=max_size or 1333),
196
216
  v2.RandomHorizontalFlip(0.5),
197
217
  v2.SanitizeBoundingBoxes(),
198
218
  v2.ToDtype(torch.float32, scale=True),
@@ -264,21 +284,18 @@ def training_preset(
264
284
  )
265
285
 
266
286
  if aug_type == "detr":
287
+ multiscale_sizes = build_multiscale_sizes(multiscale_min_size)
267
288
  return v2.Compose( # type: ignore
268
289
  [
269
290
  v2.ToImage(),
270
291
  v2.RandomChoice(
271
292
  [
272
- v2.RandomShortestSize(
273
- (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=max_size or 1333
274
- ),
293
+ v2.RandomShortestSize(min_size=multiscale_sizes, max_size=max_size or 1333),
275
294
  v2.Compose(
276
295
  [
277
296
  v2.RandomShortestSize((400, 500, 600)),
278
297
  v2.RandomIoUCrop() if post_mosaic is False else v2.Identity(), # RandomSizeCrop
279
- v2.RandomShortestSize(
280
- (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=max_size or 1333
281
- ),
298
+ v2.RandomShortestSize(min_size=multiscale_sizes, max_size=max_size or 1333),
282
299
  ]
283
300
  ),
284
301
  ]
@@ -5,17 +5,99 @@ from typing import Optional
5
5
  import torch
6
6
  import torch.amp
7
7
  from PIL import Image
8
+ from torch.nn import functional as F
8
9
  from torch.utils.data import DataLoader
9
10
  from tqdm import tqdm
10
11
 
11
12
  from birder.conf import settings
13
+ from birder.data.collators.detection import batch_images
12
14
  from birder.data.transforms.detection import InferenceTransform
15
+ from birder.inference.wbf import fuse_detections_wbf
16
+ from birder.net.base import make_divisible
17
+
18
+
19
+ def _normalize_image_sizes(inputs: torch.Tensor, image_sizes: Optional[list[list[int]]]) -> list[list[int]]:
20
+ if image_sizes is not None:
21
+ return image_sizes
22
+
23
+ (_, _, height, width) = inputs.shape
24
+ return [[height, width] for _ in range(inputs.size(0))]
25
+
26
+
27
+ def _hflip_inputs(inputs: torch.Tensor, image_sizes: list[list[int]]) -> torch.Tensor:
28
+ # Detection collator pads on the right/bottom, so flip only the valid region to keep padding aligned.
29
+ flipped = inputs.clone()
30
+ for idx, (height, width) in enumerate(image_sizes):
31
+ flipped[idx, :, :height, :width] = torch.flip(inputs[idx, :, :height, :width], dims=[2])
32
+
33
+ return flipped
34
+
35
+
36
+ def _resize_batch(
37
+ inputs: torch.Tensor, image_sizes: list[list[int]], scale: float, size_divisible: int
38
+ ) -> tuple[torch.Tensor, torch.Tensor, list[list[int]]]:
39
+ resized_images: list[torch.Tensor] = []
40
+ for idx, (height, width) in enumerate(image_sizes):
41
+ target_h = make_divisible(height * scale, size_divisible)
42
+ target_w = make_divisible(width * scale, size_divisible)
43
+ image = inputs[idx, :, :height, :width]
44
+ resized = F.interpolate(image.unsqueeze(0), size=(target_h, target_w), mode="bilinear", align_corners=False)
45
+ resized_images.append(resized.squeeze(0))
46
+
47
+ return batch_images(resized_images, size_divisible)
48
+
49
+
50
+ def _rescale_boxes(boxes: torch.Tensor, from_size: list[int], to_size: list[int]) -> torch.Tensor:
51
+ scale_w = to_size[1] / from_size[1]
52
+ scale_h = to_size[0] / from_size[0]
53
+ scale = boxes.new_tensor([scale_w, scale_h, scale_w, scale_h])
54
+ return boxes * scale
55
+
56
+
57
+ def _rescale_detections(
58
+ detections: list[dict[str, torch.Tensor]],
59
+ from_sizes: list[list[int]],
60
+ to_sizes: list[list[int]],
61
+ ) -> list[dict[str, torch.Tensor]]:
62
+ for idx, (detection, from_size, to_size) in enumerate(zip(detections, from_sizes, to_sizes)):
63
+ boxes = detection["boxes"]
64
+ if boxes.numel() == 0:
65
+ continue
66
+
67
+ detections[idx]["boxes"] = _rescale_boxes(boxes, from_size, to_size)
68
+
69
+ return detections
70
+
71
+
72
+ def _invert_hflip_boxes(boxes: torch.Tensor, image_size: list[int]) -> torch.Tensor:
73
+ width = boxes.new_tensor(image_size[1])
74
+ x1 = boxes[:, 0]
75
+ x2 = boxes[:, 2]
76
+ flipped = boxes.clone()
77
+ flipped[:, 0] = width - x2
78
+ flipped[:, 2] = width - x1
79
+
80
+ return flipped
81
+
82
+
83
+ def _invert_detections(
84
+ detections: list[dict[str, torch.Tensor]], image_sizes: list[list[int]]
85
+ ) -> list[dict[str, torch.Tensor]]:
86
+ for idx, (detection, image_size) in enumerate(zip(detections, image_sizes)):
87
+ boxes = detection["boxes"]
88
+ if boxes.numel() == 0:
89
+ continue
90
+
91
+ detections[idx]["boxes"] = _invert_hflip_boxes(boxes, image_size)
92
+
93
+ return detections
13
94
 
14
95
 
15
96
  def infer_image(
16
97
  net: torch.nn.Module | torch.ScriptModule,
17
98
  sample: Image.Image | str,
18
99
  transform: Callable[..., torch.Tensor],
100
+ tta: bool = False,
19
101
  device: Optional[torch.device] = None,
20
102
  score_threshold: Optional[float] = None,
21
103
  **kwargs: Any,
@@ -43,7 +125,7 @@ def infer_image(
43
125
  device = torch.device("cpu")
44
126
 
45
127
  input_tensor = transform(image).unsqueeze(dim=0).to(device)
46
- detections = infer_batch(net, input_tensor, **kwargs)
128
+ detections = infer_batch(net, input_tensor, tta=tta, **kwargs)
47
129
  if score_threshold is not None:
48
130
  for i, detection in enumerate(detections):
49
131
  idxs = torch.where(detection["scores"] > score_threshold)
@@ -63,16 +145,36 @@ def infer_batch(
63
145
  inputs: torch.Tensor,
64
146
  masks: Optional[torch.Tensor] = None,
65
147
  image_sizes: Optional[list[list[int]]] = None,
148
+ tta: bool = False,
66
149
  **kwargs: Any,
67
150
  ) -> list[dict[str, torch.Tensor]]:
68
- (detections, _) = net(inputs, masks=masks, image_sizes=image_sizes, **kwargs)
69
- return detections # type: ignore[no-any-return]
151
+ if tta is False:
152
+ (detections, _) = net(inputs, masks=masks, image_sizes=image_sizes, **kwargs)
153
+ return detections # type: ignore[no-any-return]
154
+
155
+ normalized_sizes = _normalize_image_sizes(inputs, image_sizes)
156
+ detections_list: list[list[dict[str, torch.Tensor]]] = []
157
+
158
+ for scale in (0.8, 1.0, 1.2):
159
+ (scaled_inputs, scaled_masks, scaled_sizes) = _resize_batch(inputs, normalized_sizes, scale, size_divisible=32)
160
+ (detections, _) = net(scaled_inputs, masks=scaled_masks, image_sizes=scaled_sizes, **kwargs)
161
+ detections = _rescale_detections(detections, scaled_sizes, normalized_sizes)
162
+ detections_list.append(detections)
163
+
164
+ flipped_inputs = _hflip_inputs(scaled_inputs, scaled_sizes)
165
+ (flipped_detections, _) = net(flipped_inputs, masks=scaled_masks, image_sizes=scaled_sizes, **kwargs)
166
+ flipped_detections = _invert_detections(flipped_detections, scaled_sizes)
167
+ flipped_detections = _rescale_detections(flipped_detections, scaled_sizes, normalized_sizes)
168
+ detections_list.append(flipped_detections)
169
+
170
+ return fuse_detections_wbf(detections_list, iou_thr=0.55, conf_type="avg")
70
171
 
71
172
 
72
173
  def infer_dataloader(
73
174
  device: torch.device,
74
175
  net: torch.nn.Module | torch.ScriptModule,
75
176
  dataloader: DataLoader,
177
+ tta: bool = False,
76
178
  model_dtype: torch.dtype = torch.float32,
77
179
  amp: bool = False,
78
180
  amp_dtype: Optional[torch.dtype] = None,
@@ -97,6 +199,8 @@ def infer_dataloader(
97
199
  The model to use for inference.
98
200
  dataloader
99
201
  The DataLoader containing the dataset to perform inference on.
202
+ tta
203
+ Run inference with multi-scale and horizontal flip test time augmentation and fuse results with WBF.
100
204
  model_dtype
101
205
  The base dtype to use.
102
206
  amp
@@ -142,7 +246,7 @@ def infer_dataloader(
142
246
  masks = masks.to(device, non_blocking=True)
143
247
 
144
248
  with torch.amp.autocast(device.type, enabled=amp, dtype=amp_dtype):
145
- detections = infer_batch(net, inputs, masks, image_sizes)
249
+ detections = infer_batch(net, inputs, masks=masks, image_sizes=image_sizes, tta=tta)
146
250
 
147
251
  detections = InferenceTransform.postprocess(detections, image_sizes, orig_sizes)
148
252
  if targets[0] != settings.NO_LABEL: