birder 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. birder/common/lib.py +2 -9
  2. birder/common/training_cli.py +18 -0
  3. birder/common/training_utils.py +123 -10
  4. birder/data/collators/detection.py +10 -3
  5. birder/data/datasets/coco.py +8 -10
  6. birder/data/transforms/detection.py +30 -13
  7. birder/inference/detection.py +108 -4
  8. birder/inference/wbf.py +226 -0
  9. birder/net/__init__.py +8 -0
  10. birder/net/detection/efficientdet.py +65 -86
  11. birder/net/detection/rt_detr_v1.py +1 -0
  12. birder/net/detection/yolo_anchors.py +205 -0
  13. birder/net/detection/yolo_v2.py +25 -24
  14. birder/net/detection/yolo_v3.py +39 -40
  15. birder/net/detection/yolo_v4.py +28 -26
  16. birder/net/detection/yolo_v4_tiny.py +24 -20
  17. birder/net/fasternet.py +1 -1
  18. birder/net/gc_vit.py +671 -0
  19. birder/net/lit_v1.py +472 -0
  20. birder/net/lit_v1_tiny.py +342 -0
  21. birder/net/lit_v2.py +436 -0
  22. birder/net/mobilenet_v4_hybrid.py +1 -1
  23. birder/net/resnet_v1.py +1 -1
  24. birder/net/resnext.py +67 -25
  25. birder/net/se_resnet_v1.py +46 -0
  26. birder/net/se_resnext.py +3 -0
  27. birder/net/simple_vit.py +2 -2
  28. birder/net/vit.py +0 -15
  29. birder/net/vovnet_v2.py +31 -1
  30. birder/scripts/benchmark.py +90 -21
  31. birder/scripts/predict.py +1 -0
  32. birder/scripts/predict_detection.py +18 -11
  33. birder/scripts/train.py +10 -34
  34. birder/scripts/train_barlow_twins.py +10 -34
  35. birder/scripts/train_byol.py +10 -34
  36. birder/scripts/train_capi.py +10 -35
  37. birder/scripts/train_data2vec.py +9 -34
  38. birder/scripts/train_data2vec2.py +9 -34
  39. birder/scripts/train_detection.py +48 -40
  40. birder/scripts/train_dino_v1.py +10 -34
  41. birder/scripts/train_dino_v2.py +9 -34
  42. birder/scripts/train_dino_v2_dist.py +9 -34
  43. birder/scripts/train_franca.py +9 -34
  44. birder/scripts/train_i_jepa.py +9 -34
  45. birder/scripts/train_ibot.py +9 -34
  46. birder/scripts/train_kd.py +156 -64
  47. birder/scripts/train_mim.py +10 -34
  48. birder/scripts/train_mmcr.py +10 -34
  49. birder/scripts/train_rotnet.py +10 -34
  50. birder/scripts/train_simclr.py +10 -34
  51. birder/scripts/train_vicreg.py +10 -34
  52. birder/tools/auto_anchors.py +20 -1
  53. birder/tools/pack.py +172 -103
  54. birder/tools/show_det_iterator.py +10 -1
  55. birder/version.py +1 -1
  56. {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/METADATA +3 -3
  57. {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/RECORD +61 -55
  58. {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/WHEEL +0 -0
  59. {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/entry_points.txt +0 -0
  60. {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/licenses/LICENSE +0 -0
  61. {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/top_level.txt +0 -0
@@ -41,7 +41,6 @@ from birder.common import training_utils
41
41
  from birder.common.lib import format_duration
42
42
  from birder.common.lib import get_mim_network_name
43
43
  from birder.common.lib import get_network_name
44
- from birder.common.lib import set_random_seeds
45
44
  from birder.conf import settings
46
45
  from birder.data.dataloader.webdataset import make_wds_loader
47
46
  from birder.data.datasets.directory import make_image_dataset
@@ -79,44 +78,16 @@ def train(args: argparse.Namespace) -> None:
79
78
  #
80
79
  # Initialize
81
80
  #
82
- training_utils.init_distributed_mode(args)
83
- logger.info(f"Starting training, birder version: {birder.__version__}, pytorch version: {torch.__version__}")
84
- training_utils.log_git_info()
81
+ (device, device_id, disable_tqdm) = training_utils.init_training(args, logger)
85
82
 
86
83
  if args.size is None:
87
84
  args.size = registry.get_default_size(args.network)
88
85
 
89
86
  logger.info(f"Using size={args.size}")
90
87
 
91
- if args.cpu is True:
92
- device = torch.device("cpu")
93
- device_id = 0
94
- else:
95
- device = torch.device("cuda")
96
- device_id = torch.cuda.current_device()
97
-
98
- if args.use_deterministic_algorithms is True:
99
- torch.backends.cudnn.benchmark = False
100
- torch.use_deterministic_algorithms(True)
101
- else:
102
- torch.backends.cudnn.benchmark = True
103
-
104
- if args.seed is not None:
105
- set_random_seeds(args.seed)
106
-
107
- if args.non_interactive is True or training_utils.is_local_primary(args) is False:
108
- disable_tqdm = True
109
- elif sys.stderr.isatty() is False:
110
- disable_tqdm = True
111
- else:
112
- disable_tqdm = False
113
-
114
- # Enable or disable the autograd anomaly detection
115
- torch.autograd.set_detect_anomaly(args.grad_anomaly_detection)
116
-
117
88
  batch_size: int = args.batch_size
118
89
  grad_accum_steps: int = args.grad_accum_steps
119
- logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
90
+ logger.debug(f"Effective batch size = {batch_size * grad_accum_steps * args.world_size}")
120
91
 
121
92
  begin_epoch = 1
122
93
  epochs = args.epochs + 1
@@ -281,28 +252,32 @@ def train(args: argparse.Namespace) -> None:
281
252
 
282
253
  optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
283
254
  last_batch_idx = len(training_loader) - 1
255
+ logging.debug(f"Epoch has {last_batch_idx+1} iterations ({optimizer_steps_per_epoch} steps)")
284
256
 
285
257
  #
286
258
  # Loss criteria, optimizer, learning rate scheduler and training parameter groups
287
259
  #
288
260
 
261
+ # Learning rate scaling
262
+ lr = training_utils.scale_lr(args)
263
+ clustering_lr = lr / 2
264
+
289
265
  # Training parameter groups
290
266
  custom_keys_weight_decay = training_utils.get_wd_custom_keys(args)
291
267
  parameters = training_utils.optimizer_parameter_groups(
292
268
  student,
293
269
  args.wd,
270
+ base_lr=lr,
294
271
  norm_weight_decay=args.norm_wd,
295
272
  custom_keys_weight_decay=custom_keys_weight_decay,
273
+ custom_layer_weight_decay=args.custom_layer_wd,
296
274
  layer_decay=args.layer_decay,
297
275
  layer_decay_min_scale=args.layer_decay_min_scale,
298
276
  layer_decay_no_opt_scale=args.layer_decay_no_opt_scale,
299
277
  bias_lr=args.bias_lr,
278
+ custom_layer_lr_scale=args.custom_layer_lr_scale,
300
279
  )
301
280
 
302
- # Learning rate scaling
303
- lr = training_utils.scale_lr(args)
304
- clustering_lr = lr / 2
305
-
306
281
  if args.lr_scheduler_update == "epoch":
307
282
  step_update = False
308
283
  scheduler_steps_per_epoch = 1
@@ -31,7 +31,6 @@ from birder.common import training_utils
31
31
  from birder.common.lib import format_duration
32
32
  from birder.common.lib import get_mim_network_name
33
33
  from birder.common.lib import get_network_name
34
- from birder.common.lib import set_random_seeds
35
34
  from birder.common.masking import BlockMasking
36
35
  from birder.conf import settings
37
36
  from birder.data.dataloader.webdataset import make_wds_loader
@@ -69,9 +68,7 @@ def train(args: argparse.Namespace) -> None:
69
68
  #
70
69
  # Initialize
71
70
  #
72
- training_utils.init_distributed_mode(args)
73
- logger.info(f"Starting training, birder version: {birder.__version__}, pytorch version: {torch.__version__}")
74
- training_utils.log_git_info()
71
+ (device, device_id, disable_tqdm) = training_utils.init_training(args, logger)
75
72
 
76
73
  if args.size is None:
77
74
  # Prefer mim size over encoder default size
@@ -79,35 +76,9 @@ def train(args: argparse.Namespace) -> None:
79
76
 
80
77
  logger.info(f"Using size={args.size}")
81
78
 
82
- if args.cpu is True:
83
- device = torch.device("cpu")
84
- device_id = 0
85
- else:
86
- device = torch.device("cuda")
87
- device_id = torch.cuda.current_device()
88
-
89
- if args.use_deterministic_algorithms is True:
90
- torch.backends.cudnn.benchmark = False
91
- torch.use_deterministic_algorithms(True)
92
- else:
93
- torch.backends.cudnn.benchmark = True
94
-
95
- if args.seed is not None:
96
- set_random_seeds(args.seed)
97
-
98
- if args.non_interactive is True or training_utils.is_local_primary(args) is False:
99
- disable_tqdm = True
100
- elif sys.stderr.isatty() is False:
101
- disable_tqdm = True
102
- else:
103
- disable_tqdm = False
104
-
105
- # Enable or disable the autograd anomaly detection
106
- torch.autograd.set_detect_anomaly(args.grad_anomaly_detection)
107
-
108
79
  batch_size: int = args.batch_size
109
80
  grad_accum_steps: int = args.grad_accum_steps
110
- logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
81
+ logger.debug(f"Effective batch size = {batch_size * grad_accum_steps * args.world_size}")
111
82
 
112
83
  begin_epoch = 1
113
84
  epochs = args.epochs + 1
@@ -248,27 +219,31 @@ def train(args: argparse.Namespace) -> None:
248
219
 
249
220
  optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
250
221
  last_batch_idx = len(training_loader) - 1
222
+ logging.debug(f"Epoch has {last_batch_idx+1} iterations ({optimizer_steps_per_epoch} steps)")
251
223
 
252
224
  #
253
225
  # Loss criteria, optimizer, learning rate scheduler and training parameter groups
254
226
  #
255
227
 
228
+ # Learning rate scaling
229
+ lr = training_utils.scale_lr(args)
230
+
256
231
  # Training parameter groups
257
232
  custom_keys_weight_decay = training_utils.get_wd_custom_keys(args)
258
233
  parameters = training_utils.optimizer_parameter_groups(
259
234
  net,
260
235
  args.wd,
236
+ base_lr=lr,
261
237
  norm_weight_decay=args.norm_wd,
262
238
  custom_keys_weight_decay=custom_keys_weight_decay,
239
+ custom_layer_weight_decay=args.custom_layer_wd,
263
240
  layer_decay=args.layer_decay,
264
241
  layer_decay_min_scale=args.layer_decay_min_scale,
265
242
  layer_decay_no_opt_scale=args.layer_decay_no_opt_scale,
266
243
  bias_lr=args.bias_lr,
244
+ custom_layer_lr_scale=args.custom_layer_lr_scale,
267
245
  )
268
246
 
269
- # Learning rate scaling
270
- lr = training_utils.scale_lr(args)
271
-
272
247
  if args.lr_scheduler_update == "epoch":
273
248
  step_update = False
274
249
  scheduler_steps_per_epoch = 1
@@ -34,7 +34,6 @@ from birder.common import training_utils
34
34
  from birder.common.lib import format_duration
35
35
  from birder.common.lib import get_mim_network_name
36
36
  from birder.common.lib import get_network_name
37
- from birder.common.lib import set_random_seeds
38
37
  from birder.common.masking import InverseRollBlockMasking
39
38
  from birder.conf import settings
40
39
  from birder.data.dataloader.webdataset import make_wds_loader
@@ -75,9 +74,7 @@ def train(args: argparse.Namespace) -> None:
75
74
  #
76
75
  # Initialize
77
76
  #
78
- training_utils.init_distributed_mode(args)
79
- logger.info(f"Starting training, birder version: {birder.__version__}, pytorch version: {torch.__version__}")
80
- training_utils.log_git_info()
77
+ (device, device_id, disable_tqdm) = training_utils.init_training(args, logger)
81
78
 
82
79
  if args.size is None:
83
80
  # Prefer mim size over encoder default size
@@ -85,35 +82,9 @@ def train(args: argparse.Namespace) -> None:
85
82
 
86
83
  logger.info(f"Using size={args.size}")
87
84
 
88
- if args.cpu is True:
89
- device = torch.device("cpu")
90
- device_id = 0
91
- else:
92
- device = torch.device("cuda")
93
- device_id = torch.cuda.current_device()
94
-
95
- if args.use_deterministic_algorithms is True:
96
- torch.backends.cudnn.benchmark = False
97
- torch.use_deterministic_algorithms(True)
98
- else:
99
- torch.backends.cudnn.benchmark = True
100
-
101
- if args.seed is not None:
102
- set_random_seeds(args.seed)
103
-
104
- if args.non_interactive is True or training_utils.is_local_primary(args) is False:
105
- disable_tqdm = True
106
- elif sys.stderr.isatty() is False:
107
- disable_tqdm = True
108
- else:
109
- disable_tqdm = False
110
-
111
- # Enable or disable the autograd anomaly detection
112
- torch.autograd.set_detect_anomaly(args.grad_anomaly_detection)
113
-
114
85
  batch_size: int = args.batch_size
115
86
  grad_accum_steps: int = args.grad_accum_steps
116
- logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
87
+ logger.debug(f"Effective batch size = {batch_size * grad_accum_steps * args.world_size}")
117
88
 
118
89
  begin_epoch = 1
119
90
  epochs = args.epochs + 1
@@ -257,27 +228,31 @@ def train(args: argparse.Namespace) -> None:
257
228
 
258
229
  optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
259
230
  last_batch_idx = len(training_loader) - 1
231
+ logging.debug(f"Epoch has {last_batch_idx+1} iterations ({optimizer_steps_per_epoch} steps)")
260
232
 
261
233
  #
262
234
  # Loss criteria, optimizer, learning rate scheduler and training parameter groups
263
235
  #
264
236
 
237
+ # Learning rate scaling
238
+ lr = training_utils.scale_lr(args)
239
+
265
240
  # Training parameter groups
266
241
  custom_keys_weight_decay = training_utils.get_wd_custom_keys(args)
267
242
  parameters = training_utils.optimizer_parameter_groups(
268
243
  net,
269
244
  args.wd,
245
+ base_lr=lr,
270
246
  norm_weight_decay=args.norm_wd,
271
247
  custom_keys_weight_decay=custom_keys_weight_decay,
248
+ custom_layer_weight_decay=args.custom_layer_wd,
272
249
  layer_decay=args.layer_decay,
273
250
  layer_decay_min_scale=args.layer_decay_min_scale,
274
251
  layer_decay_no_opt_scale=args.layer_decay_no_opt_scale,
275
252
  bias_lr=args.bias_lr,
253
+ custom_layer_lr_scale=args.custom_layer_lr_scale,
276
254
  )
277
255
 
278
- # Learning rate scaling
279
- lr = training_utils.scale_lr(args)
280
-
281
256
  if args.lr_scheduler_update == "epoch":
282
257
  step_update = False
283
258
  scheduler_steps_per_epoch = 1
@@ -53,10 +53,6 @@ def train(args: argparse.Namespace) -> None:
53
53
  #
54
54
  # Initialize
55
55
  #
56
- training_utils.init_distributed_mode(args)
57
- logger.info(f"Starting training, birder version: {birder.__version__}, pytorch version: {torch.__version__}")
58
- training_utils.log_git_info()
59
-
60
56
  transform_dynamic_size = (
61
57
  args.multiscale is True
62
58
  or args.dynamic_size is True
@@ -66,6 +62,10 @@ def train(args: argparse.Namespace) -> None:
66
62
  )
67
63
  model_dynamic_size = transform_dynamic_size or args.batch_multiscale is True
68
64
 
65
+ (device, device_id, disable_tqdm) = training_utils.init_training(
66
+ args, logger, cudnn_dynamic_size=transform_dynamic_size
67
+ )
68
+
69
69
  if args.size is None:
70
70
  args.size = registry.get_default_size(args.network)
71
71
 
@@ -76,36 +76,6 @@ def train(args: argparse.Namespace) -> None:
76
76
  else:
77
77
  logger.info(f"Running with dynamic size, with base size={args.size}")
78
78
 
79
- if args.cpu is True:
80
- device = torch.device("cpu")
81
- device_id = 0
82
- else:
83
- device = torch.device("cuda")
84
- device_id = torch.cuda.current_device()
85
-
86
- if args.use_deterministic_algorithms is True:
87
- torch.backends.cudnn.benchmark = False
88
- torch.use_deterministic_algorithms(True)
89
- elif transform_dynamic_size is True:
90
- # Disable cuDNN for dynamic sizes to avoid per-size algorithm selection overhead
91
- torch.backends.cudnn.enabled = False
92
- else:
93
- torch.backends.cudnn.enabled = True
94
- torch.backends.cudnn.benchmark = True
95
-
96
- if args.seed is not None:
97
- lib.set_random_seeds(args.seed)
98
-
99
- if args.non_interactive is True or training_utils.is_local_primary(args) is False:
100
- disable_tqdm = True
101
- elif sys.stderr.isatty() is False:
102
- disable_tqdm = True
103
- else:
104
- disable_tqdm = False
105
-
106
- # Enable or disable the autograd anomaly detection
107
- torch.autograd.set_detect_anomaly(args.grad_anomaly_detection)
108
-
109
79
  #
110
80
  # Data
111
81
  #
@@ -113,7 +83,14 @@ def train(args: argparse.Namespace) -> None:
113
83
  logger.debug(f"Using RGB stats: {rgb_stats}")
114
84
 
115
85
  transforms = training_preset(
116
- args.size, args.aug_type, args.aug_level, rgb_stats, args.dynamic_size, args.multiscale, args.max_size
86
+ args.size,
87
+ args.aug_type,
88
+ args.aug_level,
89
+ rgb_stats,
90
+ args.dynamic_size,
91
+ args.multiscale,
92
+ args.max_size,
93
+ args.multiscale_min_size,
117
94
  )
118
95
  mosaic_dataset = None
119
96
  if args.mosaic_prob > 0.0:
@@ -125,6 +102,7 @@ def train(args: argparse.Namespace) -> None:
125
102
  args.dynamic_size,
126
103
  args.multiscale,
127
104
  args.max_size,
105
+ args.multiscale_min_size,
128
106
  post_mosaic=True,
129
107
  )
130
108
  if args.dynamic_size is True or args.multiscale is True:
@@ -194,13 +172,13 @@ def train(args: argparse.Namespace) -> None:
194
172
  batch_size: int = args.batch_size
195
173
  grad_accum_steps: int = args.grad_accum_steps
196
174
  model_ema_steps: int = args.model_ema_steps
197
- logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
175
+ logger.debug(f"Effective batch size = {batch_size * grad_accum_steps * args.world_size}")
198
176
 
199
177
  # Data loaders and samplers
200
178
  (train_sampler, validation_sampler) = training_utils.get_samplers(args, training_dataset, validation_dataset)
201
179
 
202
180
  if args.batch_multiscale is True:
203
- train_collate_fn: Any = BatchRandomResizeCollator(0, args.size)
181
+ train_collate_fn: Any = BatchRandomResizeCollator(0, args.size, multiscale_min_size=args.multiscale_min_size)
204
182
  else:
205
183
  train_collate_fn = training_collate_fn
206
184
 
@@ -236,6 +214,8 @@ def train(args: argparse.Namespace) -> None:
236
214
  else:
237
215
  args.stop_epoch += 1
238
216
 
217
+ logging.debug(f"Epoch has {last_batch_idx+1} iterations ({optimizer_steps_per_epoch} steps)")
218
+
239
219
  #
240
220
  # Initialize network
241
221
  #
@@ -354,23 +334,26 @@ def train(args: argparse.Namespace) -> None:
354
334
  # Loss criteria, optimizer, learning rate scheduler and training parameter groups
355
335
  #
356
336
 
337
+ # Learning rate scaling
338
+ lr = training_utils.scale_lr(args)
339
+
357
340
  # Training parameter groups
358
341
  custom_keys_weight_decay = training_utils.get_wd_custom_keys(args)
359
342
  parameters = training_utils.optimizer_parameter_groups(
360
343
  net,
361
344
  args.wd,
345
+ base_lr=lr,
362
346
  norm_weight_decay=args.norm_wd,
363
347
  custom_keys_weight_decay=custom_keys_weight_decay,
348
+ custom_layer_weight_decay=args.custom_layer_wd,
364
349
  layer_decay=args.layer_decay,
365
350
  layer_decay_min_scale=args.layer_decay_min_scale,
366
351
  layer_decay_no_opt_scale=args.layer_decay_no_opt_scale,
367
352
  bias_lr=args.bias_lr,
368
353
  backbone_lr=args.backbone_lr,
354
+ custom_layer_lr_scale=args.custom_layer_lr_scale,
369
355
  )
370
356
 
371
- # Learning rate scaling
372
- lr = training_utils.scale_lr(args)
373
-
374
357
  if args.lr_scheduler_update == "epoch":
375
358
  step_update = False
376
359
  scheduler_steps_per_epoch = 1
@@ -857,6 +840,31 @@ def get_args_parser() -> argparse.ArgumentParser:
857
840
  " --fast-matmul \\\n"
858
841
  " --compile-backbone \\\n"
859
842
  " --compile-opt\n"
843
+ "\n"
844
+ "YOLO v4 with custom anchors training example (COCO):\n"
845
+ "python train_detection.py \\\n"
846
+ " --network yolo_v4 \\\n"
847
+ " --model-config anchors=data/anchors.json \\\n"
848
+ " --tag coco \\\n"
849
+ " --backbone csp_darknet_53 \\\n"
850
+ " --backbone-model-config drop_block=0.1 \\\n"
851
+ " --lr 0.001 \\\n"
852
+ " --lr-scheduler multistep \\\n"
853
+ " --lr-steps 300 350 \\\n"
854
+ " --lr-step-gamma 0.1 \\\n"
855
+ " --batch-size 32 \\\n"
856
+ " --warmup-epochs 5 \\\n"
857
+ " --epochs 400 \\\n"
858
+ " --wd 0.0005 \\\n"
859
+ " --aug-level 5 \\\n"
860
+ " --mosaic-prob 0.5 --mosaic-stop-epoch 360 \\\n"
861
+ " --batch-multiscale \\\n"
862
+ " --amp --amp-dtype float16 \\\n"
863
+ " --data-path ~/Datasets/cocodataset/train2017 \\\n"
864
+ " --val-path ~/Datasets/cocodataset/val2017 \\\n"
865
+ " --coco-json-path ~/Datasets/cocodataset/annotations/instances_train2017.json \\\n"
866
+ " --coco-val-json-path ~/Datasets/cocodataset/annotations/instances_val2017.json \\\n"
867
+ " --class-file public_datasets_metadata/coco-classes.txt\n"
860
868
  ),
861
869
  formatter_class=cli.ArgumentHelpFormatter,
862
870
  )
@@ -39,7 +39,6 @@ from birder.common import training_utils
39
39
  from birder.common.lib import format_duration
40
40
  from birder.common.lib import get_mim_network_name
41
41
  from birder.common.lib import get_network_name
42
- from birder.common.lib import set_random_seeds
43
42
  from birder.conf import settings
44
43
  from birder.data.dataloader.webdataset import make_wds_loader
45
44
  from birder.data.datasets.directory import make_image_dataset
@@ -101,41 +100,13 @@ def train(args: argparse.Namespace) -> None:
101
100
  #
102
101
  # Initialize
103
102
  #
104
- training_utils.init_distributed_mode(args)
105
- logger.info(f"Starting training, birder version: {birder.__version__}, pytorch version: {torch.__version__}")
106
- training_utils.log_git_info()
103
+ (device, device_id, disable_tqdm) = training_utils.init_training(args, logger)
107
104
 
108
105
  if args.size is None:
109
106
  args.size = registry.get_default_size(args.network)
110
107
 
111
108
  logger.info(f"Using size={args.size}")
112
109
 
113
- if args.cpu is True:
114
- device = torch.device("cpu")
115
- device_id = 0
116
- else:
117
- device = torch.device("cuda")
118
- device_id = torch.cuda.current_device()
119
-
120
- if args.use_deterministic_algorithms is True:
121
- torch.backends.cudnn.benchmark = False
122
- torch.use_deterministic_algorithms(True)
123
- else:
124
- torch.backends.cudnn.benchmark = True
125
-
126
- if args.seed is not None:
127
- set_random_seeds(args.seed)
128
-
129
- if args.non_interactive is True or training_utils.is_local_primary(args) is False:
130
- disable_tqdm = True
131
- elif sys.stderr.isatty() is False:
132
- disable_tqdm = True
133
- else:
134
- disable_tqdm = False
135
-
136
- # Enable or disable the autograd anomaly detection
137
- torch.autograd.set_detect_anomaly(args.grad_anomaly_detection)
138
-
139
110
  #
140
111
  # Data
141
112
  #
@@ -187,7 +158,7 @@ def train(args: argparse.Namespace) -> None:
187
158
 
188
159
  batch_size: int = args.batch_size
189
160
  grad_accum_steps: int = args.grad_accum_steps
190
- logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
161
+ logger.debug(f"Effective batch size = {batch_size * grad_accum_steps * args.world_size}")
191
162
 
192
163
  # Data loaders and samplers
193
164
  if args.distributed is True:
@@ -228,6 +199,8 @@ def train(args: argparse.Namespace) -> None:
228
199
  else:
229
200
  args.stop_epoch += 1
230
201
 
202
+ logging.debug(f"Epoch has {last_batch_idx+1} iterations ({optimizer_steps_per_epoch} steps)")
203
+
231
204
  #
232
205
  # Initialize networks
233
206
  #
@@ -339,22 +312,25 @@ def train(args: argparse.Namespace) -> None:
339
312
  # Loss criteria, optimizer, learning rate scheduler and training parameter groups
340
313
  #
341
314
 
315
+ # Learning rate scaling
316
+ lr = training_utils.scale_lr(args)
317
+
342
318
  # Training parameter groups
343
319
  custom_keys_weight_decay = training_utils.get_wd_custom_keys(args)
344
320
  parameters = training_utils.optimizer_parameter_groups(
345
321
  student,
346
322
  args.wd,
323
+ base_lr=lr,
347
324
  norm_weight_decay=args.norm_wd,
348
325
  custom_keys_weight_decay=custom_keys_weight_decay,
326
+ custom_layer_weight_decay=args.custom_layer_wd,
349
327
  layer_decay=args.layer_decay,
350
328
  layer_decay_min_scale=args.layer_decay_min_scale,
351
329
  layer_decay_no_opt_scale=args.layer_decay_no_opt_scale,
352
330
  bias_lr=args.bias_lr,
331
+ custom_layer_lr_scale=args.custom_layer_lr_scale,
353
332
  )
354
333
 
355
- # Learning rate scaling
356
- lr = training_utils.scale_lr(args)
357
-
358
334
  if args.lr_scheduler_update == "epoch":
359
335
  step_update = False
360
336
  scheduler_steps_per_epoch = 1
@@ -36,7 +36,6 @@ from birder.common import training_utils
36
36
  from birder.common.lib import format_duration
37
37
  from birder.common.lib import get_mim_network_name
38
38
  from birder.common.lib import get_network_name
39
- from birder.common.lib import set_random_seeds
40
39
  from birder.common.masking import BlockMasking
41
40
  from birder.conf import settings
42
41
  from birder.data.dataloader.webdataset import make_wds_loader
@@ -178,44 +177,16 @@ def train(args: argparse.Namespace) -> None:
178
177
  #
179
178
  # Initialize
180
179
  #
181
- training_utils.init_distributed_mode(args)
182
- logger.info(f"Starting training, birder version: {birder.__version__}, pytorch version: {torch.__version__}")
183
- training_utils.log_git_info()
180
+ (device, device_id, disable_tqdm) = training_utils.init_training(args, logger)
184
181
 
185
182
  if args.size is None:
186
183
  args.size = registry.get_default_size(args.network)
187
184
 
188
185
  logger.info(f"Using size={args.size}")
189
186
 
190
- if args.cpu is True:
191
- device = torch.device("cpu")
192
- device_id = 0
193
- else:
194
- device = torch.device("cuda")
195
- device_id = torch.cuda.current_device()
196
-
197
- if args.use_deterministic_algorithms is True:
198
- torch.backends.cudnn.benchmark = False
199
- torch.use_deterministic_algorithms(True)
200
- else:
201
- torch.backends.cudnn.benchmark = True
202
-
203
- if args.seed is not None:
204
- set_random_seeds(args.seed)
205
-
206
- if args.non_interactive is True or training_utils.is_local_primary(args) is False:
207
- disable_tqdm = True
208
- elif sys.stderr.isatty() is False:
209
- disable_tqdm = True
210
- else:
211
- disable_tqdm = False
212
-
213
- # Enable or disable the autograd anomaly detection
214
- torch.autograd.set_detect_anomaly(args.grad_anomaly_detection)
215
-
216
187
  batch_size: int = args.batch_size
217
188
  grad_accum_steps: int = args.grad_accum_steps
218
- logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
189
+ logger.debug(f"Effective batch size = {batch_size * grad_accum_steps * args.world_size}")
219
190
 
220
191
  begin_epoch = 1
221
192
  epochs = args.epochs + 1
@@ -420,27 +391,31 @@ def train(args: argparse.Namespace) -> None:
420
391
 
421
392
  optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
422
393
  last_batch_idx = len(training_loader) - 1
394
+ logging.debug(f"Epoch has {last_batch_idx+1} iterations ({optimizer_steps_per_epoch} steps)")
423
395
 
424
396
  #
425
397
  # Optimizer, learning rate scheduler and training parameter groups
426
398
  #
427
399
 
400
+ # Learning rate scaling
401
+ lr = training_utils.scale_lr(args)
402
+
428
403
  # Training parameter groups
429
404
  custom_keys_weight_decay = training_utils.get_wd_custom_keys(args)
430
405
  parameters = training_utils.optimizer_parameter_groups(
431
406
  net,
432
407
  args.wd,
408
+ base_lr=lr,
433
409
  norm_weight_decay=args.norm_wd,
434
410
  custom_keys_weight_decay=custom_keys_weight_decay,
411
+ custom_layer_weight_decay=args.custom_layer_wd,
435
412
  layer_decay=args.layer_decay,
436
413
  layer_decay_min_scale=args.layer_decay_min_scale,
437
414
  layer_decay_no_opt_scale=args.layer_decay_no_opt_scale,
438
415
  bias_lr=args.bias_lr,
416
+ custom_layer_lr_scale=args.custom_layer_lr_scale,
439
417
  )
440
418
 
441
- # Learning rate scaling
442
- lr = training_utils.scale_lr(args)
443
-
444
419
  if args.lr_scheduler_update == "epoch":
445
420
  step_update = False
446
421
  scheduler_steps_per_epoch = 1