birder 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. birder/adversarial/__init__.py +13 -0
  2. birder/adversarial/base.py +101 -0
  3. birder/adversarial/deepfool.py +173 -0
  4. birder/adversarial/fgsm.py +51 -18
  5. birder/adversarial/pgd.py +79 -28
  6. birder/adversarial/simba.py +172 -0
  7. birder/common/training_cli.py +11 -3
  8. birder/common/training_utils.py +18 -1
  9. birder/inference/data_parallel.py +1 -2
  10. birder/introspection/__init__.py +10 -6
  11. birder/introspection/attention_rollout.py +122 -54
  12. birder/introspection/base.py +73 -29
  13. birder/introspection/gradcam.py +71 -100
  14. birder/introspection/guided_backprop.py +146 -72
  15. birder/introspection/transformer_attribution.py +182 -0
  16. birder/net/detection/deformable_detr.py +14 -12
  17. birder/net/detection/detr.py +7 -3
  18. birder/net/detection/rt_detr_v1.py +3 -3
  19. birder/net/detection/yolo_v3.py +6 -11
  20. birder/net/detection/yolo_v4.py +7 -18
  21. birder/net/detection/yolo_v4_tiny.py +3 -3
  22. birder/net/fastvit.py +1 -1
  23. birder/net/mim/mae_vit.py +7 -8
  24. birder/net/pit.py +1 -1
  25. birder/net/resnet_v1.py +94 -34
  26. birder/net/ssl/data2vec.py +1 -1
  27. birder/net/ssl/data2vec2.py +4 -2
  28. birder/results/gui.py +15 -2
  29. birder/scripts/predict_detection.py +33 -1
  30. birder/scripts/train.py +24 -17
  31. birder/scripts/train_barlow_twins.py +10 -7
  32. birder/scripts/train_byol.py +10 -7
  33. birder/scripts/train_capi.py +12 -9
  34. birder/scripts/train_data2vec.py +10 -7
  35. birder/scripts/train_data2vec2.py +10 -7
  36. birder/scripts/train_detection.py +42 -18
  37. birder/scripts/train_dino_v1.py +10 -7
  38. birder/scripts/train_dino_v2.py +10 -7
  39. birder/scripts/train_dino_v2_dist.py +17 -7
  40. birder/scripts/train_franca.py +10 -7
  41. birder/scripts/train_i_jepa.py +17 -13
  42. birder/scripts/train_ibot.py +10 -7
  43. birder/scripts/train_kd.py +24 -18
  44. birder/scripts/train_mim.py +11 -10
  45. birder/scripts/train_mmcr.py +10 -7
  46. birder/scripts/train_rotnet.py +10 -7
  47. birder/scripts/train_simclr.py +10 -7
  48. birder/scripts/train_vicreg.py +10 -7
  49. birder/tools/__main__.py +6 -2
  50. birder/tools/adversarial.py +147 -96
  51. birder/tools/auto_anchors.py +361 -0
  52. birder/tools/ensemble_model.py +1 -1
  53. birder/tools/introspection.py +58 -31
  54. birder/version.py +1 -1
  55. {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/METADATA +2 -1
  56. {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/RECORD +60 -55
  57. {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/WHEEL +0 -0
  58. {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/entry_points.txt +0 -0
  59. {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/licenses/LICENSE +0 -0
  60. {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/top_level.txt +0 -0
@@ -214,7 +214,8 @@ def train(args: argparse.Namespace) -> None:
214
214
  torch.autograd.set_detect_anomaly(args.grad_anomaly_detection)
215
215
 
216
216
  batch_size: int = args.batch_size
217
- logger.debug(f"Effective batch size = {args.batch_size * args.grad_accum_steps * args.world_size}")
217
+ grad_accum_steps: int = args.grad_accum_steps
218
+ logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
218
219
 
219
220
  begin_epoch = 1
220
221
  epochs = args.epochs + 1
@@ -417,6 +418,7 @@ def train(args: argparse.Namespace) -> None:
417
418
  drop_last=args.drop_last,
418
419
  )
419
420
 
421
+ optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
420
422
  last_batch_idx = len(training_loader) - 1
421
423
 
422
424
  #
@@ -438,20 +440,19 @@ def train(args: argparse.Namespace) -> None:
438
440
 
439
441
  # Learning rate scaling
440
442
  lr = training_utils.scale_lr(args)
441
- grad_accum_steps: int = args.grad_accum_steps
442
443
 
443
444
  if args.lr_scheduler_update == "epoch":
444
445
  step_update = False
445
- steps_per_epoch = 1
446
+ scheduler_steps_per_epoch = 1
446
447
  elif args.lr_scheduler_update == "step":
447
448
  step_update = True
448
- steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
449
+ scheduler_steps_per_epoch = optimizer_steps_per_epoch
449
450
  else:
450
451
  raise ValueError("Unsupported lr_scheduler_update")
451
452
 
452
453
  # Optimizer and learning rate scheduler
453
454
  optimizer = training_utils.get_optimizer(parameters, lr, args)
454
- scheduler = training_utils.get_scheduler(optimizer, steps_per_epoch, args)
455
+ scheduler = training_utils.get_scheduler(optimizer, scheduler_steps_per_epoch, args)
455
456
  if args.compile_opt is True:
456
457
  optimizer.step = torch.compile(optimizer.step, fullgraph=False)
457
458
 
@@ -492,11 +493,13 @@ def train(args: argparse.Namespace) -> None:
492
493
  optimizer.step()
493
494
  lrs = []
494
495
  for _ in range(begin_epoch, epochs):
495
- for _ in range(steps_per_epoch):
496
+ for _ in range(scheduler_steps_per_epoch):
496
497
  lrs.append(float(max(scheduler.get_last_lr())))
497
498
  scheduler.step()
498
499
 
499
- plt.plot(np.linspace(begin_epoch, epochs, steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs)
500
+ plt.plot(
501
+ np.linspace(begin_epoch, epochs, scheduler_steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs
502
+ )
500
503
  plt.show()
501
504
  raise SystemExit(0)
502
505
 
@@ -215,7 +215,8 @@ def train(args: argparse.Namespace) -> None:
215
215
  torch.autograd.set_detect_anomaly(args.grad_anomaly_detection)
216
216
 
217
217
  batch_size: int = args.batch_size
218
- logger.debug(f"Effective batch size = {args.batch_size * args.grad_accum_steps * args.world_size}")
218
+ grad_accum_steps: int = args.grad_accum_steps
219
+ logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
219
220
 
220
221
  begin_epoch = 1
221
222
  epochs = args.epochs + 1
@@ -240,6 +241,7 @@ def train(args: argparse.Namespace) -> None:
240
241
  args.network, sample_shape[1], 0, config=args.model_config, size=args.size
241
242
  )
242
243
  student_backbone_ema.load_state_dict(student_backbone.state_dict())
244
+ student_backbone_ema.requires_grad_(False)
243
245
 
244
246
  teacher_backbone = registry.net_factory(
245
247
  args.teacher,
@@ -248,6 +250,11 @@ def train(args: argparse.Namespace) -> None:
248
250
  config=args.teacher_model_config,
249
251
  size=args.size,
250
252
  )
253
+ assert student_backbone.max_stride == teacher_backbone.max_stride, (
254
+ "Student and teacher max_stride must match for distillation "
255
+ f"(student={student_backbone.max_stride}, teacher={teacher_backbone.max_stride})"
256
+ )
257
+
251
258
  student_backbone.set_dynamic_size()
252
259
  if args.ibot_separate_head is False:
253
260
  args.ibot_out_dim = args.dino_out_dim
@@ -433,6 +440,7 @@ def train(args: argparse.Namespace) -> None:
433
440
  drop_last=args.drop_last,
434
441
  )
435
442
 
443
+ optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
436
444
  last_batch_idx = len(training_loader) - 1
437
445
 
438
446
  #
@@ -454,20 +462,19 @@ def train(args: argparse.Namespace) -> None:
454
462
 
455
463
  # Learning rate scaling
456
464
  lr = training_utils.scale_lr(args)
457
- grad_accum_steps: int = args.grad_accum_steps
458
465
 
459
466
  if args.lr_scheduler_update == "epoch":
460
467
  step_update = False
461
- steps_per_epoch = 1
468
+ scheduler_steps_per_epoch = 1
462
469
  elif args.lr_scheduler_update == "step":
463
470
  step_update = True
464
- steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
471
+ scheduler_steps_per_epoch = optimizer_steps_per_epoch
465
472
  else:
466
473
  raise ValueError("Unsupported lr_scheduler_update")
467
474
 
468
475
  # Optimizer and learning rate scheduler
469
476
  optimizer = training_utils.get_optimizer(parameters, lr, args)
470
- scheduler = training_utils.get_scheduler(optimizer, steps_per_epoch, args)
477
+ scheduler = training_utils.get_scheduler(optimizer, scheduler_steps_per_epoch, args)
471
478
  if args.compile_opt is True:
472
479
  optimizer.step = torch.compile(optimizer.step, fullgraph=False)
473
480
 
@@ -507,11 +514,13 @@ def train(args: argparse.Namespace) -> None:
507
514
  optimizer.step()
508
515
  lrs = []
509
516
  for _ in range(begin_epoch, epochs):
510
- for _ in range(steps_per_epoch):
517
+ for _ in range(scheduler_steps_per_epoch):
511
518
  lrs.append(float(max(scheduler.get_last_lr())))
512
519
  scheduler.step()
513
520
 
514
- plt.plot(np.linspace(begin_epoch, epochs, steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs)
521
+ plt.plot(
522
+ np.linspace(begin_epoch, epochs, scheduler_steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs
523
+ )
515
524
  plt.show()
516
525
  raise SystemExit(0)
517
526
 
@@ -604,6 +613,7 @@ def train(args: argparse.Namespace) -> None:
604
613
  for epoch in range(begin_epoch, args.stop_epoch):
605
614
  tic = time.time()
606
615
  net.train()
616
+ teacher.eval()
607
617
  running_loss = training_utils.SmoothedValue()
608
618
  running_loss_dino_local = training_utils.SmoothedValue()
609
619
  running_loss_dino_global = training_utils.SmoothedValue()
@@ -241,7 +241,8 @@ def train(args: argparse.Namespace) -> None:
241
241
  torch.autograd.set_detect_anomaly(args.grad_anomaly_detection)
242
242
 
243
243
  batch_size: int = args.batch_size
244
- logger.debug(f"Effective batch size = {args.batch_size * args.grad_accum_steps * args.world_size}")
244
+ grad_accum_steps: int = args.grad_accum_steps
245
+ logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
245
246
 
246
247
  begin_epoch = 1
247
248
  epochs = args.epochs + 1
@@ -444,6 +445,7 @@ def train(args: argparse.Namespace) -> None:
444
445
  drop_last=args.drop_last,
445
446
  )
446
447
 
448
+ optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
447
449
  last_batch_idx = len(training_loader) - 1
448
450
 
449
451
  #
@@ -465,20 +467,19 @@ def train(args: argparse.Namespace) -> None:
465
467
 
466
468
  # Learning rate scaling
467
469
  lr = training_utils.scale_lr(args)
468
- grad_accum_steps: int = args.grad_accum_steps
469
470
 
470
471
  if args.lr_scheduler_update == "epoch":
471
472
  step_update = False
472
- steps_per_epoch = 1
473
+ scheduler_steps_per_epoch = 1
473
474
  elif args.lr_scheduler_update == "step":
474
475
  step_update = True
475
- steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
476
+ scheduler_steps_per_epoch = optimizer_steps_per_epoch
476
477
  else:
477
478
  raise ValueError("Unsupported lr_scheduler_update")
478
479
 
479
480
  # Optimizer and learning rate scheduler
480
481
  optimizer = training_utils.get_optimizer(parameters, lr, args)
481
- scheduler = training_utils.get_scheduler(optimizer, steps_per_epoch, args)
482
+ scheduler = training_utils.get_scheduler(optimizer, scheduler_steps_per_epoch, args)
482
483
  if args.compile_opt is True:
483
484
  optimizer.step = torch.compile(optimizer.step, fullgraph=False)
484
485
 
@@ -519,11 +520,13 @@ def train(args: argparse.Namespace) -> None:
519
520
  optimizer.step()
520
521
  lrs = []
521
522
  for _ in range(begin_epoch, epochs):
522
- for _ in range(steps_per_epoch):
523
+ for _ in range(scheduler_steps_per_epoch):
523
524
  lrs.append(float(max(scheduler.get_last_lr())))
524
525
  scheduler.step()
525
526
 
526
- plt.plot(np.linspace(begin_epoch, epochs, steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs)
527
+ plt.plot(
528
+ np.linspace(begin_epoch, epochs, scheduler_steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs
529
+ )
527
530
  plt.show()
528
531
  raise SystemExit(0)
529
532
 
@@ -120,7 +120,8 @@ def train(args: argparse.Namespace) -> None:
120
120
  torch.autograd.set_detect_anomaly(args.grad_anomaly_detection)
121
121
 
122
122
  batch_size: int = args.batch_size
123
- logger.debug(f"Effective batch size = {args.batch_size * args.grad_accum_steps * args.world_size}")
123
+ grad_accum_steps: int = args.grad_accum_steps
124
+ logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
124
125
 
125
126
  begin_epoch = 1
126
127
  epochs = args.epochs + 1
@@ -284,6 +285,7 @@ def train(args: argparse.Namespace) -> None:
284
285
  drop_last=args.drop_last,
285
286
  )
286
287
 
288
+ optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
287
289
  last_batch_idx = len(training_loader) - 1
288
290
 
289
291
  #
@@ -305,20 +307,19 @@ def train(args: argparse.Namespace) -> None:
305
307
 
306
308
  # Learning rate scaling
307
309
  lr = training_utils.scale_lr(args)
308
- grad_accum_steps: int = args.grad_accum_steps
309
310
 
310
311
  if args.lr_scheduler_update == "epoch":
311
312
  step_update = False
312
- steps_per_epoch = 1
313
+ scheduler_steps_per_epoch = 1
313
314
  elif args.lr_scheduler_update == "step":
314
315
  step_update = True
315
- steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
316
+ scheduler_steps_per_epoch = optimizer_steps_per_epoch
316
317
  else:
317
318
  raise ValueError("Unsupported lr_scheduler_update")
318
319
 
319
320
  # Optimizer and learning rate scheduler
320
321
  optimizer = training_utils.get_optimizer(parameters, lr, args)
321
- scheduler = training_utils.get_scheduler(optimizer, steps_per_epoch, args)
322
+ scheduler = training_utils.get_scheduler(optimizer, scheduler_steps_per_epoch, args)
322
323
  if args.compile_opt is True:
323
324
  optimizer.step = torch.compile(optimizer.step, fullgraph=False)
324
325
 
@@ -351,11 +352,13 @@ def train(args: argparse.Namespace) -> None:
351
352
  optimizer.step()
352
353
  lrs = []
353
354
  for _ in range(begin_epoch, epochs):
354
- for _ in range(steps_per_epoch):
355
+ for _ in range(scheduler_steps_per_epoch):
355
356
  lrs.append(float(max(scheduler.get_last_lr())))
356
357
  scheduler.step()
357
358
 
358
- plt.plot(np.linspace(begin_epoch, epochs, steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs)
359
+ plt.plot(
360
+ np.linspace(begin_epoch, epochs, scheduler_steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs
361
+ )
359
362
  plt.show()
360
363
  raise SystemExit(0)
361
364
 
@@ -521,12 +524,13 @@ def train(args: argparse.Namespace) -> None:
521
524
  if step_update is True:
522
525
  scheduler.step()
523
526
 
524
- # EMA update for the target encoder
525
- with torch.no_grad():
526
- m = momentum_schedule[global_iter]
527
- torch._foreach_lerp_( # pylint: disable=protected-access
528
- list(target_encoder.parameters()), list(encoder.parameters()), weight=1 - m
529
- )
527
+ if optimizer_update is True:
528
+ # EMA update for the target encoder
529
+ with torch.no_grad():
530
+ m = momentum_schedule[global_iter]
531
+ torch._foreach_lerp_( # pylint: disable=protected-access
532
+ list(target_encoder.parameters()), list(encoder.parameters()), weight=1 - m
533
+ )
530
534
 
531
535
  # Statistics
532
536
  running_loss.update(loss.detach())
@@ -143,7 +143,8 @@ def train(args: argparse.Namespace) -> None:
143
143
  torch.autograd.set_detect_anomaly(args.grad_anomaly_detection)
144
144
 
145
145
  batch_size: int = args.batch_size
146
- logger.debug(f"Effective batch size = {args.batch_size * args.grad_accum_steps * args.world_size}")
146
+ grad_accum_steps: int = args.grad_accum_steps
147
+ logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
147
148
 
148
149
  begin_epoch = 1
149
150
  epochs = args.epochs + 1
@@ -351,6 +352,7 @@ def train(args: argparse.Namespace) -> None:
351
352
  drop_last=args.drop_last,
352
353
  )
353
354
 
355
+ optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
354
356
  last_batch_idx = len(training_loader) - 1
355
357
 
356
358
  #
@@ -372,20 +374,19 @@ def train(args: argparse.Namespace) -> None:
372
374
 
373
375
  # Learning rate scaling
374
376
  lr = training_utils.scale_lr(args)
375
- grad_accum_steps: int = args.grad_accum_steps
376
377
 
377
378
  if args.lr_scheduler_update == "epoch":
378
379
  step_update = False
379
- steps_per_epoch = 1
380
+ scheduler_steps_per_epoch = 1
380
381
  elif args.lr_scheduler_update == "step":
381
382
  step_update = True
382
- steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
383
+ scheduler_steps_per_epoch = optimizer_steps_per_epoch
383
384
  else:
384
385
  raise ValueError("Unsupported lr_scheduler_update")
385
386
 
386
387
  # Optimizer and learning rate scheduler
387
388
  optimizer = training_utils.get_optimizer(parameters, lr, args)
388
- scheduler = training_utils.get_scheduler(optimizer, steps_per_epoch, args)
389
+ scheduler = training_utils.get_scheduler(optimizer, scheduler_steps_per_epoch, args)
389
390
  if args.compile_opt is True:
390
391
  optimizer.step = torch.compile(optimizer.step, fullgraph=False)
391
392
 
@@ -418,11 +419,13 @@ def train(args: argparse.Namespace) -> None:
418
419
  optimizer.step()
419
420
  lrs = []
420
421
  for _ in range(begin_epoch, epochs):
421
- for _ in range(steps_per_epoch):
422
+ for _ in range(scheduler_steps_per_epoch):
422
423
  lrs.append(float(max(scheduler.get_last_lr())))
423
424
  scheduler.step()
424
425
 
425
- plt.plot(np.linspace(begin_epoch, epochs, steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs)
426
+ plt.plot(
427
+ np.linspace(begin_epoch, epochs, scheduler_steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs
428
+ )
426
429
  plt.show()
427
430
  raise SystemExit(0)
428
431
 
@@ -186,8 +186,9 @@ def train(args: argparse.Namespace) -> None:
186
186
 
187
187
  num_outputs = len(class_to_idx)
188
188
  batch_size: int = args.batch_size
189
- model_ema_steps: int = args.model_ema_steps * args.grad_accum_steps
190
- logger.debug(f"Effective batch size = {args.batch_size * args.grad_accum_steps * args.world_size}")
189
+ grad_accum_steps: int = args.grad_accum_steps
190
+ model_ema_steps: int = args.model_ema_steps
191
+ logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
191
192
 
192
193
  # Set data iterators
193
194
  if args.mixup_alpha is not None or args.cutmix is True:
@@ -246,8 +247,8 @@ def train(args: argparse.Namespace) -> None:
246
247
  pin_memory=True,
247
248
  )
248
249
 
249
- optimizer_steps_per_epoch = math.ceil(len(training_loader) / args.grad_accum_steps)
250
- assert args.model_ema is False or args.model_ema_steps <= optimizer_steps_per_epoch
250
+ optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
251
+ assert args.model_ema is False or model_ema_steps <= optimizer_steps_per_epoch
251
252
 
252
253
  last_batch_idx = len(training_loader) - 1
253
254
  begin_epoch = 1
@@ -336,20 +337,18 @@ def train(args: argparse.Namespace) -> None:
336
337
 
337
338
  # Learning rate scaling
338
339
  lr = training_utils.scale_lr(args)
339
- grad_accum_steps: int = args.grad_accum_steps
340
-
341
340
  if args.lr_scheduler_update == "epoch":
342
341
  step_update = False
343
- steps_per_epoch = 1
342
+ scheduler_steps_per_epoch = 1
344
343
  elif args.lr_scheduler_update == "step":
345
344
  step_update = True
346
- steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
345
+ scheduler_steps_per_epoch = optimizer_steps_per_epoch
347
346
  else:
348
347
  raise ValueError("Unsupported lr_scheduler_update")
349
348
 
350
349
  # Optimizer and learning rate scheduler
351
350
  optimizer = training_utils.get_optimizer(parameters, lr, args)
352
- scheduler = training_utils.get_scheduler(optimizer, steps_per_epoch, args)
351
+ scheduler = training_utils.get_scheduler(optimizer, scheduler_steps_per_epoch, args)
353
352
  if args.compile_opt is True:
354
353
  optimizer.step = torch.compile(optimizer.step, fullgraph=False)
355
354
 
@@ -375,11 +374,14 @@ def train(args: argparse.Namespace) -> None:
375
374
  optimizer.step()
376
375
  lrs = []
377
376
  for _ in range(begin_epoch, epochs):
378
- for _ in range(steps_per_epoch):
377
+ for _ in range(scheduler_steps_per_epoch):
379
378
  lrs.append(float(max(scheduler.get_last_lr())))
380
379
  scheduler.step()
381
380
 
382
- plt.plot(np.linspace(begin_epoch, epochs, steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs)
381
+ plt.plot(
382
+ np.linspace(begin_epoch, epochs, scheduler_steps_per_epoch * (epochs - begin_epoch), endpoint=False),
383
+ lrs,
384
+ )
383
385
  plt.show()
384
386
  raise SystemExit(0)
385
387
 
@@ -387,15 +389,15 @@ def train(args: argparse.Namespace) -> None:
387
389
  # Distributed (DDP) and Model EMA
388
390
  #
389
391
  if args.model_ema_warmup is not None:
390
- ema_warmup_epochs = args.model_ema_warmup
392
+ ema_warmup_steps = args.model_ema_warmup * optimizer_steps_per_epoch
391
393
  elif args.warmup_epochs is not None:
392
- ema_warmup_epochs = args.warmup_epochs
394
+ ema_warmup_steps = args.warmup_epochs * optimizer_steps_per_epoch
393
395
  elif args.warmup_steps is not None:
394
- ema_warmup_epochs = args.warmup_steps // steps_per_epoch
396
+ ema_warmup_steps = args.warmup_steps
395
397
  else:
396
- ema_warmup_epochs = 0
398
+ ema_warmup_steps = 0
397
399
 
398
- logger.debug(f"EMA warmup epochs = {ema_warmup_epochs}")
400
+ logger.debug(f"EMA warmup steps = {ema_warmup_steps}")
399
401
  net_without_ddp = student
400
402
  if args.distributed is True:
401
403
  student = torch.nn.parallel.DistributedDataParallel(
@@ -493,6 +495,7 @@ def train(args: argparse.Namespace) -> None:
493
495
  #
494
496
  # Training loop
495
497
  #
498
+ optimizer_step = (begin_epoch - 1) * optimizer_steps_per_epoch
496
499
  logger.info(f"Starting training with learning rate of {last_lr}")
497
500
  for epoch in range(begin_epoch, args.stop_epoch):
498
501
  tic = time.time()
@@ -571,10 +574,13 @@ def train(args: argparse.Namespace) -> None:
571
574
  if step_update is True:
572
575
  scheduler.step()
573
576
 
577
+ if optimizer_update is True:
578
+ optimizer_step += 1
579
+
574
580
  # Exponential moving average
575
- if args.model_ema is True and i % model_ema_steps == 0:
581
+ if args.model_ema is True and optimizer_update is True and optimizer_step % model_ema_steps == 0:
576
582
  model_ema.update_parameters(student)
577
- if epoch <= ema_warmup_epochs:
583
+ if ema_warmup_steps > 0 and optimizer_step <= ema_warmup_steps:
578
584
  # Reset ema buffer to keep copying weights during warmup period
579
585
  model_ema.n_averaged.fill_(0) # pylint: disable=no-member
580
586
 
@@ -130,7 +130,8 @@ def train(args: argparse.Namespace) -> None:
130
130
  logger.info(f"Training on {len(training_dataset):,} samples")
131
131
 
132
132
  batch_size: int = args.batch_size
133
- logger.debug(f"Effective batch size = {args.batch_size * args.grad_accum_steps * args.world_size}")
133
+ grad_accum_steps: int = args.grad_accum_steps
134
+ logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
134
135
 
135
136
  # Data loaders and samplers
136
137
  if args.distributed is True:
@@ -162,6 +163,7 @@ def train(args: argparse.Namespace) -> None:
162
163
  drop_last=args.drop_last,
163
164
  )
164
165
 
166
+ optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
165
167
  last_batch_idx = len(training_loader) - 1
166
168
  begin_epoch = 1
167
169
  epochs = args.epochs + 1
@@ -254,20 +256,19 @@ def train(args: argparse.Namespace) -> None:
254
256
 
255
257
  # Learning rate scaling
256
258
  lr = training_utils.scale_lr(args)
257
- grad_accum_steps: int = args.grad_accum_steps
258
259
 
259
260
  if args.lr_scheduler_update == "epoch":
260
261
  step_update = False
261
- steps_per_epoch = 1
262
+ scheduler_steps_per_epoch = 1
262
263
  elif args.lr_scheduler_update == "step":
263
264
  step_update = True
264
- steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
265
+ scheduler_steps_per_epoch = optimizer_steps_per_epoch
265
266
  else:
266
267
  raise ValueError("Unsupported lr_scheduler_update")
267
268
 
268
269
  # Optimizer and learning rate scheduler
269
270
  optimizer = training_utils.get_optimizer(parameters, lr, args)
270
- scheduler = training_utils.get_scheduler(optimizer, steps_per_epoch, args)
271
+ scheduler = training_utils.get_scheduler(optimizer, scheduler_steps_per_epoch, args)
271
272
  if args.compile_opt is True:
272
273
  optimizer.step = torch.compile(optimizer.step, fullgraph=False)
273
274
 
@@ -293,11 +294,13 @@ def train(args: argparse.Namespace) -> None:
293
294
  optimizer.step()
294
295
  lrs = []
295
296
  for _ in range(begin_epoch, epochs):
296
- for _ in range(steps_per_epoch):
297
+ for _ in range(scheduler_steps_per_epoch):
297
298
  lrs.append(float(max(scheduler.get_last_lr())))
298
299
  scheduler.step()
299
300
 
300
- plt.plot(np.linspace(begin_epoch, epochs, steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs)
301
+ plt.plot(
302
+ np.linspace(begin_epoch, epochs, scheduler_steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs
303
+ )
301
304
  plt.show()
302
305
  raise SystemExit(0)
303
306
 
@@ -599,9 +602,7 @@ def get_args_parser() -> argparse.ArgumentParser:
599
602
  "('drop_path_rate=0.2' or '{\"units\": [3, 24, 36, 3], \"dropout\": 0.2}'"
600
603
  ),
601
604
  )
602
- parser.add_argument(
603
- "--mask-ratio", type=float, default=None, help="mask ratio for MIM training (default: model-specific)"
604
- )
605
+ parser.add_argument("--mask-ratio", type=float, help="mask ratio for MIM training (default: model-specific)")
605
606
  parser.add_argument("--min-mask-size", type=int, default=1, help="minimum mask unit size in patches")
606
607
  training_cli.add_optimization_args(parser)
607
608
  training_cli.add_lr_wd_args(parser)
@@ -154,7 +154,8 @@ def train(args: argparse.Namespace) -> None:
154
154
  logger.info(f"Training on {len(training_dataset):,} samples")
155
155
 
156
156
  batch_size: int = args.batch_size
157
- logger.debug(f"Effective batch size = {args.batch_size * args.grad_accum_steps * args.world_size}")
157
+ grad_accum_steps: int = args.grad_accum_steps
158
+ logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
158
159
 
159
160
  # Data loaders and samplers
160
161
  if args.distributed is True:
@@ -186,6 +187,7 @@ def train(args: argparse.Namespace) -> None:
186
187
  drop_last=args.drop_last,
187
188
  )
188
189
 
190
+ optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
189
191
  last_batch_idx = len(training_loader) - 1
190
192
  begin_epoch = 1
191
193
  epochs = args.epochs + 1
@@ -256,20 +258,19 @@ def train(args: argparse.Namespace) -> None:
256
258
 
257
259
  # Learning rate scaling
258
260
  lr = training_utils.scale_lr(args)
259
- grad_accum_steps: int = args.grad_accum_steps
260
261
 
261
262
  if args.lr_scheduler_update == "epoch":
262
263
  step_update = False
263
- steps_per_epoch = 1
264
+ scheduler_steps_per_epoch = 1
264
265
  elif args.lr_scheduler_update == "step":
265
266
  step_update = True
266
- steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
267
+ scheduler_steps_per_epoch = optimizer_steps_per_epoch
267
268
  else:
268
269
  raise ValueError("Unsupported lr_scheduler_update")
269
270
 
270
271
  # Optimizer and learning rate scheduler
271
272
  optimizer = training_utils.get_optimizer(parameters, lr, args)
272
- scheduler = training_utils.get_scheduler(optimizer, steps_per_epoch, args)
273
+ scheduler = training_utils.get_scheduler(optimizer, scheduler_steps_per_epoch, args)
273
274
  if args.compile_opt is True:
274
275
  optimizer.step = torch.compile(optimizer.step, fullgraph=False)
275
276
 
@@ -295,11 +296,13 @@ def train(args: argparse.Namespace) -> None:
295
296
  optimizer.step()
296
297
  lrs = []
297
298
  for _ in range(begin_epoch, epochs):
298
- for _ in range(steps_per_epoch):
299
+ for _ in range(scheduler_steps_per_epoch):
299
300
  lrs.append(float(max(scheduler.get_last_lr())))
300
301
  scheduler.step()
301
302
 
302
- plt.plot(np.linspace(begin_epoch, epochs, steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs)
303
+ plt.plot(
304
+ np.linspace(begin_epoch, epochs, scheduler_steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs
305
+ )
303
306
  plt.show()
304
307
  raise SystemExit(0)
305
308
 
@@ -168,7 +168,8 @@ def train(args: argparse.Namespace) -> None:
168
168
  logger.info(f"Training on {len(training_dataset):,} samples")
169
169
 
170
170
  batch_size: int = args.batch_size
171
- logger.debug(f"Effective batch size = {args.batch_size * args.grad_accum_steps * args.world_size}")
171
+ grad_accum_steps: int = args.grad_accum_steps
172
+ logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
172
173
 
173
174
  # Data loaders and samplers
174
175
  if args.distributed is True:
@@ -200,6 +201,7 @@ def train(args: argparse.Namespace) -> None:
200
201
  drop_last=args.drop_last,
201
202
  )
202
203
 
204
+ optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
203
205
  last_batch_idx = len(training_loader) - 1
204
206
  begin_epoch = 1
205
207
  epochs = args.epochs + 1
@@ -268,20 +270,19 @@ def train(args: argparse.Namespace) -> None:
268
270
 
269
271
  # Learning rate scaling
270
272
  lr = training_utils.scale_lr(args)
271
- grad_accum_steps: int = args.grad_accum_steps
272
273
 
273
274
  if args.lr_scheduler_update == "epoch":
274
275
  step_update = False
275
- steps_per_epoch = 1
276
+ scheduler_steps_per_epoch = 1
276
277
  elif args.lr_scheduler_update == "step":
277
278
  step_update = True
278
- steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
279
+ scheduler_steps_per_epoch = optimizer_steps_per_epoch
279
280
  else:
280
281
  raise ValueError("Unsupported lr_scheduler_update")
281
282
 
282
283
  # Optimizer and learning rate scheduler
283
284
  optimizer = training_utils.get_optimizer(parameters, lr, args)
284
- scheduler = training_utils.get_scheduler(optimizer, steps_per_epoch, args)
285
+ scheduler = training_utils.get_scheduler(optimizer, scheduler_steps_per_epoch, args)
285
286
  if args.compile_opt is True:
286
287
  optimizer.step = torch.compile(optimizer.step, fullgraph=False)
287
288
 
@@ -307,11 +308,13 @@ def train(args: argparse.Namespace) -> None:
307
308
  optimizer.step()
308
309
  lrs = []
309
310
  for _ in range(begin_epoch, epochs):
310
- for _ in range(steps_per_epoch):
311
+ for _ in range(scheduler_steps_per_epoch):
311
312
  lrs.append(float(max(scheduler.get_last_lr())))
312
313
  scheduler.step()
313
314
 
314
- plt.plot(np.linspace(begin_epoch, epochs, steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs)
315
+ plt.plot(
316
+ np.linspace(begin_epoch, epochs, scheduler_steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs
317
+ )
315
318
  plt.show()
316
319
  raise SystemExit(0)
317
320