birder 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. birder/adversarial/__init__.py +13 -0
  2. birder/adversarial/base.py +101 -0
  3. birder/adversarial/deepfool.py +173 -0
  4. birder/adversarial/fgsm.py +51 -18
  5. birder/adversarial/pgd.py +79 -28
  6. birder/adversarial/simba.py +172 -0
  7. birder/common/training_cli.py +11 -3
  8. birder/common/training_utils.py +18 -1
  9. birder/inference/data_parallel.py +1 -2
  10. birder/introspection/__init__.py +10 -6
  11. birder/introspection/attention_rollout.py +122 -54
  12. birder/introspection/base.py +73 -29
  13. birder/introspection/gradcam.py +71 -100
  14. birder/introspection/guided_backprop.py +146 -72
  15. birder/introspection/transformer_attribution.py +182 -0
  16. birder/net/detection/deformable_detr.py +14 -12
  17. birder/net/detection/detr.py +7 -3
  18. birder/net/detection/rt_detr_v1.py +3 -3
  19. birder/net/detection/yolo_v3.py +6 -11
  20. birder/net/detection/yolo_v4.py +7 -18
  21. birder/net/detection/yolo_v4_tiny.py +3 -3
  22. birder/net/fastvit.py +1 -1
  23. birder/net/mim/mae_vit.py +7 -8
  24. birder/net/pit.py +1 -1
  25. birder/net/resnet_v1.py +94 -34
  26. birder/net/ssl/data2vec.py +1 -1
  27. birder/net/ssl/data2vec2.py +4 -2
  28. birder/results/gui.py +15 -2
  29. birder/scripts/predict_detection.py +33 -1
  30. birder/scripts/train.py +24 -17
  31. birder/scripts/train_barlow_twins.py +10 -7
  32. birder/scripts/train_byol.py +10 -7
  33. birder/scripts/train_capi.py +12 -9
  34. birder/scripts/train_data2vec.py +10 -7
  35. birder/scripts/train_data2vec2.py +10 -7
  36. birder/scripts/train_detection.py +42 -18
  37. birder/scripts/train_dino_v1.py +10 -7
  38. birder/scripts/train_dino_v2.py +10 -7
  39. birder/scripts/train_dino_v2_dist.py +17 -7
  40. birder/scripts/train_franca.py +10 -7
  41. birder/scripts/train_i_jepa.py +17 -13
  42. birder/scripts/train_ibot.py +10 -7
  43. birder/scripts/train_kd.py +24 -18
  44. birder/scripts/train_mim.py +11 -10
  45. birder/scripts/train_mmcr.py +10 -7
  46. birder/scripts/train_rotnet.py +10 -7
  47. birder/scripts/train_simclr.py +10 -7
  48. birder/scripts/train_vicreg.py +10 -7
  49. birder/tools/__main__.py +6 -2
  50. birder/tools/adversarial.py +147 -96
  51. birder/tools/auto_anchors.py +361 -0
  52. birder/tools/ensemble_model.py +1 -1
  53. birder/tools/introspection.py +58 -31
  54. birder/version.py +1 -1
  55. {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/METADATA +2 -1
  56. {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/RECORD +60 -55
  57. {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/WHEEL +0 -0
  58. {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/entry_points.txt +0 -0
  59. {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/licenses/LICENSE +0 -0
  60. {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/top_level.txt +0 -0
birder/scripts/train.py CHANGED
@@ -160,8 +160,9 @@ def train(args: argparse.Namespace) -> None:
160
160
 
161
161
  num_outputs = len(class_to_idx)
162
162
  batch_size: int = args.batch_size
163
- model_ema_steps: int = args.model_ema_steps * args.grad_accum_steps
164
- logger.debug(f"Effective batch size = {args.batch_size * args.grad_accum_steps * args.world_size}")
163
+ grad_accum_steps: int = args.grad_accum_steps
164
+ model_ema_steps: int = args.model_ema_steps
165
+ logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
165
166
 
166
167
  # Set data iterators
167
168
  if args.mixup_alpha is not None or args.cutmix is True:
@@ -220,8 +221,8 @@ def train(args: argparse.Namespace) -> None:
220
221
  pin_memory=True,
221
222
  )
222
223
 
223
- optimizer_steps_per_epoch = math.ceil(len(training_loader) / args.grad_accum_steps)
224
- assert args.model_ema is False or args.model_ema_steps <= optimizer_steps_per_epoch
224
+ optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
225
+ assert args.model_ema is False or model_ema_steps <= optimizer_steps_per_epoch
225
226
 
226
227
  last_batch_idx = len(training_loader) - 1
227
228
  begin_epoch = 1
@@ -317,20 +318,19 @@ def train(args: argparse.Namespace) -> None:
317
318
 
318
319
  # Learning rate scaling
319
320
  lr = training_utils.scale_lr(args)
320
- grad_accum_steps: int = args.grad_accum_steps
321
321
 
322
322
  if args.lr_scheduler_update == "epoch":
323
323
  step_update = False
324
- steps_per_epoch = 1
324
+ scheduler_steps_per_epoch = 1
325
325
  elif args.lr_scheduler_update == "step":
326
326
  step_update = True
327
- steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
327
+ scheduler_steps_per_epoch = optimizer_steps_per_epoch
328
328
  else:
329
329
  raise ValueError("Unsupported lr_scheduler_update")
330
330
 
331
331
  # Optimizer and learning rate scheduler
332
332
  optimizer = training_utils.get_optimizer(parameters, lr, args)
333
- scheduler = training_utils.get_scheduler(optimizer, steps_per_epoch, args)
333
+ scheduler = training_utils.get_scheduler(optimizer, scheduler_steps_per_epoch, args)
334
334
  if args.compile_opt is True:
335
335
  optimizer.step = torch.compile(optimizer.step, fullgraph=False)
336
336
 
@@ -356,11 +356,14 @@ def train(args: argparse.Namespace) -> None:
356
356
  optimizer.step()
357
357
  lrs = []
358
358
  for _ in range(begin_epoch, epochs):
359
- for _ in range(steps_per_epoch):
359
+ for _ in range(scheduler_steps_per_epoch):
360
360
  lrs.append(float(max(scheduler.get_last_lr())))
361
361
  scheduler.step()
362
362
 
363
- plt.plot(np.linspace(begin_epoch, epochs, steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs)
363
+ plt.plot(
364
+ np.linspace(begin_epoch, epochs, scheduler_steps_per_epoch * (epochs - begin_epoch), endpoint=False),
365
+ lrs,
366
+ )
364
367
  plt.show()
365
368
  raise SystemExit(0)
366
369
 
@@ -368,15 +371,15 @@ def train(args: argparse.Namespace) -> None:
368
371
  # Distributed (DDP) and Model EMA
369
372
  #
370
373
  if args.model_ema_warmup is not None:
371
- ema_warmup_epochs = args.model_ema_warmup
374
+ ema_warmup_steps = args.model_ema_warmup * optimizer_steps_per_epoch
372
375
  elif args.warmup_epochs is not None:
373
- ema_warmup_epochs = args.warmup_epochs
376
+ ema_warmup_steps = args.warmup_epochs * optimizer_steps_per_epoch
374
377
  elif args.warmup_steps is not None:
375
- ema_warmup_epochs = args.warmup_steps // steps_per_epoch
378
+ ema_warmup_steps = args.warmup_steps
376
379
  else:
377
- ema_warmup_epochs = 0
380
+ ema_warmup_steps = 0
378
381
 
379
- logger.debug(f"EMA warmup epochs = {ema_warmup_epochs}")
382
+ logger.debug(f"EMA warmup steps = {ema_warmup_steps}")
380
383
  net_without_ddp = net
381
384
  if args.distributed is True:
382
385
  net = torch.nn.parallel.DistributedDataParallel(
@@ -474,6 +477,7 @@ def train(args: argparse.Namespace) -> None:
474
477
  #
475
478
  # Training loop
476
479
  #
480
+ optimizer_step = (begin_epoch - 1) * optimizer_steps_per_epoch
477
481
  logger.info(f"Starting training with learning rate of {last_lr}")
478
482
  for epoch in range(begin_epoch, args.stop_epoch):
479
483
  tic = time.time()
@@ -542,10 +546,13 @@ def train(args: argparse.Namespace) -> None:
542
546
  if step_update is True:
543
547
  scheduler.step()
544
548
 
549
+ if optimizer_update is True:
550
+ optimizer_step += 1
551
+
545
552
  # Exponential moving average
546
- if args.model_ema is True and i % model_ema_steps == 0:
553
+ if args.model_ema is True and optimizer_update is True and optimizer_step % model_ema_steps == 0:
547
554
  model_ema.update_parameters(net)
548
- if epoch <= ema_warmup_epochs:
555
+ if ema_warmup_steps > 0 and optimizer_step <= ema_warmup_steps:
549
556
  # Reset ema buffer to keep copying weights during warmup period
550
557
  model_ema.n_averaged.fill_(0) # pylint: disable=no-member
551
558
 
@@ -149,7 +149,8 @@ def train(args: argparse.Namespace) -> None:
149
149
  logger.info(f"Training on {len(training_dataset):,} samples")
150
150
 
151
151
  batch_size: int = args.batch_size
152
- logger.debug(f"Effective batch size = {args.batch_size * args.grad_accum_steps * args.world_size}")
152
+ grad_accum_steps: int = args.grad_accum_steps
153
+ logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
153
154
 
154
155
  # Data loaders and samplers
155
156
  if args.distributed is True:
@@ -181,6 +182,7 @@ def train(args: argparse.Namespace) -> None:
181
182
  drop_last=args.drop_last,
182
183
  )
183
184
 
185
+ optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
184
186
  last_batch_idx = len(training_loader) - 1
185
187
  begin_epoch = 1
186
188
  epochs = args.epochs + 1
@@ -244,20 +246,19 @@ def train(args: argparse.Namespace) -> None:
244
246
 
245
247
  # Learning rate scaling
246
248
  lr = training_utils.scale_lr(args)
247
- grad_accum_steps: int = args.grad_accum_steps
248
249
 
249
250
  if args.lr_scheduler_update == "epoch":
250
251
  step_update = False
251
- steps_per_epoch = 1
252
+ scheduler_steps_per_epoch = 1
252
253
  elif args.lr_scheduler_update == "step":
253
254
  step_update = True
254
- steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
255
+ scheduler_steps_per_epoch = optimizer_steps_per_epoch
255
256
  else:
256
257
  raise ValueError("Unsupported lr_scheduler_update")
257
258
 
258
259
  # Optimizer and learning rate scheduler
259
260
  optimizer = training_utils.get_optimizer(parameters, lr, args)
260
- scheduler = training_utils.get_scheduler(optimizer, steps_per_epoch, args)
261
+ scheduler = training_utils.get_scheduler(optimizer, scheduler_steps_per_epoch, args)
261
262
  if args.compile_opt is True:
262
263
  optimizer.step = torch.compile(optimizer.step, fullgraph=False)
263
264
 
@@ -283,11 +284,13 @@ def train(args: argparse.Namespace) -> None:
283
284
  optimizer.step()
284
285
  lrs = []
285
286
  for _ in range(begin_epoch, epochs):
286
- for _ in range(steps_per_epoch):
287
+ for _ in range(scheduler_steps_per_epoch):
287
288
  lrs.append(float(max(scheduler.get_last_lr())))
288
289
  scheduler.step()
289
290
 
290
- plt.plot(np.linspace(begin_epoch, epochs, steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs)
291
+ plt.plot(
292
+ np.linspace(begin_epoch, epochs, scheduler_steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs
293
+ )
291
294
  plt.show()
292
295
  raise SystemExit(0)
293
296
 
@@ -151,7 +151,8 @@ def train(args: argparse.Namespace) -> None:
151
151
  logger.info(f"Training on {len(training_dataset):,} samples")
152
152
 
153
153
  batch_size: int = args.batch_size
154
- logger.debug(f"Effective batch size = {args.batch_size * args.grad_accum_steps * args.world_size}")
154
+ grad_accum_steps: int = args.grad_accum_steps
155
+ logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
155
156
 
156
157
  # Data loaders and samplers
157
158
  if args.distributed is True:
@@ -183,6 +184,7 @@ def train(args: argparse.Namespace) -> None:
183
184
  drop_last=args.drop_last,
184
185
  )
185
186
 
187
+ optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
186
188
  last_batch_idx = len(training_loader) - 1
187
189
  begin_epoch = 1
188
190
  epochs = args.epochs + 1
@@ -256,20 +258,19 @@ def train(args: argparse.Namespace) -> None:
256
258
 
257
259
  # Learning rate scaling
258
260
  lr = training_utils.scale_lr(args)
259
- grad_accum_steps: int = args.grad_accum_steps
260
261
 
261
262
  if args.lr_scheduler_update == "epoch":
262
263
  step_update = False
263
- steps_per_epoch = 1
264
+ scheduler_steps_per_epoch = 1
264
265
  elif args.lr_scheduler_update == "step":
265
266
  step_update = True
266
- steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
267
+ scheduler_steps_per_epoch = optimizer_steps_per_epoch
267
268
  else:
268
269
  raise ValueError("Unsupported lr_scheduler_update")
269
270
 
270
271
  # Optimizer and learning rate scheduler
271
272
  optimizer = training_utils.get_optimizer(parameters, lr, args)
272
- scheduler = training_utils.get_scheduler(optimizer, steps_per_epoch, args)
273
+ scheduler = training_utils.get_scheduler(optimizer, scheduler_steps_per_epoch, args)
273
274
  if args.compile_opt is True:
274
275
  optimizer.step = torch.compile(optimizer.step, fullgraph=False)
275
276
 
@@ -295,11 +296,13 @@ def train(args: argparse.Namespace) -> None:
295
296
  optimizer.step()
296
297
  lrs = []
297
298
  for _ in range(begin_epoch, epochs):
298
- for _ in range(steps_per_epoch):
299
+ for _ in range(scheduler_steps_per_epoch):
299
300
  lrs.append(float(max(scheduler.get_last_lr())))
300
301
  scheduler.step()
301
302
 
302
- plt.plot(np.linspace(begin_epoch, epochs, steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs)
303
+ plt.plot(
304
+ np.linspace(begin_epoch, epochs, scheduler_steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs
305
+ )
303
306
  plt.show()
304
307
  raise SystemExit(0)
305
308
 
@@ -115,7 +115,8 @@ def train(args: argparse.Namespace) -> None:
115
115
  torch.autograd.set_detect_anomaly(args.grad_anomaly_detection)
116
116
 
117
117
  batch_size: int = args.batch_size
118
- logger.debug(f"Effective batch size = {args.batch_size * args.grad_accum_steps * args.world_size}")
118
+ grad_accum_steps: int = args.grad_accum_steps
119
+ logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
119
120
 
120
121
  begin_epoch = 1
121
122
  epochs = args.epochs + 1
@@ -278,6 +279,7 @@ def train(args: argparse.Namespace) -> None:
278
279
  drop_last=args.drop_last,
279
280
  )
280
281
 
282
+ optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
281
283
  last_batch_idx = len(training_loader) - 1
282
284
 
283
285
  #
@@ -300,22 +302,21 @@ def train(args: argparse.Namespace) -> None:
300
302
  # Learning rate scaling
301
303
  lr = training_utils.scale_lr(args)
302
304
  clustering_lr = lr / 2
303
- grad_accum_steps: int = args.grad_accum_steps
304
305
 
305
306
  if args.lr_scheduler_update == "epoch":
306
307
  step_update = False
307
- steps_per_epoch = 1
308
+ scheduler_steps_per_epoch = 1
308
309
  elif args.lr_scheduler_update == "step":
309
310
  step_update = True
310
- steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
311
+ scheduler_steps_per_epoch = optimizer_steps_per_epoch
311
312
  else:
312
313
  raise ValueError("Unsupported lr_scheduler_update")
313
314
 
314
315
  # Optimizer and learning rate scheduler
315
316
  optimizer = training_utils.get_optimizer(parameters, lr, args)
316
317
  clustering_optimizer = torch.optim.AdamW(teacher.head.parameters(), lr=clustering_lr, betas=[0.9, 0.95])
317
- scheduler = training_utils.get_scheduler(optimizer, steps_per_epoch, args)
318
- clustering_scheduler = training_utils.get_scheduler(clustering_optimizer, steps_per_epoch, args)
318
+ scheduler = training_utils.get_scheduler(optimizer, scheduler_steps_per_epoch, args)
319
+ clustering_scheduler = training_utils.get_scheduler(clustering_optimizer, scheduler_steps_per_epoch, args)
319
320
  if args.compile_opt is True:
320
321
  optimizer.step = torch.compile(optimizer.step, fullgraph=False)
321
322
  clustering_optimizer.step = torch.compile(clustering_optimizer.step, fullgraph=False)
@@ -324,7 +325,7 @@ def train(args: argparse.Namespace) -> None:
324
325
  if args.warmup_epochs is not None:
325
326
  warmup_epochs = args.warmup_epochs
326
327
  elif args.warmup_steps is not None:
327
- warmup_epochs = args.warmup_steps / steps_per_epoch
328
+ warmup_epochs = args.warmup_steps / scheduler_steps_per_epoch
328
329
  else:
329
330
  warmup_epochs = 0.0
330
331
 
@@ -353,11 +354,13 @@ def train(args: argparse.Namespace) -> None:
353
354
  optimizer.step()
354
355
  lrs = []
355
356
  for _ in range(begin_epoch, epochs):
356
- for _ in range(steps_per_epoch):
357
+ for _ in range(scheduler_steps_per_epoch):
357
358
  lrs.append(float(max(scheduler.get_last_lr())))
358
359
  scheduler.step()
359
360
 
360
- plt.plot(np.linspace(begin_epoch, epochs, steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs)
361
+ plt.plot(
362
+ np.linspace(begin_epoch, epochs, scheduler_steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs
363
+ )
361
364
  plt.show()
362
365
  raise SystemExit(0)
363
366
 
@@ -106,7 +106,8 @@ def train(args: argparse.Namespace) -> None:
106
106
  torch.autograd.set_detect_anomaly(args.grad_anomaly_detection)
107
107
 
108
108
  batch_size: int = args.batch_size
109
- logger.debug(f"Effective batch size = {args.batch_size * args.grad_accum_steps * args.world_size}")
109
+ grad_accum_steps: int = args.grad_accum_steps
110
+ logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
110
111
 
111
112
  begin_epoch = 1
112
113
  epochs = args.epochs + 1
@@ -245,6 +246,7 @@ def train(args: argparse.Namespace) -> None:
245
246
  drop_last=args.drop_last,
246
247
  )
247
248
 
249
+ optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
248
250
  last_batch_idx = len(training_loader) - 1
249
251
 
250
252
  #
@@ -266,20 +268,19 @@ def train(args: argparse.Namespace) -> None:
266
268
 
267
269
  # Learning rate scaling
268
270
  lr = training_utils.scale_lr(args)
269
- grad_accum_steps: int = args.grad_accum_steps
270
271
 
271
272
  if args.lr_scheduler_update == "epoch":
272
273
  step_update = False
273
- steps_per_epoch = 1
274
+ scheduler_steps_per_epoch = 1
274
275
  elif args.lr_scheduler_update == "step":
275
276
  step_update = True
276
- steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
277
+ scheduler_steps_per_epoch = optimizer_steps_per_epoch
277
278
  else:
278
279
  raise ValueError("Unsupported lr_scheduler_update")
279
280
 
280
281
  # Optimizer and learning rate scheduler
281
282
  optimizer = training_utils.get_optimizer(parameters, lr, args)
282
- scheduler = training_utils.get_scheduler(optimizer, steps_per_epoch, args)
283
+ scheduler = training_utils.get_scheduler(optimizer, scheduler_steps_per_epoch, args)
283
284
  if args.compile_opt is True:
284
285
  optimizer.step = torch.compile(optimizer.step, fullgraph=False)
285
286
 
@@ -310,11 +311,13 @@ def train(args: argparse.Namespace) -> None:
310
311
  optimizer.step()
311
312
  lrs = []
312
313
  for _ in range(begin_epoch, epochs):
313
- for _ in range(steps_per_epoch):
314
+ for _ in range(scheduler_steps_per_epoch):
314
315
  lrs.append(float(max(scheduler.get_last_lr())))
315
316
  scheduler.step()
316
317
 
317
- plt.plot(np.linspace(begin_epoch, epochs, steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs)
318
+ plt.plot(
319
+ np.linspace(begin_epoch, epochs, scheduler_steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs
320
+ )
318
321
  plt.show()
319
322
  raise SystemExit(0)
320
323
 
@@ -112,7 +112,8 @@ def train(args: argparse.Namespace) -> None:
112
112
  torch.autograd.set_detect_anomaly(args.grad_anomaly_detection)
113
113
 
114
114
  batch_size: int = args.batch_size
115
- logger.debug(f"Effective batch size = {args.batch_size * args.grad_accum_steps * args.world_size}")
115
+ grad_accum_steps: int = args.grad_accum_steps
116
+ logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
116
117
 
117
118
  begin_epoch = 1
118
119
  epochs = args.epochs + 1
@@ -254,6 +255,7 @@ def train(args: argparse.Namespace) -> None:
254
255
  drop_last=args.drop_last,
255
256
  )
256
257
 
258
+ optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
257
259
  last_batch_idx = len(training_loader) - 1
258
260
 
259
261
  #
@@ -275,20 +277,19 @@ def train(args: argparse.Namespace) -> None:
275
277
 
276
278
  # Learning rate scaling
277
279
  lr = training_utils.scale_lr(args)
278
- grad_accum_steps: int = args.grad_accum_steps
279
280
 
280
281
  if args.lr_scheduler_update == "epoch":
281
282
  step_update = False
282
- steps_per_epoch = 1
283
+ scheduler_steps_per_epoch = 1
283
284
  elif args.lr_scheduler_update == "step":
284
285
  step_update = True
285
- steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
286
+ scheduler_steps_per_epoch = optimizer_steps_per_epoch
286
287
  else:
287
288
  raise ValueError("Unsupported lr_scheduler_update")
288
289
 
289
290
  # Optimizer and learning rate scheduler
290
291
  optimizer = training_utils.get_optimizer(parameters, lr, args)
291
- scheduler = training_utils.get_scheduler(optimizer, steps_per_epoch, args)
292
+ scheduler = training_utils.get_scheduler(optimizer, scheduler_steps_per_epoch, args)
292
293
  if args.compile_opt is True:
293
294
  optimizer.step = torch.compile(optimizer.step, fullgraph=False)
294
295
 
@@ -319,11 +320,13 @@ def train(args: argparse.Namespace) -> None:
319
320
  optimizer.step()
320
321
  lrs = []
321
322
  for _ in range(begin_epoch, epochs):
322
- for _ in range(steps_per_epoch):
323
+ for _ in range(scheduler_steps_per_epoch):
323
324
  lrs.append(float(max(scheduler.get_last_lr())))
324
325
  scheduler.step()
325
326
 
326
- plt.plot(np.linspace(begin_epoch, epochs, steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs)
327
+ plt.plot(
328
+ np.linspace(begin_epoch, epochs, scheduler_steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs
329
+ )
327
330
  plt.show()
328
331
  raise SystemExit(0)
329
332
 
@@ -192,8 +192,9 @@ def train(args: argparse.Namespace) -> None:
192
192
 
193
193
  num_outputs = len(class_to_idx) # Does not include background class
194
194
  batch_size: int = args.batch_size
195
- model_ema_steps: int = args.model_ema_steps * args.grad_accum_steps
196
- logger.debug(f"Effective batch size = {args.batch_size * args.grad_accum_steps * args.world_size}")
195
+ grad_accum_steps: int = args.grad_accum_steps
196
+ model_ema_steps: int = args.model_ema_steps
197
+ logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
197
198
 
198
199
  # Data loaders and samplers
199
200
  (train_sampler, validation_sampler) = training_utils.get_samplers(args, training_dataset, validation_dataset)
@@ -224,8 +225,8 @@ def train(args: argparse.Namespace) -> None:
224
225
  drop_last=args.drop_last,
225
226
  )
226
227
 
227
- optimizer_steps_per_epoch = math.ceil(len(training_loader) / args.grad_accum_steps)
228
- assert args.model_ema is False or args.model_ema_steps <= optimizer_steps_per_epoch
228
+ optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
229
+ assert args.model_ema is False or model_ema_steps <= optimizer_steps_per_epoch
229
230
 
230
231
  last_batch_idx = len(training_loader) - 1
231
232
  begin_epoch = 1
@@ -369,20 +370,19 @@ def train(args: argparse.Namespace) -> None:
369
370
 
370
371
  # Learning rate scaling
371
372
  lr = training_utils.scale_lr(args)
372
- grad_accum_steps: int = args.grad_accum_steps
373
373
 
374
374
  if args.lr_scheduler_update == "epoch":
375
375
  step_update = False
376
- steps_per_epoch = 1
376
+ scheduler_steps_per_epoch = 1
377
377
  elif args.lr_scheduler_update == "step":
378
378
  step_update = True
379
- steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
379
+ scheduler_steps_per_epoch = optimizer_steps_per_epoch
380
380
  else:
381
381
  raise ValueError("Unsupported lr_scheduler_update")
382
382
 
383
383
  # Optimizer and learning rate scheduler
384
384
  optimizer = training_utils.get_optimizer(parameters, lr, args)
385
- scheduler = training_utils.get_scheduler(optimizer, steps_per_epoch, args)
385
+ scheduler = training_utils.get_scheduler(optimizer, scheduler_steps_per_epoch, args)
386
386
  if args.compile_opt is True:
387
387
  optimizer.step = torch.compile(optimizer.step, fullgraph=False)
388
388
 
@@ -408,11 +408,14 @@ def train(args: argparse.Namespace) -> None:
408
408
  optimizer.step()
409
409
  lrs = []
410
410
  for _ in range(begin_epoch, epochs):
411
- for _ in range(steps_per_epoch):
411
+ for _ in range(scheduler_steps_per_epoch):
412
412
  lrs.append(float(max(scheduler.get_last_lr())))
413
413
  scheduler.step()
414
414
 
415
- plt.plot(np.linspace(begin_epoch, epochs, steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs)
415
+ plt.plot(
416
+ np.linspace(begin_epoch, epochs, scheduler_steps_per_epoch * (epochs - begin_epoch), endpoint=False),
417
+ lrs,
418
+ )
416
419
  plt.show()
417
420
  raise SystemExit(0)
418
421
 
@@ -420,15 +423,15 @@ def train(args: argparse.Namespace) -> None:
420
423
  # Distributed (DDP) and Model EMA
421
424
  #
422
425
  if args.model_ema_warmup is not None:
423
- ema_warmup_epochs = args.model_ema_warmup
426
+ ema_warmup_steps = args.model_ema_warmup * optimizer_steps_per_epoch
424
427
  elif args.warmup_epochs is not None:
425
- ema_warmup_epochs = args.warmup_epochs
428
+ ema_warmup_steps = args.warmup_epochs * optimizer_steps_per_epoch
426
429
  elif args.warmup_steps is not None:
427
- ema_warmup_epochs = args.warmup_steps // steps_per_epoch
430
+ ema_warmup_steps = args.warmup_steps
428
431
  else:
429
- ema_warmup_epochs = 0
432
+ ema_warmup_steps = 0
430
433
 
431
- logger.debug(f"EMA warmup epochs = {ema_warmup_epochs}")
434
+ logger.debug(f"EMA warmup steps = {ema_warmup_steps}")
432
435
  net_without_ddp = net
433
436
  if args.distributed is True:
434
437
  net = torch.nn.parallel.DistributedDataParallel(
@@ -532,11 +535,13 @@ def train(args: argparse.Namespace) -> None:
532
535
  #
533
536
  # Training loop
534
537
  #
538
+ optimizer_step = (begin_epoch - 1) * optimizer_steps_per_epoch
535
539
  logger.info(f"Starting training with learning rate of {last_lr}")
536
540
  for epoch in range(begin_epoch, args.stop_epoch):
537
541
  tic = time.time()
538
542
  net.train()
539
543
  running_loss = training_utils.SmoothedValue()
544
+ loss_trackers: dict[str, training_utils.SmoothedValue] = {}
540
545
  validation_metrics.reset()
541
546
 
542
547
  if args.distributed is True:
@@ -598,16 +603,28 @@ def train(args: argparse.Namespace) -> None:
598
603
  if step_update is True:
599
604
  scheduler.step()
600
605
 
606
+ if optimizer_update is True:
607
+ optimizer_step += 1
608
+
601
609
  # Exponential moving average
602
- if args.model_ema is True and i % model_ema_steps == 0:
610
+ if args.model_ema is True and optimizer_update is True and optimizer_step % model_ema_steps == 0:
603
611
  model_ema.update_parameters(net)
604
- if epoch <= ema_warmup_epochs:
612
+ if ema_warmup_steps > 0 and optimizer_step <= ema_warmup_steps:
605
613
  # Reset ema buffer to keep copying weights during warmup period
606
614
  model_ema.n_averaged.fill_(0) # pylint: disable=no-member
607
615
 
608
616
  # Statistics
609
617
  running_loss.update(loss.detach())
610
618
 
619
+ # Dynamically create trackers on first batch
620
+ if len(loss_trackers) == 0:
621
+ for key in losses.keys():
622
+ loss_trackers[key] = training_utils.SmoothedValue()
623
+
624
+ # Update individual loss trackers
625
+ for key, value in losses.items():
626
+ loss_trackers[key].update(value.detach())
627
+
611
628
  # Write statistics
612
629
  if (i == last_batch_idx) or (i + 1) % args.log_interval == 0:
613
630
  time_now = time.time()
@@ -624,6 +641,9 @@ def train(args: argparse.Namespace) -> None:
624
641
  cur_lr = float(max(scheduler.get_last_lr()))
625
642
 
626
643
  running_loss.synchronize_between_processes(device)
644
+ for tracker in loss_trackers.values():
645
+ tracker.synchronize_between_processes(device)
646
+
627
647
  with training_utils.single_handler_logging(logger, file_handler, enabled=not disable_tqdm) as log:
628
648
  log.info(
629
649
  f"[Trn] Epoch {epoch}/{epochs-1}, iter {i+1}/{last_batch_idx+1} "
@@ -636,9 +656,11 @@ def train(args: argparse.Namespace) -> None:
636
656
  )
637
657
 
638
658
  if training_utils.is_local_primary(args) is True:
659
+ loss_dict = {"training": running_loss.avg}
660
+ loss_dict.update({k: v.avg for k, v in loss_trackers.items()})
639
661
  summary_writer.add_scalars(
640
662
  "loss",
641
- {"training": running_loss.avg},
663
+ loss_dict,
642
664
  ((epoch - 1) * len(training_dataset)) + (i * batch_size * args.world_size),
643
665
  )
644
666
 
@@ -649,6 +671,8 @@ def train(args: argparse.Namespace) -> None:
649
671
 
650
672
  # Epoch training metrics
651
673
  logger.info(f"[Trn] Epoch {epoch}/{epochs-1} training_loss: {running_loss.global_avg:.4f}")
674
+ for key, tracker in loss_trackers.items():
675
+ logger.info(f"[Trn] Epoch {epoch}/{epochs-1} {key}: {tracker.global_avg:.4f}")
652
676
 
653
677
  # Validation
654
678
  eval_model.eval()
@@ -186,7 +186,8 @@ def train(args: argparse.Namespace) -> None:
186
186
  logger.info(f"Training on {len(training_dataset):,} samples")
187
187
 
188
188
  batch_size: int = args.batch_size
189
- logger.debug(f"Effective batch size = {args.batch_size * args.grad_accum_steps * args.world_size}")
189
+ grad_accum_steps: int = args.grad_accum_steps
190
+ logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
190
191
 
191
192
  # Data loaders and samplers
192
193
  if args.distributed is True:
@@ -218,6 +219,7 @@ def train(args: argparse.Namespace) -> None:
218
219
  drop_last=args.drop_last,
219
220
  )
220
221
 
222
+ optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
221
223
  last_batch_idx = len(training_loader) - 1
222
224
  begin_epoch = 1
223
225
  epochs = args.epochs + 1
@@ -352,20 +354,19 @@ def train(args: argparse.Namespace) -> None:
352
354
 
353
355
  # Learning rate scaling
354
356
  lr = training_utils.scale_lr(args)
355
- grad_accum_steps: int = args.grad_accum_steps
356
357
 
357
358
  if args.lr_scheduler_update == "epoch":
358
359
  step_update = False
359
- steps_per_epoch = 1
360
+ scheduler_steps_per_epoch = 1
360
361
  elif args.lr_scheduler_update == "step":
361
362
  step_update = True
362
- steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
363
+ scheduler_steps_per_epoch = optimizer_steps_per_epoch
363
364
  else:
364
365
  raise ValueError("Unsupported lr_scheduler_update")
365
366
 
366
367
  # Optimizer and learning rate scheduler
367
368
  optimizer = training_utils.get_optimizer(parameters, lr, args)
368
- scheduler = training_utils.get_scheduler(optimizer, steps_per_epoch, args)
369
+ scheduler = training_utils.get_scheduler(optimizer, scheduler_steps_per_epoch, args)
369
370
  if args.compile_opt is True:
370
371
  optimizer.step = torch.compile(optimizer.step, fullgraph=False)
371
372
 
@@ -398,11 +399,13 @@ def train(args: argparse.Namespace) -> None:
398
399
  optimizer.step()
399
400
  lrs = []
400
401
  for _ in range(begin_epoch, epochs):
401
- for _ in range(steps_per_epoch):
402
+ for _ in range(scheduler_steps_per_epoch):
402
403
  lrs.append(float(max(scheduler.get_last_lr())))
403
404
  scheduler.step()
404
405
 
405
- plt.plot(np.linspace(begin_epoch, epochs, steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs)
406
+ plt.plot(
407
+ np.linspace(begin_epoch, epochs, scheduler_steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs
408
+ )
406
409
  plt.show()
407
410
  raise SystemExit(0)
408
411