birder 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. birder/common/lib.py +2 -9
  2. birder/common/training_cli.py +18 -0
  3. birder/common/training_utils.py +123 -10
  4. birder/data/collators/detection.py +10 -3
  5. birder/data/datasets/coco.py +8 -10
  6. birder/data/transforms/detection.py +30 -13
  7. birder/inference/detection.py +108 -4
  8. birder/inference/wbf.py +226 -0
  9. birder/net/__init__.py +8 -0
  10. birder/net/detection/efficientdet.py +65 -86
  11. birder/net/detection/rt_detr_v1.py +1 -0
  12. birder/net/detection/yolo_anchors.py +205 -0
  13. birder/net/detection/yolo_v2.py +25 -24
  14. birder/net/detection/yolo_v3.py +39 -40
  15. birder/net/detection/yolo_v4.py +28 -26
  16. birder/net/detection/yolo_v4_tiny.py +24 -20
  17. birder/net/fasternet.py +1 -1
  18. birder/net/gc_vit.py +671 -0
  19. birder/net/lit_v1.py +472 -0
  20. birder/net/lit_v1_tiny.py +342 -0
  21. birder/net/lit_v2.py +436 -0
  22. birder/net/mobilenet_v4_hybrid.py +1 -1
  23. birder/net/resnet_v1.py +1 -1
  24. birder/net/resnext.py +67 -25
  25. birder/net/se_resnet_v1.py +46 -0
  26. birder/net/se_resnext.py +3 -0
  27. birder/net/simple_vit.py +2 -2
  28. birder/net/vit.py +0 -15
  29. birder/net/vovnet_v2.py +31 -1
  30. birder/scripts/benchmark.py +90 -21
  31. birder/scripts/predict.py +1 -0
  32. birder/scripts/predict_detection.py +18 -11
  33. birder/scripts/train.py +10 -34
  34. birder/scripts/train_barlow_twins.py +10 -34
  35. birder/scripts/train_byol.py +10 -34
  36. birder/scripts/train_capi.py +10 -35
  37. birder/scripts/train_data2vec.py +9 -34
  38. birder/scripts/train_data2vec2.py +9 -34
  39. birder/scripts/train_detection.py +48 -40
  40. birder/scripts/train_dino_v1.py +10 -34
  41. birder/scripts/train_dino_v2.py +9 -34
  42. birder/scripts/train_dino_v2_dist.py +9 -34
  43. birder/scripts/train_franca.py +9 -34
  44. birder/scripts/train_i_jepa.py +9 -34
  45. birder/scripts/train_ibot.py +9 -34
  46. birder/scripts/train_kd.py +156 -64
  47. birder/scripts/train_mim.py +10 -34
  48. birder/scripts/train_mmcr.py +10 -34
  49. birder/scripts/train_rotnet.py +10 -34
  50. birder/scripts/train_simclr.py +10 -34
  51. birder/scripts/train_vicreg.py +10 -34
  52. birder/tools/auto_anchors.py +20 -1
  53. birder/tools/pack.py +172 -103
  54. birder/tools/show_det_iterator.py +10 -1
  55. birder/version.py +1 -1
  56. {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/METADATA +3 -3
  57. {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/RECORD +61 -55
  58. {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/WHEEL +0 -0
  59. {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/entry_points.txt +0 -0
  60. {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/licenses/LICENSE +0 -0
  61. {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/top_level.txt +0 -0
@@ -4,6 +4,7 @@ Supports:
4
4
  * Logits matching (Soft distillation), https://arxiv.org/abs/1503.02531
5
5
  * Hard-label distillation, https://arxiv.org/pdf/2012.12877
6
6
  * Distillation token, https://arxiv.org/pdf/2012.12877
7
+ * Embedding matching (L2-normalized MSE)
7
8
  """
8
9
 
9
10
  import argparse
@@ -16,6 +17,7 @@ import typing
16
17
  from pathlib import Path
17
18
  from typing import Any
18
19
  from typing import Literal
20
+ from typing import Optional
19
21
 
20
22
  import matplotlib.pyplot as plt
21
23
  import numpy as np
@@ -39,7 +41,6 @@ from birder.common import training_cli
39
41
  from birder.common import training_utils
40
42
  from birder.common.lib import format_duration
41
43
  from birder.common.lib import get_network_name
42
- from birder.common.lib import set_random_seeds
43
44
  from birder.conf import settings
44
45
  from birder.data.dataloader.webdataset import make_wds_loader
45
46
  from birder.data.datasets.directory import HierarchicalImageFolder
@@ -55,7 +56,18 @@ from birder.net.base import get_signature
55
56
 
56
57
  logger = logging.getLogger(__name__)
57
58
 
58
- DistType = Literal["soft", "hard", "deit"]
59
+ DistType = Literal["soft", "hard", "deit", "embedding"]
60
+
61
+
62
+ class EmbeddingDistillWrapper(torch.nn.Module):
63
+ def __init__(self, model: torch.nn.Module) -> None:
64
+ super().__init__()
65
+ self.model = model
66
+
67
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
68
+ embedding = self.model.embedding(x)
69
+ outputs = self.model.classify(embedding)
70
+ return (outputs, embedding)
59
71
 
60
72
 
61
73
  # pylint: disable=too-many-locals,too-many-branches,too-many-statements
@@ -63,41 +75,11 @@ def train(args: argparse.Namespace) -> None:
63
75
  #
64
76
  # Initialize
65
77
  #
66
- training_utils.init_distributed_mode(args)
67
- logger.info(f"Starting training, birder version: {birder.__version__}, pytorch version: {torch.__version__}")
68
- training_utils.log_git_info()
78
+ (device, device_id, disable_tqdm) = training_utils.init_training(args, logger)
69
79
 
70
80
  if args.type != "soft":
71
81
  args.temperature = 1.0
72
82
 
73
- logger.info(f"Using size={args.size}")
74
-
75
- if args.cpu is True:
76
- device = torch.device("cpu")
77
- device_id = 0
78
- else:
79
- device = torch.device("cuda")
80
- device_id = torch.cuda.current_device()
81
-
82
- if args.use_deterministic_algorithms is True:
83
- torch.backends.cudnn.benchmark = False
84
- torch.use_deterministic_algorithms(True)
85
- else:
86
- torch.backends.cudnn.benchmark = True
87
-
88
- if args.seed is not None:
89
- set_random_seeds(args.seed)
90
-
91
- if args.non_interactive is True or training_utils.is_local_primary(args) is False:
92
- disable_tqdm = True
93
- elif sys.stderr.isatty() is False:
94
- disable_tqdm = True
95
- else:
96
- disable_tqdm = False
97
-
98
- # Enable or disable the autograd anomaly detection
99
- torch.autograd.set_detect_anomaly(args.grad_anomaly_detection)
100
-
101
83
  # Using the teacher rgb values for the student
102
84
  (teacher, (class_to_idx, signature, rgb_stats, *_)) = fs_ops.load_model(
103
85
  device,
@@ -112,7 +94,8 @@ def train(args: argparse.Namespace) -> None:
112
94
  )
113
95
  if args.size is None:
114
96
  args.size = lib.get_size_from_signature(signature)
115
- logger.debug(f"Using size={args.size}")
97
+
98
+ logger.info(f"Using size={args.size}")
116
99
 
117
100
  #
118
101
  # Data
@@ -188,7 +171,7 @@ def train(args: argparse.Namespace) -> None:
188
171
  batch_size: int = args.batch_size
189
172
  grad_accum_steps: int = args.grad_accum_steps
190
173
  model_ema_steps: int = args.model_ema_steps
191
- logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
174
+ logger.debug(f"Effective batch size = {batch_size * grad_accum_steps * args.world_size}")
192
175
 
193
176
  # Set data iterators
194
177
  if args.mixup_alpha is not None or args.cutmix is True:
@@ -258,6 +241,8 @@ def train(args: argparse.Namespace) -> None:
258
241
  else:
259
242
  args.stop_epoch += 1
260
243
 
244
+ logging.debug(f"Epoch has {last_batch_idx+1} iterations ({optimizer_steps_per_epoch} steps)")
245
+
261
246
  #
262
247
  # Initialize networks
263
248
  #
@@ -298,33 +283,61 @@ def train(args: argparse.Namespace) -> None:
298
283
  if args.fast_matmul is True or args.amp is True:
299
284
  torch.set_float32_matmul_precision("high")
300
285
 
301
- # Compile networks
302
- if args.compile is True:
303
- teacher = torch.compile(teacher)
304
- student = torch.compile(student)
305
- elif args.compile_teacher is True:
306
- teacher = torch.compile(teacher)
286
+ distillation_type: DistType = args.type
287
+ embedding_projection: Optional[torch.nn.Module] = None
288
+ if distillation_type == "embedding":
289
+ if student.embedding_size == teacher.embedding_size:
290
+ embedding_projection = torch.nn.Identity()
291
+ else:
292
+ logger.info(
293
+ f"Creating embedding projection layer from {student.embedding_size} to {teacher.embedding_size}"
294
+ )
295
+ embedding_projection = torch.nn.Linear(student.embedding_size, teacher.embedding_size)
296
+
297
+ embedding_projection.to(device, dtype=model_dtype)
298
+ if training_states.extra_states is not None:
299
+ projection_state = training_states.extra_states.get("embedding_projection")
300
+ if projection_state is not None:
301
+ embedding_projection.load_state_dict(projection_state)
307
302
 
308
303
  #
309
304
  # Loss criteria, optimizer, learning rate scheduler and training parameter groups
310
305
  #
311
306
 
307
+ # Learning rate scaling
308
+ lr = training_utils.scale_lr(args)
309
+
312
310
  # Training parameter groups and loss criteria
313
311
  custom_keys_weight_decay = training_utils.get_wd_custom_keys(args)
314
312
  parameters = training_utils.optimizer_parameter_groups(
315
313
  student,
316
314
  args.wd,
315
+ base_lr=lr,
317
316
  norm_weight_decay=args.norm_wd,
318
317
  custom_keys_weight_decay=custom_keys_weight_decay,
318
+ custom_layer_weight_decay=args.custom_layer_wd,
319
319
  layer_decay=args.layer_decay,
320
320
  layer_decay_min_scale=args.layer_decay_min_scale,
321
321
  layer_decay_no_opt_scale=args.layer_decay_no_opt_scale,
322
322
  bias_lr=args.bias_lr,
323
+ custom_layer_lr_scale=args.custom_layer_lr_scale,
323
324
  )
325
+ if embedding_projection is not None:
326
+ projection_parameters = training_utils.optimizer_parameter_groups(
327
+ embedding_projection,
328
+ args.wd,
329
+ base_lr=lr,
330
+ norm_weight_decay=args.norm_wd,
331
+ custom_keys_weight_decay=custom_keys_weight_decay,
332
+ custom_layer_weight_decay=args.custom_layer_wd,
333
+ bias_lr=args.bias_lr,
334
+ custom_layer_lr_scale=args.custom_layer_lr_scale,
335
+ )
336
+ parameters.extend(projection_parameters)
337
+
324
338
  criterion = torch.nn.CrossEntropyLoss(label_smoothing=args.smoothing_alpha)
325
339
 
326
340
  # Distillation
327
- distillation_type: DistType = args.type
328
341
  if distillation_type == "soft":
329
342
  distillation_criterion = torch.nn.KLDivLoss(reduction="batchmean", log_target=False)
330
343
  elif distillation_type == "hard":
@@ -332,11 +345,11 @@ def train(args: argparse.Namespace) -> None:
332
345
  elif distillation_type == "deit":
333
346
  distillation_criterion = torch.nn.CrossEntropyLoss()
334
347
  student.set_distillation_output()
348
+ elif distillation_type == "embedding":
349
+ distillation_criterion = torch.nn.MSELoss()
335
350
  else:
336
351
  raise ValueError(f"Unknown KD type: {args.type}")
337
352
 
338
- # Learning rate scaling
339
- lr = training_utils.scale_lr(args)
340
353
  if args.lr_scheduler_update == "epoch":
341
354
  step_update = False
342
355
  scheduler_steps_per_epoch = 1
@@ -398,12 +411,50 @@ def train(args: argparse.Namespace) -> None:
398
411
  ema_warmup_steps = 0
399
412
 
400
413
  logger.debug(f"EMA warmup steps = {ema_warmup_steps}")
414
+ train_student = student
415
+ if distillation_type == "embedding":
416
+ train_student = EmbeddingDistillWrapper(student)
417
+
418
+ # Compile networks
419
+ if args.compile is True:
420
+ train_student = torch.compile(train_student)
421
+ if distillation_type == "embedding":
422
+ teacher.embedding = torch.compile(teacher.embedding)
423
+ embedding_projection = torch.compile(embedding_projection)
424
+ student = torch.compile(student) # For validation
425
+ else:
426
+ teacher = torch.compile(teacher)
427
+ student = train_student
428
+
429
+ elif args.compile_teacher is True:
430
+ if distillation_type == "embedding":
431
+ teacher.embedding = torch.compile(teacher.embedding)
432
+ else:
433
+ teacher = torch.compile(teacher)
434
+
401
435
  net_without_ddp = student
402
436
  if args.distributed is True:
403
- student = torch.nn.parallel.DistributedDataParallel(
404
- student, device_ids=[args.local_rank], find_unused_parameters=args.find_unused_parameters
437
+ train_student = torch.nn.parallel.DistributedDataParallel(
438
+ train_student, device_ids=[args.local_rank], find_unused_parameters=args.find_unused_parameters
405
439
  )
406
- net_without_ddp = student.module
440
+ if distillation_type != "embedding":
441
+ net_without_ddp = train_student.module
442
+
443
+ embedding_projection_to_save = None
444
+ if embedding_projection is not None:
445
+ if args.distributed is True and any(p.requires_grad for p in embedding_projection.parameters()):
446
+ embedding_projection = torch.nn.parallel.DistributedDataParallel(
447
+ embedding_projection,
448
+ device_ids=[args.local_rank],
449
+ find_unused_parameters=args.find_unused_parameters,
450
+ )
451
+ embedding_projection_to_save = embedding_projection.module
452
+ else:
453
+ embedding_projection_to_save = embedding_projection
454
+
455
+ # Unwrap compiled module for saving
456
+ if hasattr(embedding_projection_to_save, "_orig_mod"):
457
+ embedding_projection_to_save = embedding_projection_to_save._orig_mod # pylint: disable=protected-access
407
458
 
408
459
  if args.model_ema is True:
409
460
  model_base = net_without_ddp # Original model without DDP wrapper, will be saved as training state
@@ -499,7 +550,10 @@ def train(args: argparse.Namespace) -> None:
499
550
  logger.info(f"Starting training with learning rate of {last_lr}")
500
551
  for epoch in range(begin_epoch, args.stop_epoch):
501
552
  tic = time.time()
502
- student.train()
553
+ train_student.train()
554
+ if embedding_projection is not None:
555
+ embedding_projection.train()
556
+
503
557
  running_loss = training_utils.SmoothedValue(window_size=64)
504
558
  running_val_loss = training_utils.SmoothedValue()
505
559
  train_accuracy = training_utils.SmoothedValue(window_size=64)
@@ -531,22 +585,37 @@ def train(args: argparse.Namespace) -> None:
531
585
 
532
586
  # Forward, backward and optimize
533
587
  with torch.amp.autocast("cuda", enabled=args.amp, dtype=amp_dtype):
534
- with torch.inference_mode():
535
- teacher_outputs = teacher(inputs)
536
-
537
- softmax_teacher = F.softmax(teacher_outputs / args.temperature, dim=-1)
538
- if distillation_type == "soft":
539
- outputs = student(inputs)
540
- dist_output = F.log_softmax(outputs / args.temperature, dim=-1)
541
- elif distillation_type == "hard":
542
- outputs = student(inputs)
543
- dist_output = outputs
544
- elif distillation_type == "deit":
545
- (outputs, dist_output) = torch.unbind(student(inputs), dim=1)
588
+ if distillation_type == "embedding":
589
+ with torch.no_grad():
590
+ teacher_embedding = teacher.embedding(inputs)
591
+ teacher_embedding = F.normalize(teacher_embedding, dim=-1)
592
+
593
+ (outputs, student_embedding) = train_student(inputs)
594
+ student_embedding = embedding_projection(student_embedding) # type: ignore[misc]
595
+ student_embedding = F.normalize(student_embedding, dim=-1)
596
+ dist_loss = distillation_criterion(student_embedding, teacher_embedding)
597
+
546
598
  else:
547
- raise RuntimeError
599
+ with torch.no_grad():
600
+ teacher_outputs = teacher(inputs)
601
+ if distillation_type == "soft":
602
+ teacher_targets = F.softmax(teacher_outputs / args.temperature, dim=-1)
603
+ else:
604
+ teacher_targets = teacher_outputs.argmax(dim=-1)
605
+
606
+ if distillation_type == "soft":
607
+ outputs = train_student(inputs)
608
+ dist_output = F.log_softmax(outputs / args.temperature, dim=-1)
609
+ dist_loss = distillation_criterion(dist_output, teacher_targets) * (args.temperature**2)
610
+ elif distillation_type == "hard":
611
+ outputs = train_student(inputs)
612
+ dist_loss = distillation_criterion(outputs, teacher_targets)
613
+ elif distillation_type == "deit":
614
+ (outputs, dist_output) = torch.unbind(train_student(inputs), dim=1)
615
+ dist_loss = distillation_criterion(dist_output, teacher_targets)
616
+ else:
617
+ raise RuntimeError
548
618
 
549
- dist_loss = distillation_criterion(dist_output, softmax_teacher) * (args.temperature**2)
550
619
  target_loss = criterion(outputs, targets)
551
620
  loss = (1 - args.lambda_param) * target_loss + (args.lambda_param * dist_loss)
552
621
 
@@ -555,7 +624,11 @@ def train(args: argparse.Namespace) -> None:
555
624
  if optimizer_update is True:
556
625
  if args.clip_grad_norm is not None:
557
626
  scaler.unscale_(optimizer)
558
- torch.nn.utils.clip_grad_norm_(student.parameters(), args.clip_grad_norm)
627
+ params = list(train_student.parameters())
628
+ if embedding_projection is not None:
629
+ params += list(embedding_projection.parameters())
630
+
631
+ torch.nn.utils.clip_grad_norm_(params, args.clip_grad_norm)
559
632
 
560
633
  scaler.step(optimizer)
561
634
  scaler.update()
@@ -567,7 +640,11 @@ def train(args: argparse.Namespace) -> None:
567
640
  loss.backward()
568
641
  if optimizer_update is True:
569
642
  if args.clip_grad_norm is not None:
570
- torch.nn.utils.clip_grad_norm_(student.parameters(), args.clip_grad_norm)
643
+ params = list(train_student.parameters())
644
+ if embedding_projection is not None:
645
+ params += list(embedding_projection.parameters())
646
+
647
+ torch.nn.utils.clip_grad_norm_(params, args.clip_grad_norm)
571
648
 
572
649
  optimizer.step()
573
650
  optimizer.zero_grad()
@@ -710,6 +787,10 @@ def train(args: argparse.Namespace) -> None:
710
787
  if training_utils.is_local_primary(args) is True:
711
788
  # Checkpoint model
712
789
  if epoch % args.save_frequency == 0:
790
+ extra_states = {}
791
+ if embedding_projection_to_save is not None:
792
+ extra_states["embedding_projection"] = embedding_projection_to_save.state_dict()
793
+
713
794
  fs_ops.checkpoint_model(
714
795
  student_name,
715
796
  epoch,
@@ -721,6 +802,7 @@ def train(args: argparse.Namespace) -> None:
721
802
  scheduler,
722
803
  scaler,
723
804
  model_base,
805
+ **extra_states,
724
806
  )
725
807
  if args.keep_last is not None:
726
808
  fs_ops.clean_checkpoints(student_name, args.keep_last)
@@ -766,6 +848,10 @@ def train(args: argparse.Namespace) -> None:
766
848
 
767
849
  # Checkpoint model
768
850
  if training_utils.is_local_primary(args) is True:
851
+ extra_states = {}
852
+ if embedding_projection_to_save is not None:
853
+ extra_states["embedding_projection"] = embedding_projection_to_save.state_dict()
854
+
769
855
  fs_ops.checkpoint_model(
770
856
  student_name,
771
857
  epoch,
@@ -777,6 +863,7 @@ def train(args: argparse.Namespace) -> None:
777
863
  scheduler,
778
864
  scaler,
779
865
  model_base,
866
+ **extra_states,
780
867
  )
781
868
 
782
869
  training_utils.shutdown_distributed_mode(args)
@@ -896,6 +983,8 @@ def validate_args(args: argparse.Namespace) -> None:
896
983
  training_cli.common_args_validation(args)
897
984
 
898
985
  # Script specific checks
986
+ if args.type is None:
987
+ raise cli.ValidationError("--type is required")
899
988
  if args.teacher is None:
900
989
  raise cli.ValidationError("--teacher is required")
901
990
  if args.student is None:
@@ -905,6 +994,9 @@ def validate_args(args: argparse.Namespace) -> None:
905
994
  if registry.exists(args.student, task=Task.IMAGE_CLASSIFICATION) is False:
906
995
  raise cli.ValidationError(f"--student {args.student} not supported, see list-models tool for available options")
907
996
 
997
+ if args.type == "embedding" and (args.pts is True or args.pt2 is True):
998
+ raise cli.ValidationError("--type embedding does not support --pts or --pt2 teachers")
999
+
908
1000
  if args.smoothing_alpha < 0 or args.smoothing_alpha >= 0.5:
909
1001
  raise cli.ValidationError(f"--smoothing-alpha must be in range of [0, 0.5), got {args.smoothing_alpha}")
910
1002
 
@@ -25,7 +25,6 @@ from birder.common import training_utils
25
25
  from birder.common.lib import format_duration
26
26
  from birder.common.lib import get_mim_network_name
27
27
  from birder.common.lib import get_network_name
28
- from birder.common.lib import set_random_seeds
29
28
  from birder.conf import settings
30
29
  from birder.data.dataloader.webdataset import make_wds_loader
31
30
  from birder.data.datasets.directory import make_image_dataset
@@ -49,9 +48,7 @@ def train(args: argparse.Namespace) -> None:
49
48
  #
50
49
  # Initialize
51
50
  #
52
- training_utils.init_distributed_mode(args)
53
- logger.info(f"Starting training, birder version: {birder.__version__}, pytorch version: {torch.__version__}")
54
- training_utils.log_git_info()
51
+ (device, device_id, disable_tqdm) = training_utils.init_training(args, logger)
55
52
 
56
53
  if args.size is None:
57
54
  # Prefer mim size over encoder default size
@@ -59,32 +56,6 @@ def train(args: argparse.Namespace) -> None:
59
56
 
60
57
  logger.info(f"Using size={args.size}")
61
58
 
62
- if args.cpu is True:
63
- device = torch.device("cpu")
64
- device_id = 0
65
- else:
66
- device = torch.device("cuda")
67
- device_id = torch.cuda.current_device()
68
-
69
- if args.use_deterministic_algorithms is True:
70
- torch.backends.cudnn.benchmark = False
71
- torch.use_deterministic_algorithms(True)
72
- else:
73
- torch.backends.cudnn.benchmark = True
74
-
75
- if args.seed is not None:
76
- set_random_seeds(args.seed)
77
-
78
- if args.non_interactive is True or training_utils.is_local_primary(args) is False:
79
- disable_tqdm = True
80
- elif sys.stderr.isatty() is False:
81
- disable_tqdm = True
82
- else:
83
- disable_tqdm = False
84
-
85
- # Enable or disable the autograd anomaly detection
86
- torch.autograd.set_detect_anomaly(args.grad_anomaly_detection)
87
-
88
59
  #
89
60
  # Data
90
61
  #
@@ -131,7 +102,7 @@ def train(args: argparse.Namespace) -> None:
131
102
 
132
103
  batch_size: int = args.batch_size
133
104
  grad_accum_steps: int = args.grad_accum_steps
134
- logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
105
+ logger.debug(f"Effective batch size = {batch_size * grad_accum_steps * args.world_size}")
135
106
 
136
107
  # Data loaders and samplers
137
108
  if args.distributed is True:
@@ -172,6 +143,8 @@ def train(args: argparse.Namespace) -> None:
172
143
  else:
173
144
  args.stop_epoch += 1
174
145
 
146
+ logging.debug(f"Epoch has {last_batch_idx+1} iterations ({optimizer_steps_per_epoch} steps)")
147
+
175
148
  #
176
149
  # Initialize network
177
150
  #
@@ -241,22 +214,25 @@ def train(args: argparse.Namespace) -> None:
241
214
  # Loss criteria, optimizer, learning rate scheduler and training parameter groups
242
215
  #
243
216
 
217
+ # Learning rate scaling
218
+ lr = training_utils.scale_lr(args)
219
+
244
220
  # Training parameter groups
245
221
  custom_keys_weight_decay = training_utils.get_wd_custom_keys(args)
246
222
  parameters = training_utils.optimizer_parameter_groups(
247
223
  net,
248
224
  args.wd,
225
+ base_lr=lr,
249
226
  norm_weight_decay=args.norm_wd,
250
227
  custom_keys_weight_decay=custom_keys_weight_decay,
228
+ custom_layer_weight_decay=args.custom_layer_wd,
251
229
  layer_decay=args.layer_decay,
252
230
  layer_decay_min_scale=args.layer_decay_min_scale,
253
231
  layer_decay_no_opt_scale=args.layer_decay_no_opt_scale,
254
232
  bias_lr=args.bias_lr,
233
+ custom_layer_lr_scale=args.custom_layer_lr_scale,
255
234
  )
256
235
 
257
- # Learning rate scaling
258
- lr = training_utils.scale_lr(args)
259
-
260
236
  if args.lr_scheduler_update == "epoch":
261
237
  step_update = False
262
238
  scheduler_steps_per_epoch = 1
@@ -36,7 +36,6 @@ from birder.common import training_utils
36
36
  from birder.common.lib import format_duration
37
37
  from birder.common.lib import get_mim_network_name
38
38
  from birder.common.lib import get_network_name
39
- from birder.common.lib import set_random_seeds
40
39
  from birder.conf import settings
41
40
  from birder.data.dataloader.webdataset import make_wds_loader
42
41
  from birder.data.datasets.directory import make_image_dataset
@@ -74,41 +73,13 @@ def train(args: argparse.Namespace) -> None:
74
73
  #
75
74
  # Initialize
76
75
  #
77
- training_utils.init_distributed_mode(args)
78
- logger.info(f"Starting training, birder version: {birder.__version__}, pytorch version: {torch.__version__}")
79
- training_utils.log_git_info()
76
+ (device, device_id, disable_tqdm) = training_utils.init_training(args, logger)
80
77
 
81
78
  if args.size is None:
82
79
  args.size = registry.get_default_size(args.network)
83
80
 
84
81
  logger.info(f"Using size={args.size}")
85
82
 
86
- if args.cpu is True:
87
- device = torch.device("cpu")
88
- device_id = 0
89
- else:
90
- device = torch.device("cuda")
91
- device_id = torch.cuda.current_device()
92
-
93
- if args.use_deterministic_algorithms is True:
94
- torch.backends.cudnn.benchmark = False
95
- torch.use_deterministic_algorithms(True)
96
- else:
97
- torch.backends.cudnn.benchmark = True
98
-
99
- if args.seed is not None:
100
- set_random_seeds(args.seed)
101
-
102
- if args.non_interactive is True or training_utils.is_local_primary(args) is False:
103
- disable_tqdm = True
104
- elif sys.stderr.isatty() is False:
105
- disable_tqdm = True
106
- else:
107
- disable_tqdm = False
108
-
109
- # Enable or disable the autograd anomaly detection
110
- torch.autograd.set_detect_anomaly(args.grad_anomaly_detection)
111
-
112
83
  #
113
84
  # Data
114
85
  #
@@ -155,7 +126,7 @@ def train(args: argparse.Namespace) -> None:
155
126
 
156
127
  batch_size: int = args.batch_size
157
128
  grad_accum_steps: int = args.grad_accum_steps
158
- logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
129
+ logger.debug(f"Effective batch size = {batch_size * grad_accum_steps * args.world_size}")
159
130
 
160
131
  # Data loaders and samplers
161
132
  if args.distributed is True:
@@ -196,6 +167,8 @@ def train(args: argparse.Namespace) -> None:
196
167
  else:
197
168
  args.stop_epoch += 1
198
169
 
170
+ logging.debug(f"Epoch has {last_batch_idx+1} iterations ({optimizer_steps_per_epoch} steps)")
171
+
199
172
  #
200
173
  # Initialize network
201
174
  #
@@ -243,22 +216,25 @@ def train(args: argparse.Namespace) -> None:
243
216
  # Loss
244
217
  mmcr_loss = MMCRMomentumLoss(args.lambda_coeff, n_aug=args.n_aug)
245
218
 
219
+ # Learning rate scaling
220
+ lr = training_utils.scale_lr(args)
221
+
246
222
  # Training parameter groups
247
223
  custom_keys_weight_decay = training_utils.get_wd_custom_keys(args)
248
224
  parameters = training_utils.optimizer_parameter_groups(
249
225
  net,
250
226
  args.wd,
227
+ base_lr=lr,
251
228
  norm_weight_decay=args.norm_wd,
252
229
  custom_keys_weight_decay=custom_keys_weight_decay,
230
+ custom_layer_weight_decay=args.custom_layer_wd,
253
231
  layer_decay=args.layer_decay,
254
232
  layer_decay_min_scale=args.layer_decay_min_scale,
255
233
  layer_decay_no_opt_scale=args.layer_decay_no_opt_scale,
256
234
  bias_lr=args.bias_lr,
235
+ custom_layer_lr_scale=args.custom_layer_lr_scale,
257
236
  )
258
237
 
259
- # Learning rate scaling
260
- lr = training_utils.scale_lr(args)
261
-
262
238
  if args.lr_scheduler_update == "epoch":
263
239
  step_update = False
264
240
  scheduler_steps_per_epoch = 1
@@ -31,7 +31,6 @@ from birder.common import training_cli
31
31
  from birder.common import training_utils
32
32
  from birder.common.lib import format_duration
33
33
  from birder.common.lib import get_network_name
34
- from birder.common.lib import set_random_seeds
35
34
  from birder.conf import settings
36
35
  from birder.data.dataloader.webdataset import make_wds_loader
37
36
  from birder.data.datasets.directory import make_image_dataset
@@ -83,41 +82,13 @@ def train(args: argparse.Namespace) -> None:
83
82
  #
84
83
  # Initialize
85
84
  #
86
- training_utils.init_distributed_mode(args)
87
- logger.info(f"Starting training, birder version: {birder.__version__}, pytorch version: {torch.__version__}")
88
- training_utils.log_git_info()
85
+ (device, device_id, disable_tqdm) = training_utils.init_training(args, logger)
89
86
 
90
87
  if args.size is None:
91
88
  args.size = registry.get_default_size(args.network)
92
89
 
93
90
  logger.info(f"Using size={args.size}")
94
91
 
95
- if args.cpu is True:
96
- device = torch.device("cpu")
97
- device_id = 0
98
- else:
99
- device = torch.device("cuda")
100
- device_id = torch.cuda.current_device()
101
-
102
- if args.use_deterministic_algorithms is True:
103
- torch.backends.cudnn.benchmark = False
104
- torch.use_deterministic_algorithms(True)
105
- else:
106
- torch.backends.cudnn.benchmark = True
107
-
108
- if args.seed is not None:
109
- set_random_seeds(args.seed)
110
-
111
- if args.non_interactive is True or training_utils.is_local_primary(args) is False:
112
- disable_tqdm = True
113
- elif sys.stderr.isatty() is False:
114
- disable_tqdm = True
115
- else:
116
- disable_tqdm = False
117
-
118
- # Enable or disable the autograd anomaly detection
119
- torch.autograd.set_detect_anomaly(args.grad_anomaly_detection)
120
-
121
92
  #
122
93
  # Data
123
94
  #
@@ -169,7 +140,7 @@ def train(args: argparse.Namespace) -> None:
169
140
 
170
141
  batch_size: int = args.batch_size
171
142
  grad_accum_steps: int = args.grad_accum_steps
172
- logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
143
+ logger.debug(f"Effective batch size = {batch_size * grad_accum_steps * args.world_size}")
173
144
 
174
145
  # Data loaders and samplers
175
146
  if args.distributed is True:
@@ -210,6 +181,8 @@ def train(args: argparse.Namespace) -> None:
210
181
  else:
211
182
  args.stop_epoch += 1
212
183
 
184
+ logging.debug(f"Epoch has {last_batch_idx+1} iterations ({optimizer_steps_per_epoch} steps)")
185
+
213
186
  #
214
187
  # Initialize network
215
188
  #
@@ -252,25 +225,28 @@ def train(args: argparse.Namespace) -> None:
252
225
  # Loss criteria, optimizer, learning rate scheduler and training parameter groups
253
226
  #
254
227
 
228
+ # Learning rate scaling
229
+ lr = training_utils.scale_lr(args)
230
+
255
231
  # Training parameter groups
256
232
  custom_keys_weight_decay = training_utils.get_wd_custom_keys(args)
257
233
  parameters = training_utils.optimizer_parameter_groups(
258
234
  net,
259
235
  args.wd,
236
+ base_lr=lr,
260
237
  norm_weight_decay=args.norm_wd,
261
238
  custom_keys_weight_decay=custom_keys_weight_decay,
239
+ custom_layer_weight_decay=args.custom_layer_wd,
262
240
  layer_decay=args.layer_decay,
263
241
  layer_decay_min_scale=args.layer_decay_min_scale,
264
242
  layer_decay_no_opt_scale=args.layer_decay_no_opt_scale,
265
243
  bias_lr=args.bias_lr,
244
+ custom_layer_lr_scale=args.custom_layer_lr_scale,
266
245
  )
267
246
 
268
247
  # Loss criteria
269
248
  criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
270
249
 
271
- # Learning rate scaling
272
- lr = training_utils.scale_lr(args)
273
-
274
250
  if args.lr_scheduler_update == "epoch":
275
251
  step_update = False
276
252
  scheduler_steps_per_epoch = 1