birder 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/adversarial/__init__.py +13 -0
- birder/adversarial/base.py +101 -0
- birder/adversarial/deepfool.py +173 -0
- birder/adversarial/fgsm.py +51 -18
- birder/adversarial/pgd.py +79 -28
- birder/adversarial/simba.py +172 -0
- birder/common/training_cli.py +11 -3
- birder/common/training_utils.py +18 -1
- birder/inference/data_parallel.py +1 -2
- birder/introspection/__init__.py +10 -6
- birder/introspection/attention_rollout.py +122 -54
- birder/introspection/base.py +73 -29
- birder/introspection/gradcam.py +71 -100
- birder/introspection/guided_backprop.py +146 -72
- birder/introspection/transformer_attribution.py +182 -0
- birder/net/detection/deformable_detr.py +14 -12
- birder/net/detection/detr.py +7 -3
- birder/net/detection/rt_detr_v1.py +3 -3
- birder/net/detection/yolo_v3.py +6 -11
- birder/net/detection/yolo_v4.py +7 -18
- birder/net/detection/yolo_v4_tiny.py +3 -3
- birder/net/fastvit.py +1 -1
- birder/net/mim/mae_vit.py +7 -8
- birder/net/pit.py +1 -1
- birder/net/resnet_v1.py +94 -34
- birder/net/ssl/data2vec.py +1 -1
- birder/net/ssl/data2vec2.py +4 -2
- birder/results/gui.py +15 -2
- birder/scripts/predict_detection.py +33 -1
- birder/scripts/train.py +24 -17
- birder/scripts/train_barlow_twins.py +10 -7
- birder/scripts/train_byol.py +10 -7
- birder/scripts/train_capi.py +12 -9
- birder/scripts/train_data2vec.py +10 -7
- birder/scripts/train_data2vec2.py +10 -7
- birder/scripts/train_detection.py +42 -18
- birder/scripts/train_dino_v1.py +10 -7
- birder/scripts/train_dino_v2.py +10 -7
- birder/scripts/train_dino_v2_dist.py +17 -7
- birder/scripts/train_franca.py +10 -7
- birder/scripts/train_i_jepa.py +17 -13
- birder/scripts/train_ibot.py +10 -7
- birder/scripts/train_kd.py +24 -18
- birder/scripts/train_mim.py +11 -10
- birder/scripts/train_mmcr.py +10 -7
- birder/scripts/train_rotnet.py +10 -7
- birder/scripts/train_simclr.py +10 -7
- birder/scripts/train_vicreg.py +10 -7
- birder/tools/__main__.py +6 -2
- birder/tools/adversarial.py +147 -96
- birder/tools/auto_anchors.py +361 -0
- birder/tools/ensemble_model.py +1 -1
- birder/tools/introspection.py +58 -31
- birder/version.py +1 -1
- {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/METADATA +2 -1
- {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/RECORD +60 -55
- {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/WHEEL +0 -0
- {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/entry_points.txt +0 -0
- {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {birder-0.2.1.dist-info → birder-0.2.2.dist-info}/top_level.txt +0 -0
birder/scripts/train_dino_v2.py
CHANGED
|
@@ -214,7 +214,8 @@ def train(args: argparse.Namespace) -> None:
|
|
|
214
214
|
torch.autograd.set_detect_anomaly(args.grad_anomaly_detection)
|
|
215
215
|
|
|
216
216
|
batch_size: int = args.batch_size
|
|
217
|
-
|
|
217
|
+
grad_accum_steps: int = args.grad_accum_steps
|
|
218
|
+
logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
|
|
218
219
|
|
|
219
220
|
begin_epoch = 1
|
|
220
221
|
epochs = args.epochs + 1
|
|
@@ -417,6 +418,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
417
418
|
drop_last=args.drop_last,
|
|
418
419
|
)
|
|
419
420
|
|
|
421
|
+
optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
|
|
420
422
|
last_batch_idx = len(training_loader) - 1
|
|
421
423
|
|
|
422
424
|
#
|
|
@@ -438,20 +440,19 @@ def train(args: argparse.Namespace) -> None:
|
|
|
438
440
|
|
|
439
441
|
# Learning rate scaling
|
|
440
442
|
lr = training_utils.scale_lr(args)
|
|
441
|
-
grad_accum_steps: int = args.grad_accum_steps
|
|
442
443
|
|
|
443
444
|
if args.lr_scheduler_update == "epoch":
|
|
444
445
|
step_update = False
|
|
445
|
-
|
|
446
|
+
scheduler_steps_per_epoch = 1
|
|
446
447
|
elif args.lr_scheduler_update == "step":
|
|
447
448
|
step_update = True
|
|
448
|
-
|
|
449
|
+
scheduler_steps_per_epoch = optimizer_steps_per_epoch
|
|
449
450
|
else:
|
|
450
451
|
raise ValueError("Unsupported lr_scheduler_update")
|
|
451
452
|
|
|
452
453
|
# Optimizer and learning rate scheduler
|
|
453
454
|
optimizer = training_utils.get_optimizer(parameters, lr, args)
|
|
454
|
-
scheduler = training_utils.get_scheduler(optimizer,
|
|
455
|
+
scheduler = training_utils.get_scheduler(optimizer, scheduler_steps_per_epoch, args)
|
|
455
456
|
if args.compile_opt is True:
|
|
456
457
|
optimizer.step = torch.compile(optimizer.step, fullgraph=False)
|
|
457
458
|
|
|
@@ -492,11 +493,13 @@ def train(args: argparse.Namespace) -> None:
|
|
|
492
493
|
optimizer.step()
|
|
493
494
|
lrs = []
|
|
494
495
|
for _ in range(begin_epoch, epochs):
|
|
495
|
-
for _ in range(
|
|
496
|
+
for _ in range(scheduler_steps_per_epoch):
|
|
496
497
|
lrs.append(float(max(scheduler.get_last_lr())))
|
|
497
498
|
scheduler.step()
|
|
498
499
|
|
|
499
|
-
plt.plot(
|
|
500
|
+
plt.plot(
|
|
501
|
+
np.linspace(begin_epoch, epochs, scheduler_steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs
|
|
502
|
+
)
|
|
500
503
|
plt.show()
|
|
501
504
|
raise SystemExit(0)
|
|
502
505
|
|
|
@@ -215,7 +215,8 @@ def train(args: argparse.Namespace) -> None:
|
|
|
215
215
|
torch.autograd.set_detect_anomaly(args.grad_anomaly_detection)
|
|
216
216
|
|
|
217
217
|
batch_size: int = args.batch_size
|
|
218
|
-
|
|
218
|
+
grad_accum_steps: int = args.grad_accum_steps
|
|
219
|
+
logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
|
|
219
220
|
|
|
220
221
|
begin_epoch = 1
|
|
221
222
|
epochs = args.epochs + 1
|
|
@@ -240,6 +241,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
240
241
|
args.network, sample_shape[1], 0, config=args.model_config, size=args.size
|
|
241
242
|
)
|
|
242
243
|
student_backbone_ema.load_state_dict(student_backbone.state_dict())
|
|
244
|
+
student_backbone_ema.requires_grad_(False)
|
|
243
245
|
|
|
244
246
|
teacher_backbone = registry.net_factory(
|
|
245
247
|
args.teacher,
|
|
@@ -248,6 +250,11 @@ def train(args: argparse.Namespace) -> None:
|
|
|
248
250
|
config=args.teacher_model_config,
|
|
249
251
|
size=args.size,
|
|
250
252
|
)
|
|
253
|
+
assert student_backbone.max_stride == teacher_backbone.max_stride, (
|
|
254
|
+
"Student and teacher max_stride must match for distillation "
|
|
255
|
+
f"(student={student_backbone.max_stride}, teacher={teacher_backbone.max_stride})"
|
|
256
|
+
)
|
|
257
|
+
|
|
251
258
|
student_backbone.set_dynamic_size()
|
|
252
259
|
if args.ibot_separate_head is False:
|
|
253
260
|
args.ibot_out_dim = args.dino_out_dim
|
|
@@ -433,6 +440,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
433
440
|
drop_last=args.drop_last,
|
|
434
441
|
)
|
|
435
442
|
|
|
443
|
+
optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
|
|
436
444
|
last_batch_idx = len(training_loader) - 1
|
|
437
445
|
|
|
438
446
|
#
|
|
@@ -454,20 +462,19 @@ def train(args: argparse.Namespace) -> None:
|
|
|
454
462
|
|
|
455
463
|
# Learning rate scaling
|
|
456
464
|
lr = training_utils.scale_lr(args)
|
|
457
|
-
grad_accum_steps: int = args.grad_accum_steps
|
|
458
465
|
|
|
459
466
|
if args.lr_scheduler_update == "epoch":
|
|
460
467
|
step_update = False
|
|
461
|
-
|
|
468
|
+
scheduler_steps_per_epoch = 1
|
|
462
469
|
elif args.lr_scheduler_update == "step":
|
|
463
470
|
step_update = True
|
|
464
|
-
|
|
471
|
+
scheduler_steps_per_epoch = optimizer_steps_per_epoch
|
|
465
472
|
else:
|
|
466
473
|
raise ValueError("Unsupported lr_scheduler_update")
|
|
467
474
|
|
|
468
475
|
# Optimizer and learning rate scheduler
|
|
469
476
|
optimizer = training_utils.get_optimizer(parameters, lr, args)
|
|
470
|
-
scheduler = training_utils.get_scheduler(optimizer,
|
|
477
|
+
scheduler = training_utils.get_scheduler(optimizer, scheduler_steps_per_epoch, args)
|
|
471
478
|
if args.compile_opt is True:
|
|
472
479
|
optimizer.step = torch.compile(optimizer.step, fullgraph=False)
|
|
473
480
|
|
|
@@ -507,11 +514,13 @@ def train(args: argparse.Namespace) -> None:
|
|
|
507
514
|
optimizer.step()
|
|
508
515
|
lrs = []
|
|
509
516
|
for _ in range(begin_epoch, epochs):
|
|
510
|
-
for _ in range(
|
|
517
|
+
for _ in range(scheduler_steps_per_epoch):
|
|
511
518
|
lrs.append(float(max(scheduler.get_last_lr())))
|
|
512
519
|
scheduler.step()
|
|
513
520
|
|
|
514
|
-
plt.plot(
|
|
521
|
+
plt.plot(
|
|
522
|
+
np.linspace(begin_epoch, epochs, scheduler_steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs
|
|
523
|
+
)
|
|
515
524
|
plt.show()
|
|
516
525
|
raise SystemExit(0)
|
|
517
526
|
|
|
@@ -604,6 +613,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
604
613
|
for epoch in range(begin_epoch, args.stop_epoch):
|
|
605
614
|
tic = time.time()
|
|
606
615
|
net.train()
|
|
616
|
+
teacher.eval()
|
|
607
617
|
running_loss = training_utils.SmoothedValue()
|
|
608
618
|
running_loss_dino_local = training_utils.SmoothedValue()
|
|
609
619
|
running_loss_dino_global = training_utils.SmoothedValue()
|
birder/scripts/train_franca.py
CHANGED
|
@@ -241,7 +241,8 @@ def train(args: argparse.Namespace) -> None:
|
|
|
241
241
|
torch.autograd.set_detect_anomaly(args.grad_anomaly_detection)
|
|
242
242
|
|
|
243
243
|
batch_size: int = args.batch_size
|
|
244
|
-
|
|
244
|
+
grad_accum_steps: int = args.grad_accum_steps
|
|
245
|
+
logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
|
|
245
246
|
|
|
246
247
|
begin_epoch = 1
|
|
247
248
|
epochs = args.epochs + 1
|
|
@@ -444,6 +445,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
444
445
|
drop_last=args.drop_last,
|
|
445
446
|
)
|
|
446
447
|
|
|
448
|
+
optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
|
|
447
449
|
last_batch_idx = len(training_loader) - 1
|
|
448
450
|
|
|
449
451
|
#
|
|
@@ -465,20 +467,19 @@ def train(args: argparse.Namespace) -> None:
|
|
|
465
467
|
|
|
466
468
|
# Learning rate scaling
|
|
467
469
|
lr = training_utils.scale_lr(args)
|
|
468
|
-
grad_accum_steps: int = args.grad_accum_steps
|
|
469
470
|
|
|
470
471
|
if args.lr_scheduler_update == "epoch":
|
|
471
472
|
step_update = False
|
|
472
|
-
|
|
473
|
+
scheduler_steps_per_epoch = 1
|
|
473
474
|
elif args.lr_scheduler_update == "step":
|
|
474
475
|
step_update = True
|
|
475
|
-
|
|
476
|
+
scheduler_steps_per_epoch = optimizer_steps_per_epoch
|
|
476
477
|
else:
|
|
477
478
|
raise ValueError("Unsupported lr_scheduler_update")
|
|
478
479
|
|
|
479
480
|
# Optimizer and learning rate scheduler
|
|
480
481
|
optimizer = training_utils.get_optimizer(parameters, lr, args)
|
|
481
|
-
scheduler = training_utils.get_scheduler(optimizer,
|
|
482
|
+
scheduler = training_utils.get_scheduler(optimizer, scheduler_steps_per_epoch, args)
|
|
482
483
|
if args.compile_opt is True:
|
|
483
484
|
optimizer.step = torch.compile(optimizer.step, fullgraph=False)
|
|
484
485
|
|
|
@@ -519,11 +520,13 @@ def train(args: argparse.Namespace) -> None:
|
|
|
519
520
|
optimizer.step()
|
|
520
521
|
lrs = []
|
|
521
522
|
for _ in range(begin_epoch, epochs):
|
|
522
|
-
for _ in range(
|
|
523
|
+
for _ in range(scheduler_steps_per_epoch):
|
|
523
524
|
lrs.append(float(max(scheduler.get_last_lr())))
|
|
524
525
|
scheduler.step()
|
|
525
526
|
|
|
526
|
-
plt.plot(
|
|
527
|
+
plt.plot(
|
|
528
|
+
np.linspace(begin_epoch, epochs, scheduler_steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs
|
|
529
|
+
)
|
|
527
530
|
plt.show()
|
|
528
531
|
raise SystemExit(0)
|
|
529
532
|
|
birder/scripts/train_i_jepa.py
CHANGED
|
@@ -120,7 +120,8 @@ def train(args: argparse.Namespace) -> None:
|
|
|
120
120
|
torch.autograd.set_detect_anomaly(args.grad_anomaly_detection)
|
|
121
121
|
|
|
122
122
|
batch_size: int = args.batch_size
|
|
123
|
-
|
|
123
|
+
grad_accum_steps: int = args.grad_accum_steps
|
|
124
|
+
logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
|
|
124
125
|
|
|
125
126
|
begin_epoch = 1
|
|
126
127
|
epochs = args.epochs + 1
|
|
@@ -284,6 +285,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
284
285
|
drop_last=args.drop_last,
|
|
285
286
|
)
|
|
286
287
|
|
|
288
|
+
optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
|
|
287
289
|
last_batch_idx = len(training_loader) - 1
|
|
288
290
|
|
|
289
291
|
#
|
|
@@ -305,20 +307,19 @@ def train(args: argparse.Namespace) -> None:
|
|
|
305
307
|
|
|
306
308
|
# Learning rate scaling
|
|
307
309
|
lr = training_utils.scale_lr(args)
|
|
308
|
-
grad_accum_steps: int = args.grad_accum_steps
|
|
309
310
|
|
|
310
311
|
if args.lr_scheduler_update == "epoch":
|
|
311
312
|
step_update = False
|
|
312
|
-
|
|
313
|
+
scheduler_steps_per_epoch = 1
|
|
313
314
|
elif args.lr_scheduler_update == "step":
|
|
314
315
|
step_update = True
|
|
315
|
-
|
|
316
|
+
scheduler_steps_per_epoch = optimizer_steps_per_epoch
|
|
316
317
|
else:
|
|
317
318
|
raise ValueError("Unsupported lr_scheduler_update")
|
|
318
319
|
|
|
319
320
|
# Optimizer and learning rate scheduler
|
|
320
321
|
optimizer = training_utils.get_optimizer(parameters, lr, args)
|
|
321
|
-
scheduler = training_utils.get_scheduler(optimizer,
|
|
322
|
+
scheduler = training_utils.get_scheduler(optimizer, scheduler_steps_per_epoch, args)
|
|
322
323
|
if args.compile_opt is True:
|
|
323
324
|
optimizer.step = torch.compile(optimizer.step, fullgraph=False)
|
|
324
325
|
|
|
@@ -351,11 +352,13 @@ def train(args: argparse.Namespace) -> None:
|
|
|
351
352
|
optimizer.step()
|
|
352
353
|
lrs = []
|
|
353
354
|
for _ in range(begin_epoch, epochs):
|
|
354
|
-
for _ in range(
|
|
355
|
+
for _ in range(scheduler_steps_per_epoch):
|
|
355
356
|
lrs.append(float(max(scheduler.get_last_lr())))
|
|
356
357
|
scheduler.step()
|
|
357
358
|
|
|
358
|
-
plt.plot(
|
|
359
|
+
plt.plot(
|
|
360
|
+
np.linspace(begin_epoch, epochs, scheduler_steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs
|
|
361
|
+
)
|
|
359
362
|
plt.show()
|
|
360
363
|
raise SystemExit(0)
|
|
361
364
|
|
|
@@ -521,12 +524,13 @@ def train(args: argparse.Namespace) -> None:
|
|
|
521
524
|
if step_update is True:
|
|
522
525
|
scheduler.step()
|
|
523
526
|
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
527
|
+
if optimizer_update is True:
|
|
528
|
+
# EMA update for the target encoder
|
|
529
|
+
with torch.no_grad():
|
|
530
|
+
m = momentum_schedule[global_iter]
|
|
531
|
+
torch._foreach_lerp_( # pylint: disable=protected-access
|
|
532
|
+
list(target_encoder.parameters()), list(encoder.parameters()), weight=1 - m
|
|
533
|
+
)
|
|
530
534
|
|
|
531
535
|
# Statistics
|
|
532
536
|
running_loss.update(loss.detach())
|
birder/scripts/train_ibot.py
CHANGED
|
@@ -143,7 +143,8 @@ def train(args: argparse.Namespace) -> None:
|
|
|
143
143
|
torch.autograd.set_detect_anomaly(args.grad_anomaly_detection)
|
|
144
144
|
|
|
145
145
|
batch_size: int = args.batch_size
|
|
146
|
-
|
|
146
|
+
grad_accum_steps: int = args.grad_accum_steps
|
|
147
|
+
logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
|
|
147
148
|
|
|
148
149
|
begin_epoch = 1
|
|
149
150
|
epochs = args.epochs + 1
|
|
@@ -351,6 +352,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
351
352
|
drop_last=args.drop_last,
|
|
352
353
|
)
|
|
353
354
|
|
|
355
|
+
optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
|
|
354
356
|
last_batch_idx = len(training_loader) - 1
|
|
355
357
|
|
|
356
358
|
#
|
|
@@ -372,20 +374,19 @@ def train(args: argparse.Namespace) -> None:
|
|
|
372
374
|
|
|
373
375
|
# Learning rate scaling
|
|
374
376
|
lr = training_utils.scale_lr(args)
|
|
375
|
-
grad_accum_steps: int = args.grad_accum_steps
|
|
376
377
|
|
|
377
378
|
if args.lr_scheduler_update == "epoch":
|
|
378
379
|
step_update = False
|
|
379
|
-
|
|
380
|
+
scheduler_steps_per_epoch = 1
|
|
380
381
|
elif args.lr_scheduler_update == "step":
|
|
381
382
|
step_update = True
|
|
382
|
-
|
|
383
|
+
scheduler_steps_per_epoch = optimizer_steps_per_epoch
|
|
383
384
|
else:
|
|
384
385
|
raise ValueError("Unsupported lr_scheduler_update")
|
|
385
386
|
|
|
386
387
|
# Optimizer and learning rate scheduler
|
|
387
388
|
optimizer = training_utils.get_optimizer(parameters, lr, args)
|
|
388
|
-
scheduler = training_utils.get_scheduler(optimizer,
|
|
389
|
+
scheduler = training_utils.get_scheduler(optimizer, scheduler_steps_per_epoch, args)
|
|
389
390
|
if args.compile_opt is True:
|
|
390
391
|
optimizer.step = torch.compile(optimizer.step, fullgraph=False)
|
|
391
392
|
|
|
@@ -418,11 +419,13 @@ def train(args: argparse.Namespace) -> None:
|
|
|
418
419
|
optimizer.step()
|
|
419
420
|
lrs = []
|
|
420
421
|
for _ in range(begin_epoch, epochs):
|
|
421
|
-
for _ in range(
|
|
422
|
+
for _ in range(scheduler_steps_per_epoch):
|
|
422
423
|
lrs.append(float(max(scheduler.get_last_lr())))
|
|
423
424
|
scheduler.step()
|
|
424
425
|
|
|
425
|
-
plt.plot(
|
|
426
|
+
plt.plot(
|
|
427
|
+
np.linspace(begin_epoch, epochs, scheduler_steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs
|
|
428
|
+
)
|
|
426
429
|
plt.show()
|
|
427
430
|
raise SystemExit(0)
|
|
428
431
|
|
birder/scripts/train_kd.py
CHANGED
|
@@ -186,8 +186,9 @@ def train(args: argparse.Namespace) -> None:
|
|
|
186
186
|
|
|
187
187
|
num_outputs = len(class_to_idx)
|
|
188
188
|
batch_size: int = args.batch_size
|
|
189
|
-
|
|
190
|
-
|
|
189
|
+
grad_accum_steps: int = args.grad_accum_steps
|
|
190
|
+
model_ema_steps: int = args.model_ema_steps
|
|
191
|
+
logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
|
|
191
192
|
|
|
192
193
|
# Set data iterators
|
|
193
194
|
if args.mixup_alpha is not None or args.cutmix is True:
|
|
@@ -246,8 +247,8 @@ def train(args: argparse.Namespace) -> None:
|
|
|
246
247
|
pin_memory=True,
|
|
247
248
|
)
|
|
248
249
|
|
|
249
|
-
optimizer_steps_per_epoch = math.ceil(len(training_loader) /
|
|
250
|
-
assert args.model_ema is False or
|
|
250
|
+
optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
|
|
251
|
+
assert args.model_ema is False or model_ema_steps <= optimizer_steps_per_epoch
|
|
251
252
|
|
|
252
253
|
last_batch_idx = len(training_loader) - 1
|
|
253
254
|
begin_epoch = 1
|
|
@@ -336,20 +337,18 @@ def train(args: argparse.Namespace) -> None:
|
|
|
336
337
|
|
|
337
338
|
# Learning rate scaling
|
|
338
339
|
lr = training_utils.scale_lr(args)
|
|
339
|
-
grad_accum_steps: int = args.grad_accum_steps
|
|
340
|
-
|
|
341
340
|
if args.lr_scheduler_update == "epoch":
|
|
342
341
|
step_update = False
|
|
343
|
-
|
|
342
|
+
scheduler_steps_per_epoch = 1
|
|
344
343
|
elif args.lr_scheduler_update == "step":
|
|
345
344
|
step_update = True
|
|
346
|
-
|
|
345
|
+
scheduler_steps_per_epoch = optimizer_steps_per_epoch
|
|
347
346
|
else:
|
|
348
347
|
raise ValueError("Unsupported lr_scheduler_update")
|
|
349
348
|
|
|
350
349
|
# Optimizer and learning rate scheduler
|
|
351
350
|
optimizer = training_utils.get_optimizer(parameters, lr, args)
|
|
352
|
-
scheduler = training_utils.get_scheduler(optimizer,
|
|
351
|
+
scheduler = training_utils.get_scheduler(optimizer, scheduler_steps_per_epoch, args)
|
|
353
352
|
if args.compile_opt is True:
|
|
354
353
|
optimizer.step = torch.compile(optimizer.step, fullgraph=False)
|
|
355
354
|
|
|
@@ -375,11 +374,14 @@ def train(args: argparse.Namespace) -> None:
|
|
|
375
374
|
optimizer.step()
|
|
376
375
|
lrs = []
|
|
377
376
|
for _ in range(begin_epoch, epochs):
|
|
378
|
-
for _ in range(
|
|
377
|
+
for _ in range(scheduler_steps_per_epoch):
|
|
379
378
|
lrs.append(float(max(scheduler.get_last_lr())))
|
|
380
379
|
scheduler.step()
|
|
381
380
|
|
|
382
|
-
plt.plot(
|
|
381
|
+
plt.plot(
|
|
382
|
+
np.linspace(begin_epoch, epochs, scheduler_steps_per_epoch * (epochs - begin_epoch), endpoint=False),
|
|
383
|
+
lrs,
|
|
384
|
+
)
|
|
383
385
|
plt.show()
|
|
384
386
|
raise SystemExit(0)
|
|
385
387
|
|
|
@@ -387,15 +389,15 @@ def train(args: argparse.Namespace) -> None:
|
|
|
387
389
|
# Distributed (DDP) and Model EMA
|
|
388
390
|
#
|
|
389
391
|
if args.model_ema_warmup is not None:
|
|
390
|
-
|
|
392
|
+
ema_warmup_steps = args.model_ema_warmup * optimizer_steps_per_epoch
|
|
391
393
|
elif args.warmup_epochs is not None:
|
|
392
|
-
|
|
394
|
+
ema_warmup_steps = args.warmup_epochs * optimizer_steps_per_epoch
|
|
393
395
|
elif args.warmup_steps is not None:
|
|
394
|
-
|
|
396
|
+
ema_warmup_steps = args.warmup_steps
|
|
395
397
|
else:
|
|
396
|
-
|
|
398
|
+
ema_warmup_steps = 0
|
|
397
399
|
|
|
398
|
-
logger.debug(f"EMA warmup
|
|
400
|
+
logger.debug(f"EMA warmup steps = {ema_warmup_steps}")
|
|
399
401
|
net_without_ddp = student
|
|
400
402
|
if args.distributed is True:
|
|
401
403
|
student = torch.nn.parallel.DistributedDataParallel(
|
|
@@ -493,6 +495,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
493
495
|
#
|
|
494
496
|
# Training loop
|
|
495
497
|
#
|
|
498
|
+
optimizer_step = (begin_epoch - 1) * optimizer_steps_per_epoch
|
|
496
499
|
logger.info(f"Starting training with learning rate of {last_lr}")
|
|
497
500
|
for epoch in range(begin_epoch, args.stop_epoch):
|
|
498
501
|
tic = time.time()
|
|
@@ -571,10 +574,13 @@ def train(args: argparse.Namespace) -> None:
|
|
|
571
574
|
if step_update is True:
|
|
572
575
|
scheduler.step()
|
|
573
576
|
|
|
577
|
+
if optimizer_update is True:
|
|
578
|
+
optimizer_step += 1
|
|
579
|
+
|
|
574
580
|
# Exponential moving average
|
|
575
|
-
if args.model_ema is True and
|
|
581
|
+
if args.model_ema is True and optimizer_update is True and optimizer_step % model_ema_steps == 0:
|
|
576
582
|
model_ema.update_parameters(student)
|
|
577
|
-
if
|
|
583
|
+
if ema_warmup_steps > 0 and optimizer_step <= ema_warmup_steps:
|
|
578
584
|
# Reset ema buffer to keep copying weights during warmup period
|
|
579
585
|
model_ema.n_averaged.fill_(0) # pylint: disable=no-member
|
|
580
586
|
|
birder/scripts/train_mim.py
CHANGED
|
@@ -130,7 +130,8 @@ def train(args: argparse.Namespace) -> None:
|
|
|
130
130
|
logger.info(f"Training on {len(training_dataset):,} samples")
|
|
131
131
|
|
|
132
132
|
batch_size: int = args.batch_size
|
|
133
|
-
|
|
133
|
+
grad_accum_steps: int = args.grad_accum_steps
|
|
134
|
+
logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
|
|
134
135
|
|
|
135
136
|
# Data loaders and samplers
|
|
136
137
|
if args.distributed is True:
|
|
@@ -162,6 +163,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
162
163
|
drop_last=args.drop_last,
|
|
163
164
|
)
|
|
164
165
|
|
|
166
|
+
optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
|
|
165
167
|
last_batch_idx = len(training_loader) - 1
|
|
166
168
|
begin_epoch = 1
|
|
167
169
|
epochs = args.epochs + 1
|
|
@@ -254,20 +256,19 @@ def train(args: argparse.Namespace) -> None:
|
|
|
254
256
|
|
|
255
257
|
# Learning rate scaling
|
|
256
258
|
lr = training_utils.scale_lr(args)
|
|
257
|
-
grad_accum_steps: int = args.grad_accum_steps
|
|
258
259
|
|
|
259
260
|
if args.lr_scheduler_update == "epoch":
|
|
260
261
|
step_update = False
|
|
261
|
-
|
|
262
|
+
scheduler_steps_per_epoch = 1
|
|
262
263
|
elif args.lr_scheduler_update == "step":
|
|
263
264
|
step_update = True
|
|
264
|
-
|
|
265
|
+
scheduler_steps_per_epoch = optimizer_steps_per_epoch
|
|
265
266
|
else:
|
|
266
267
|
raise ValueError("Unsupported lr_scheduler_update")
|
|
267
268
|
|
|
268
269
|
# Optimizer and learning rate scheduler
|
|
269
270
|
optimizer = training_utils.get_optimizer(parameters, lr, args)
|
|
270
|
-
scheduler = training_utils.get_scheduler(optimizer,
|
|
271
|
+
scheduler = training_utils.get_scheduler(optimizer, scheduler_steps_per_epoch, args)
|
|
271
272
|
if args.compile_opt is True:
|
|
272
273
|
optimizer.step = torch.compile(optimizer.step, fullgraph=False)
|
|
273
274
|
|
|
@@ -293,11 +294,13 @@ def train(args: argparse.Namespace) -> None:
|
|
|
293
294
|
optimizer.step()
|
|
294
295
|
lrs = []
|
|
295
296
|
for _ in range(begin_epoch, epochs):
|
|
296
|
-
for _ in range(
|
|
297
|
+
for _ in range(scheduler_steps_per_epoch):
|
|
297
298
|
lrs.append(float(max(scheduler.get_last_lr())))
|
|
298
299
|
scheduler.step()
|
|
299
300
|
|
|
300
|
-
plt.plot(
|
|
301
|
+
plt.plot(
|
|
302
|
+
np.linspace(begin_epoch, epochs, scheduler_steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs
|
|
303
|
+
)
|
|
301
304
|
plt.show()
|
|
302
305
|
raise SystemExit(0)
|
|
303
306
|
|
|
@@ -599,9 +602,7 @@ def get_args_parser() -> argparse.ArgumentParser:
|
|
|
599
602
|
"('drop_path_rate=0.2' or '{\"units\": [3, 24, 36, 3], \"dropout\": 0.2}'"
|
|
600
603
|
),
|
|
601
604
|
)
|
|
602
|
-
parser.add_argument(
|
|
603
|
-
"--mask-ratio", type=float, default=None, help="mask ratio for MIM training (default: model-specific)"
|
|
604
|
-
)
|
|
605
|
+
parser.add_argument("--mask-ratio", type=float, help="mask ratio for MIM training (default: model-specific)")
|
|
605
606
|
parser.add_argument("--min-mask-size", type=int, default=1, help="minimum mask unit size in patches")
|
|
606
607
|
training_cli.add_optimization_args(parser)
|
|
607
608
|
training_cli.add_lr_wd_args(parser)
|
birder/scripts/train_mmcr.py
CHANGED
|
@@ -154,7 +154,8 @@ def train(args: argparse.Namespace) -> None:
|
|
|
154
154
|
logger.info(f"Training on {len(training_dataset):,} samples")
|
|
155
155
|
|
|
156
156
|
batch_size: int = args.batch_size
|
|
157
|
-
|
|
157
|
+
grad_accum_steps: int = args.grad_accum_steps
|
|
158
|
+
logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
|
|
158
159
|
|
|
159
160
|
# Data loaders and samplers
|
|
160
161
|
if args.distributed is True:
|
|
@@ -186,6 +187,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
186
187
|
drop_last=args.drop_last,
|
|
187
188
|
)
|
|
188
189
|
|
|
190
|
+
optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
|
|
189
191
|
last_batch_idx = len(training_loader) - 1
|
|
190
192
|
begin_epoch = 1
|
|
191
193
|
epochs = args.epochs + 1
|
|
@@ -256,20 +258,19 @@ def train(args: argparse.Namespace) -> None:
|
|
|
256
258
|
|
|
257
259
|
# Learning rate scaling
|
|
258
260
|
lr = training_utils.scale_lr(args)
|
|
259
|
-
grad_accum_steps: int = args.grad_accum_steps
|
|
260
261
|
|
|
261
262
|
if args.lr_scheduler_update == "epoch":
|
|
262
263
|
step_update = False
|
|
263
|
-
|
|
264
|
+
scheduler_steps_per_epoch = 1
|
|
264
265
|
elif args.lr_scheduler_update == "step":
|
|
265
266
|
step_update = True
|
|
266
|
-
|
|
267
|
+
scheduler_steps_per_epoch = optimizer_steps_per_epoch
|
|
267
268
|
else:
|
|
268
269
|
raise ValueError("Unsupported lr_scheduler_update")
|
|
269
270
|
|
|
270
271
|
# Optimizer and learning rate scheduler
|
|
271
272
|
optimizer = training_utils.get_optimizer(parameters, lr, args)
|
|
272
|
-
scheduler = training_utils.get_scheduler(optimizer,
|
|
273
|
+
scheduler = training_utils.get_scheduler(optimizer, scheduler_steps_per_epoch, args)
|
|
273
274
|
if args.compile_opt is True:
|
|
274
275
|
optimizer.step = torch.compile(optimizer.step, fullgraph=False)
|
|
275
276
|
|
|
@@ -295,11 +296,13 @@ def train(args: argparse.Namespace) -> None:
|
|
|
295
296
|
optimizer.step()
|
|
296
297
|
lrs = []
|
|
297
298
|
for _ in range(begin_epoch, epochs):
|
|
298
|
-
for _ in range(
|
|
299
|
+
for _ in range(scheduler_steps_per_epoch):
|
|
299
300
|
lrs.append(float(max(scheduler.get_last_lr())))
|
|
300
301
|
scheduler.step()
|
|
301
302
|
|
|
302
|
-
plt.plot(
|
|
303
|
+
plt.plot(
|
|
304
|
+
np.linspace(begin_epoch, epochs, scheduler_steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs
|
|
305
|
+
)
|
|
303
306
|
plt.show()
|
|
304
307
|
raise SystemExit(0)
|
|
305
308
|
|
birder/scripts/train_rotnet.py
CHANGED
|
@@ -168,7 +168,8 @@ def train(args: argparse.Namespace) -> None:
|
|
|
168
168
|
logger.info(f"Training on {len(training_dataset):,} samples")
|
|
169
169
|
|
|
170
170
|
batch_size: int = args.batch_size
|
|
171
|
-
|
|
171
|
+
grad_accum_steps: int = args.grad_accum_steps
|
|
172
|
+
logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
|
|
172
173
|
|
|
173
174
|
# Data loaders and samplers
|
|
174
175
|
if args.distributed is True:
|
|
@@ -200,6 +201,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
200
201
|
drop_last=args.drop_last,
|
|
201
202
|
)
|
|
202
203
|
|
|
204
|
+
optimizer_steps_per_epoch = math.ceil(len(training_loader) / grad_accum_steps)
|
|
203
205
|
last_batch_idx = len(training_loader) - 1
|
|
204
206
|
begin_epoch = 1
|
|
205
207
|
epochs = args.epochs + 1
|
|
@@ -268,20 +270,19 @@ def train(args: argparse.Namespace) -> None:
|
|
|
268
270
|
|
|
269
271
|
# Learning rate scaling
|
|
270
272
|
lr = training_utils.scale_lr(args)
|
|
271
|
-
grad_accum_steps: int = args.grad_accum_steps
|
|
272
273
|
|
|
273
274
|
if args.lr_scheduler_update == "epoch":
|
|
274
275
|
step_update = False
|
|
275
|
-
|
|
276
|
+
scheduler_steps_per_epoch = 1
|
|
276
277
|
elif args.lr_scheduler_update == "step":
|
|
277
278
|
step_update = True
|
|
278
|
-
|
|
279
|
+
scheduler_steps_per_epoch = optimizer_steps_per_epoch
|
|
279
280
|
else:
|
|
280
281
|
raise ValueError("Unsupported lr_scheduler_update")
|
|
281
282
|
|
|
282
283
|
# Optimizer and learning rate scheduler
|
|
283
284
|
optimizer = training_utils.get_optimizer(parameters, lr, args)
|
|
284
|
-
scheduler = training_utils.get_scheduler(optimizer,
|
|
285
|
+
scheduler = training_utils.get_scheduler(optimizer, scheduler_steps_per_epoch, args)
|
|
285
286
|
if args.compile_opt is True:
|
|
286
287
|
optimizer.step = torch.compile(optimizer.step, fullgraph=False)
|
|
287
288
|
|
|
@@ -307,11 +308,13 @@ def train(args: argparse.Namespace) -> None:
|
|
|
307
308
|
optimizer.step()
|
|
308
309
|
lrs = []
|
|
309
310
|
for _ in range(begin_epoch, epochs):
|
|
310
|
-
for _ in range(
|
|
311
|
+
for _ in range(scheduler_steps_per_epoch):
|
|
311
312
|
lrs.append(float(max(scheduler.get_last_lr())))
|
|
312
313
|
scheduler.step()
|
|
313
314
|
|
|
314
|
-
plt.plot(
|
|
315
|
+
plt.plot(
|
|
316
|
+
np.linspace(begin_epoch, epochs, scheduler_steps_per_epoch * (epochs - begin_epoch), endpoint=False), lrs
|
|
317
|
+
)
|
|
315
318
|
plt.show()
|
|
316
319
|
raise SystemExit(0)
|
|
317
320
|
|