birder 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/adversarial/deepfool.py +2 -0
- birder/adversarial/simba.py +2 -0
- birder/common/masking.py +13 -4
- birder/inference/classification.py +1 -1
- birder/introspection/__init__.py +2 -0
- birder/introspection/base.py +0 -7
- birder/introspection/feature_pca.py +101 -0
- birder/kernels/soft_nms/soft_nms.cpp +5 -2
- birder/model_registry/model_registry.py +3 -2
- birder/net/convnext_v1.py +20 -0
- birder/net/fastvit.py +0 -1
- birder/net/flexivit.py +5 -0
- birder/net/focalnet.py +0 -1
- birder/net/hiera.py +3 -3
- birder/net/hieradet.py +116 -28
- birder/net/rope_flexivit.py +7 -0
- birder/net/rope_vit.py +49 -4
- birder/net/smt.py +0 -1
- birder/net/ssl/ibot.py +0 -1
- birder/net/vit.py +166 -2
- birder/scripts/train.py +24 -21
- birder/scripts/train_barlow_twins.py +4 -3
- birder/scripts/train_byol.py +4 -3
- birder/scripts/train_capi.py +6 -5
- birder/scripts/train_data2vec.py +4 -3
- birder/scripts/train_data2vec2.py +4 -3
- birder/scripts/train_detection.py +7 -5
- birder/scripts/train_dino_v1.py +5 -4
- birder/scripts/train_dino_v2.py +69 -20
- birder/scripts/train_dino_v2_dist.py +70 -21
- birder/scripts/train_franca.py +8 -7
- birder/scripts/train_i_jepa.py +4 -3
- birder/scripts/train_ibot.py +5 -4
- birder/scripts/train_kd.py +25 -24
- birder/scripts/train_mim.py +4 -3
- birder/scripts/train_mmcr.py +4 -3
- birder/scripts/train_rotnet.py +5 -4
- birder/scripts/train_simclr.py +4 -3
- birder/scripts/train_vicreg.py +4 -3
- birder/tools/avg_model.py +24 -8
- birder/tools/introspection.py +35 -9
- birder/tools/show_iterator.py +17 -3
- birder/version.py +1 -1
- {birder-0.3.1.dist-info → birder-0.3.3.dist-info}/METADATA +1 -1
- {birder-0.3.1.dist-info → birder-0.3.3.dist-info}/RECORD +49 -48
- {birder-0.3.1.dist-info → birder-0.3.3.dist-info}/WHEEL +0 -0
- {birder-0.3.1.dist-info → birder-0.3.3.dist-info}/entry_points.txt +0 -0
- {birder-0.3.1.dist-info → birder-0.3.3.dist-info}/licenses/LICENSE +0 -0
- {birder-0.3.1.dist-info → birder-0.3.3.dist-info}/top_level.txt +0 -0
birder/scripts/train_data2vec.py
CHANGED
|
@@ -384,11 +384,12 @@ def train(args: argparse.Namespace) -> None:
|
|
|
384
384
|
if virtual_epoch_mode is True:
|
|
385
385
|
train_iter = iter(training_loader)
|
|
386
386
|
|
|
387
|
+
running_loss = training_utils.SmoothedValue()
|
|
388
|
+
|
|
387
389
|
logger.info(f"Starting training with learning rate of {last_lr}")
|
|
388
390
|
for epoch in range(begin_epoch, args.stop_epoch):
|
|
389
391
|
tic = time.time()
|
|
390
392
|
net.train()
|
|
391
|
-
running_loss = training_utils.SmoothedValue()
|
|
392
393
|
|
|
393
394
|
if args.distributed is True or virtual_epoch_mode is True:
|
|
394
395
|
train_sampler.set_epoch(epoch)
|
|
@@ -463,7 +464,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
463
464
|
running_loss.update(loss.detach())
|
|
464
465
|
|
|
465
466
|
# Write statistics
|
|
466
|
-
if i % args.log_interval == 0 or i == last_batch_idx:
|
|
467
|
+
if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
|
|
467
468
|
time_now = time.time()
|
|
468
469
|
time_cost = time_now - start_time
|
|
469
470
|
iters_processed_in_interval = i - last_idx
|
|
@@ -603,6 +604,7 @@ def get_args_parser() -> argparse.ArgumentParser:
|
|
|
603
604
|
formatter_class=cli.ArgumentHelpFormatter,
|
|
604
605
|
)
|
|
605
606
|
parser.add_argument("-n", "--network", type=str, help="the neural network to use")
|
|
607
|
+
parser.add_argument("-t", "--tag", type=str, help="add model tag")
|
|
606
608
|
parser.add_argument(
|
|
607
609
|
"--model-config",
|
|
608
610
|
action=cli.FlexibleDictAction,
|
|
@@ -617,7 +619,6 @@ def get_args_parser() -> argparse.ArgumentParser:
|
|
|
617
619
|
default=0.999,
|
|
618
620
|
help="base EMA parameter for teacher update, set a higher value with small batches",
|
|
619
621
|
)
|
|
620
|
-
parser.add_argument("-t", "--tag", type=str, help="add model tag")
|
|
621
622
|
training_cli.add_optimization_args(parser)
|
|
622
623
|
training_cli.add_lr_wd_args(parser)
|
|
623
624
|
training_cli.add_lr_scheduler_args(parser)
|
|
@@ -393,11 +393,12 @@ def train(args: argparse.Namespace) -> None:
|
|
|
393
393
|
if virtual_epoch_mode is True:
|
|
394
394
|
train_iter = iter(training_loader)
|
|
395
395
|
|
|
396
|
+
running_loss = training_utils.SmoothedValue()
|
|
397
|
+
|
|
396
398
|
logger.info(f"Starting training with learning rate of {last_lr}")
|
|
397
399
|
for epoch in range(begin_epoch, args.stop_epoch):
|
|
398
400
|
tic = time.time()
|
|
399
401
|
net.train()
|
|
400
|
-
running_loss = training_utils.SmoothedValue()
|
|
401
402
|
|
|
402
403
|
if args.distributed is True or virtual_epoch_mode is True:
|
|
403
404
|
train_sampler.set_epoch(epoch)
|
|
@@ -473,7 +474,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
473
474
|
running_loss.update(loss.detach())
|
|
474
475
|
|
|
475
476
|
# Write statistics
|
|
476
|
-
if i % args.log_interval == 0 or i == last_batch_idx:
|
|
477
|
+
if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
|
|
477
478
|
time_now = time.time()
|
|
478
479
|
time_cost = time_now - start_time
|
|
479
480
|
iters_processed_in_interval = i - last_idx
|
|
@@ -615,6 +616,7 @@ def get_args_parser() -> argparse.ArgumentParser:
|
|
|
615
616
|
formatter_class=cli.ArgumentHelpFormatter,
|
|
616
617
|
)
|
|
617
618
|
parser.add_argument("-n", "--network", type=str, help="the neural network to use")
|
|
619
|
+
parser.add_argument("-t", "--tag", type=str, help="add model tag")
|
|
618
620
|
parser.add_argument(
|
|
619
621
|
"--model-config",
|
|
620
622
|
action=cli.FlexibleDictAction,
|
|
@@ -635,7 +637,6 @@ def get_args_parser() -> argparse.ArgumentParser:
|
|
|
635
637
|
default=0.9998,
|
|
636
638
|
help="base EMA parameter for teacher update, set a higher value with small batches",
|
|
637
639
|
)
|
|
638
|
-
parser.add_argument("-t", "--tag", type=str, help="add model tag")
|
|
639
640
|
training_cli.add_optimization_args(parser)
|
|
640
641
|
training_cli.add_lr_wd_args(parser)
|
|
641
642
|
training_cli.add_lr_scheduler_args(parser)
|
|
@@ -538,12 +538,14 @@ def train(args: argparse.Namespace) -> None:
|
|
|
538
538
|
if virtual_epoch_mode is True:
|
|
539
539
|
train_iter = iter(training_loader)
|
|
540
540
|
|
|
541
|
+
running_loss = training_utils.SmoothedValue()
|
|
542
|
+
loss_trackers: dict[str, training_utils.SmoothedValue] = {}
|
|
543
|
+
|
|
541
544
|
logger.info(f"Starting training with learning rate of {last_lr}")
|
|
542
545
|
for epoch in range(begin_epoch, args.stop_epoch):
|
|
543
546
|
tic = time.time()
|
|
544
547
|
net.train()
|
|
545
|
-
|
|
546
|
-
loss_trackers: dict[str, training_utils.SmoothedValue] = {}
|
|
548
|
+
|
|
547
549
|
validation_metrics.reset()
|
|
548
550
|
|
|
549
551
|
if args.distributed is True or virtual_epoch_mode is True:
|
|
@@ -634,7 +636,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
634
636
|
loss_trackers[key].update(value.detach())
|
|
635
637
|
|
|
636
638
|
# Write statistics
|
|
637
|
-
if i % args.log_interval == 0 or i == last_batch_idx:
|
|
639
|
+
if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
|
|
638
640
|
time_now = time.time()
|
|
639
641
|
time_cost = time_now - start_time
|
|
640
642
|
iters_processed_in_interval = i - last_idx
|
|
@@ -889,6 +891,7 @@ def get_args_parser() -> argparse.ArgumentParser:
|
|
|
889
891
|
formatter_class=cli.ArgumentHelpFormatter,
|
|
890
892
|
)
|
|
891
893
|
parser.add_argument("-n", "--network", type=str, help="the neural network to use")
|
|
894
|
+
parser.add_argument("-t", "--tag", type=str, help="add model tag")
|
|
892
895
|
parser.add_argument(
|
|
893
896
|
"--model-config",
|
|
894
897
|
action=cli.FlexibleDictAction,
|
|
@@ -897,8 +900,8 @@ def get_args_parser() -> argparse.ArgumentParser:
|
|
|
897
900
|
"('drop_path_rate=0.2' or '{\"units\": [3, 24, 36, 3], \"dropout\": 0.2}'"
|
|
898
901
|
),
|
|
899
902
|
)
|
|
900
|
-
parser.add_argument("-t", "--tag", type=str, help="add model tag")
|
|
901
903
|
parser.add_argument("--backbone", type=str, help="the neural network to used as backbone")
|
|
904
|
+
parser.add_argument("--backbone-tag", type=str, help="backbone training log tag (loading only)")
|
|
902
905
|
parser.add_argument(
|
|
903
906
|
"--backbone-model-config",
|
|
904
907
|
action=cli.FlexibleDictAction,
|
|
@@ -907,7 +910,6 @@ def get_args_parser() -> argparse.ArgumentParser:
|
|
|
907
910
|
"('drop_path_rate=0.2' or '{\"units\": [3, 24, 36, 3], \"dropout\": 0.2}'"
|
|
908
911
|
),
|
|
909
912
|
)
|
|
910
|
-
parser.add_argument("--backbone-tag", type=str, help="backbone training log tag (loading only)")
|
|
911
913
|
parser.add_argument("--backbone-epoch", type=int, help="load backbone weights from selected epoch")
|
|
912
914
|
parser.add_argument(
|
|
913
915
|
"--backbone-pretrained",
|
birder/scripts/train_dino_v1.py
CHANGED
|
@@ -480,12 +480,13 @@ def train(args: argparse.Namespace) -> None:
|
|
|
480
480
|
if virtual_epoch_mode is True:
|
|
481
481
|
train_iter = iter(training_loader)
|
|
482
482
|
|
|
483
|
+
running_loss = training_utils.SmoothedValue()
|
|
484
|
+
train_proto_agreement = training_utils.SmoothedValue()
|
|
485
|
+
|
|
483
486
|
logger.info(f"Starting training with learning rate of {last_lr}")
|
|
484
487
|
for epoch in range(begin_epoch, args.stop_epoch):
|
|
485
488
|
tic = time.time()
|
|
486
489
|
net.train()
|
|
487
|
-
running_loss = training_utils.SmoothedValue()
|
|
488
|
-
train_proto_agreement = training_utils.SmoothedValue()
|
|
489
490
|
|
|
490
491
|
if args.distributed is True or virtual_epoch_mode is True:
|
|
491
492
|
train_sampler.set_epoch(epoch)
|
|
@@ -581,7 +582,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
581
582
|
train_proto_agreement.update(training_utils.accuracy(pred_teacher, pred_student))
|
|
582
583
|
|
|
583
584
|
# Write statistics
|
|
584
|
-
if i % args.log_interval == 0 or i == last_batch_idx:
|
|
585
|
+
if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
|
|
585
586
|
time_now = time.time()
|
|
586
587
|
time_cost = time_now - start_time
|
|
587
588
|
iters_processed_in_interval = i - last_idx
|
|
@@ -733,6 +734,7 @@ def get_args_parser() -> argparse.ArgumentParser:
|
|
|
733
734
|
formatter_class=cli.ArgumentHelpFormatter,
|
|
734
735
|
)
|
|
735
736
|
parser.add_argument("-n", "--network", type=str, help="the neural network to use")
|
|
737
|
+
parser.add_argument("-t", "--tag", type=str, help="add model tag")
|
|
736
738
|
parser.add_argument(
|
|
737
739
|
"--model-config",
|
|
738
740
|
action=cli.FlexibleDictAction,
|
|
@@ -788,7 +790,6 @@ def get_args_parser() -> argparse.ArgumentParser:
|
|
|
788
790
|
parser.add_argument(
|
|
789
791
|
"--local-crop-size", type=int, nargs="+", default=[96, 96], metavar=("H", "W"), help="local view size"
|
|
790
792
|
)
|
|
791
|
-
parser.add_argument("-t", "--tag", type=str, help="add model tag")
|
|
792
793
|
parser.add_argument(
|
|
793
794
|
"--backbone-epoch",
|
|
794
795
|
type=int,
|
birder/scripts/train_dino_v2.py
CHANGED
|
@@ -582,22 +582,26 @@ def train(args: argparse.Namespace) -> None:
|
|
|
582
582
|
#
|
|
583
583
|
# Training loop
|
|
584
584
|
#
|
|
585
|
-
|
|
585
|
+
track_extended_metrics = not args.no_extended_metrics
|
|
586
586
|
if virtual_epoch_mode is True:
|
|
587
587
|
train_iter = iter(training_loader)
|
|
588
588
|
|
|
589
|
+
running_loss = training_utils.SmoothedValue()
|
|
590
|
+
running_loss_dino_local = training_utils.SmoothedValue()
|
|
591
|
+
running_loss_dino_global = training_utils.SmoothedValue()
|
|
592
|
+
running_loss_koleo = training_utils.SmoothedValue()
|
|
593
|
+
running_loss_ibot_patch = training_utils.SmoothedValue()
|
|
594
|
+
if track_extended_metrics is True:
|
|
595
|
+
train_proto_agreement = training_utils.SmoothedValue()
|
|
596
|
+
train_patch_agreement = training_utils.SmoothedValue()
|
|
597
|
+
running_target_entropy = training_utils.SmoothedValue()
|
|
598
|
+
running_dino_center_drift = training_utils.SmoothedValue()
|
|
599
|
+
running_ibot_center_drift = training_utils.SmoothedValue()
|
|
600
|
+
|
|
589
601
|
logger.info(f"Starting training with learning rate of {last_lr}")
|
|
590
602
|
for epoch in range(begin_epoch, args.stop_epoch):
|
|
591
603
|
tic = time.time()
|
|
592
604
|
net.train()
|
|
593
|
-
running_loss = training_utils.SmoothedValue()
|
|
594
|
-
running_loss_dino_local = training_utils.SmoothedValue()
|
|
595
|
-
running_loss_dino_global = training_utils.SmoothedValue()
|
|
596
|
-
running_loss_koleo = training_utils.SmoothedValue()
|
|
597
|
-
running_loss_ibot_patch = training_utils.SmoothedValue()
|
|
598
|
-
if track_agreement is True:
|
|
599
|
-
train_proto_agreement = training_utils.SmoothedValue()
|
|
600
|
-
train_patch_agreement = training_utils.SmoothedValue()
|
|
601
605
|
|
|
602
606
|
if args.sinkhorn_queue_size is not None:
|
|
603
607
|
queue_active = epoch > args.sinkhorn_queue_warmup_epochs
|
|
@@ -662,6 +666,11 @@ def train(args: argparse.Namespace) -> None:
|
|
|
662
666
|
)
|
|
663
667
|
teacher_patch_tokens_raw = teacher_masked_patch_tokens_after_head
|
|
664
668
|
if args.centering == "centering":
|
|
669
|
+
# Track centers before update for drift computation
|
|
670
|
+
if track_extended_metrics is True:
|
|
671
|
+
prev_dino_center = dino_loss.center.clone()
|
|
672
|
+
prev_ibot_center = ibot_patch_loss.center.clone()
|
|
673
|
+
|
|
665
674
|
teacher_dino_softmax_centered_list = dino_loss.softmax_center_teacher(
|
|
666
675
|
teacher_embedding_after_head, teacher_temp=teacher_temp
|
|
667
676
|
).view(n_global_crops, -1, *teacher_embedding_after_head.shape[1:])
|
|
@@ -788,7 +797,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
788
797
|
running_loss_koleo.update(loss_koleo.detach())
|
|
789
798
|
running_loss_ibot_patch.update(loss_ibot_patch.detach())
|
|
790
799
|
|
|
791
|
-
if
|
|
800
|
+
if track_extended_metrics is True:
|
|
792
801
|
probs_teacher = teacher_embedding_after_head.chunk(n_global_crops)
|
|
793
802
|
probs_student = student_global_embedding_after_head.chunk(n_global_crops)
|
|
794
803
|
pred_teacher = probs_teacher[0].argmax(dim=1)
|
|
@@ -799,8 +808,27 @@ def train(args: argparse.Namespace) -> None:
|
|
|
799
808
|
pred_patch_student = student_global_masked_patch_tokens_after_head.argmax(dim=1)
|
|
800
809
|
train_patch_agreement.update(training_utils.accuracy(pred_patch_teacher, pred_patch_student))
|
|
801
810
|
|
|
811
|
+
with torch.no_grad():
|
|
812
|
+
p = teacher_dino_softmax_centered_list.detach()
|
|
813
|
+
p = p.reshape(-1, p.size(-1)) # (N, D)
|
|
814
|
+
|
|
815
|
+
# Mean distribution over prototypes (marginal)
|
|
816
|
+
m = p.mean(dim=0).clamp_min(1e-12)
|
|
817
|
+
|
|
818
|
+
# Entropy of the marginal
|
|
819
|
+
entropy = -(m * m.log()).sum()
|
|
820
|
+
|
|
821
|
+
running_target_entropy.update(entropy.detach())
|
|
822
|
+
|
|
823
|
+
# Compute center drift
|
|
824
|
+
if args.centering == "centering":
|
|
825
|
+
dino_center_drift = torch.norm(dino_loss.center - prev_dino_center, p=2).detach()
|
|
826
|
+
ibot_center_drift = torch.norm(ibot_patch_loss.center - prev_ibot_center, p=2).detach()
|
|
827
|
+
running_dino_center_drift.update(dino_center_drift)
|
|
828
|
+
running_ibot_center_drift.update(ibot_center_drift)
|
|
829
|
+
|
|
802
830
|
# Write statistics
|
|
803
|
-
if i % args.log_interval == 0 or i == last_batch_idx:
|
|
831
|
+
if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
|
|
804
832
|
time_now = time.time()
|
|
805
833
|
time_cost = time_now - start_time
|
|
806
834
|
iters_processed_in_interval = i - last_idx
|
|
@@ -819,9 +847,13 @@ def train(args: argparse.Namespace) -> None:
|
|
|
819
847
|
running_loss_dino_global.synchronize_between_processes(device)
|
|
820
848
|
running_loss_koleo.synchronize_between_processes(device)
|
|
821
849
|
running_loss_ibot_patch.synchronize_between_processes(device)
|
|
822
|
-
if
|
|
850
|
+
if track_extended_metrics is True:
|
|
823
851
|
train_proto_agreement.synchronize_between_processes(device)
|
|
824
852
|
train_patch_agreement.synchronize_between_processes(device)
|
|
853
|
+
running_target_entropy.synchronize_between_processes(device)
|
|
854
|
+
if args.centering == "centering":
|
|
855
|
+
running_dino_center_drift.synchronize_between_processes(device)
|
|
856
|
+
running_ibot_center_drift.synchronize_between_processes(device)
|
|
825
857
|
|
|
826
858
|
with training_utils.single_handler_logging(logger, file_handler, enabled=not disable_tqdm) as log:
|
|
827
859
|
log.info(
|
|
@@ -846,13 +878,19 @@ def train(args: argparse.Namespace) -> None:
|
|
|
846
878
|
},
|
|
847
879
|
((epoch - 1) * epoch_samples) + ((i + 1) * batch_size * args.world_size),
|
|
848
880
|
)
|
|
849
|
-
if
|
|
881
|
+
if track_extended_metrics is True:
|
|
882
|
+
metrics = {
|
|
883
|
+
"prototype_agreement": train_proto_agreement.avg,
|
|
884
|
+
"patch_agreement": train_patch_agreement.avg,
|
|
885
|
+
"target_entropy": running_target_entropy.avg,
|
|
886
|
+
}
|
|
887
|
+
if args.centering == "centering":
|
|
888
|
+
metrics["dino_center_drift"] = running_dino_center_drift.avg
|
|
889
|
+
metrics["ibot_center_drift"] = running_ibot_center_drift.avg
|
|
890
|
+
|
|
850
891
|
summary_writer.add_scalars(
|
|
851
892
|
"performance",
|
|
852
|
-
|
|
853
|
-
"prototype_agreement": train_proto_agreement.avg,
|
|
854
|
-
"patch_agreement": train_patch_agreement.avg,
|
|
855
|
-
},
|
|
893
|
+
metrics,
|
|
856
894
|
((epoch - 1) * epoch_samples) + ((i + 1) * batch_size * args.world_size),
|
|
857
895
|
)
|
|
858
896
|
|
|
@@ -867,9 +905,17 @@ def train(args: argparse.Namespace) -> None:
|
|
|
867
905
|
logger.info(f"[Trn] Epoch {epoch}/{epochs-1} dino_global_loss: {running_loss_dino_global.global_avg:.4f}")
|
|
868
906
|
logger.info(f"[Trn] Epoch {epoch}/{epochs-1} koleo_loss: {running_loss_koleo.global_avg:.4f}")
|
|
869
907
|
logger.info(f"[Trn] Epoch {epoch}/{epochs-1} ibot_patch_loss: {running_loss_ibot_patch.global_avg:.4f}")
|
|
870
|
-
if
|
|
908
|
+
if track_extended_metrics is True:
|
|
871
909
|
logger.info(f"[Trn] Epoch {epoch}/{epochs-1} prototype_agreement: {train_proto_agreement.global_avg:.4f}")
|
|
872
910
|
logger.info(f"[Trn] Epoch {epoch}/{epochs-1} patch_agreement: {train_patch_agreement.global_avg:.4f}")
|
|
911
|
+
logger.info(f"[Trn] Epoch {epoch}/{epochs-1} target_entropy: {running_target_entropy.global_avg:.4f}")
|
|
912
|
+
if args.centering == "centering":
|
|
913
|
+
logger.info(
|
|
914
|
+
f"[Trn] Epoch {epoch}/{epochs-1} dino_center_drift: {running_dino_center_drift.global_avg:.4f}"
|
|
915
|
+
)
|
|
916
|
+
logger.info(
|
|
917
|
+
f"[Trn] Epoch {epoch}/{epochs-1} ibot_center_drift: {running_ibot_center_drift.global_avg:.4f}"
|
|
918
|
+
)
|
|
873
919
|
|
|
874
920
|
# Learning rate scheduler update
|
|
875
921
|
if step_update is False:
|
|
@@ -976,6 +1022,7 @@ def get_args_parser() -> argparse.ArgumentParser:
|
|
|
976
1022
|
formatter_class=cli.ArgumentHelpFormatter,
|
|
977
1023
|
)
|
|
978
1024
|
parser.add_argument("-n", "--network", type=str, help="the neural network to use")
|
|
1025
|
+
parser.add_argument("-t", "--tag", type=str, help="add model tag")
|
|
979
1026
|
parser.add_argument(
|
|
980
1027
|
"--model-config",
|
|
981
1028
|
action=cli.FlexibleDictAction,
|
|
@@ -1042,9 +1089,11 @@ def get_args_parser() -> argparse.ArgumentParser:
|
|
|
1042
1089
|
help="number of initial epochs to disable Sinkhorn queueing",
|
|
1043
1090
|
)
|
|
1044
1091
|
parser.add_argument(
|
|
1045
|
-
"--no-
|
|
1092
|
+
"--no-extended-metrics",
|
|
1093
|
+
default=False,
|
|
1094
|
+
action="store_true",
|
|
1095
|
+
help="disable extended metrics (prototype/patch agreement, target entropy, center drift)",
|
|
1046
1096
|
)
|
|
1047
|
-
parser.add_argument("-t", "--tag", type=str, help="add model tag")
|
|
1048
1097
|
training_cli.add_optimization_args(parser)
|
|
1049
1098
|
training_cli.add_lr_wd_args(parser, wd_end=True)
|
|
1050
1099
|
training_cli.add_lr_scheduler_args(parser)
|
|
@@ -603,23 +603,27 @@ def train(args: argparse.Namespace) -> None:
|
|
|
603
603
|
#
|
|
604
604
|
# Training loop
|
|
605
605
|
#
|
|
606
|
-
|
|
606
|
+
track_extended_metrics = not args.no_extended_metrics
|
|
607
607
|
if virtual_epoch_mode is True:
|
|
608
608
|
train_iter = iter(training_loader)
|
|
609
609
|
|
|
610
|
+
running_loss = training_utils.SmoothedValue()
|
|
611
|
+
running_loss_dino_local = training_utils.SmoothedValue()
|
|
612
|
+
running_loss_dino_global = training_utils.SmoothedValue()
|
|
613
|
+
running_loss_koleo = training_utils.SmoothedValue()
|
|
614
|
+
running_loss_ibot_patch = training_utils.SmoothedValue()
|
|
615
|
+
if track_extended_metrics is True:
|
|
616
|
+
train_proto_agreement = training_utils.SmoothedValue()
|
|
617
|
+
train_patch_agreement = training_utils.SmoothedValue()
|
|
618
|
+
running_target_entropy = training_utils.SmoothedValue()
|
|
619
|
+
running_dino_center_drift = training_utils.SmoothedValue()
|
|
620
|
+
running_ibot_center_drift = training_utils.SmoothedValue()
|
|
621
|
+
|
|
610
622
|
logger.info(f"Starting training with learning rate of {last_lr}")
|
|
611
623
|
for epoch in range(begin_epoch, args.stop_epoch):
|
|
612
624
|
tic = time.time()
|
|
613
625
|
net.train()
|
|
614
626
|
teacher.eval()
|
|
615
|
-
running_loss = training_utils.SmoothedValue()
|
|
616
|
-
running_loss_dino_local = training_utils.SmoothedValue()
|
|
617
|
-
running_loss_dino_global = training_utils.SmoothedValue()
|
|
618
|
-
running_loss_koleo = training_utils.SmoothedValue()
|
|
619
|
-
running_loss_ibot_patch = training_utils.SmoothedValue()
|
|
620
|
-
if track_agreement is True:
|
|
621
|
-
train_proto_agreement = training_utils.SmoothedValue()
|
|
622
|
-
train_patch_agreement = training_utils.SmoothedValue()
|
|
623
627
|
|
|
624
628
|
if args.sinkhorn_queue_size is not None:
|
|
625
629
|
queue_active = epoch > args.sinkhorn_queue_warmup_epochs
|
|
@@ -683,6 +687,11 @@ def train(args: argparse.Namespace) -> None:
|
|
|
683
687
|
)
|
|
684
688
|
teacher_patch_tokens_raw = teacher_masked_patch_tokens_after_head
|
|
685
689
|
if args.centering == "centering":
|
|
690
|
+
# Track centers before update for drift computation
|
|
691
|
+
if track_extended_metrics is True:
|
|
692
|
+
prev_dino_center = dino_loss.center.clone()
|
|
693
|
+
prev_ibot_center = ibot_patch_loss.center.clone()
|
|
694
|
+
|
|
686
695
|
teacher_dino_softmax_centered_list = dino_loss.softmax_center_teacher(
|
|
687
696
|
teacher_embedding_after_head, teacher_temp=teacher_temp
|
|
688
697
|
).view(n_global_crops, -1, *teacher_embedding_after_head.shape[1:])
|
|
@@ -809,7 +818,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
809
818
|
running_loss_koleo.update(loss_koleo.detach())
|
|
810
819
|
running_loss_ibot_patch.update(loss_ibot_patch.detach())
|
|
811
820
|
|
|
812
|
-
if
|
|
821
|
+
if track_extended_metrics is True:
|
|
813
822
|
probs_teacher = teacher_embedding_after_head.chunk(n_global_crops)
|
|
814
823
|
probs_student = student_global_embedding_after_head.chunk(n_global_crops)
|
|
815
824
|
pred_teacher = probs_teacher[0].argmax(dim=1)
|
|
@@ -820,8 +829,27 @@ def train(args: argparse.Namespace) -> None:
|
|
|
820
829
|
pred_patch_student = student_global_masked_patch_tokens_after_head.argmax(dim=1)
|
|
821
830
|
train_patch_agreement.update(training_utils.accuracy(pred_patch_teacher, pred_patch_student))
|
|
822
831
|
|
|
832
|
+
with torch.no_grad():
|
|
833
|
+
p = teacher_dino_softmax_centered_list.detach()
|
|
834
|
+
p = p.reshape(-1, p.size(-1)) # (N, D)
|
|
835
|
+
|
|
836
|
+
# Mean distribution over prototypes (marginal)
|
|
837
|
+
m = p.mean(dim=0).clamp_min(1e-12)
|
|
838
|
+
|
|
839
|
+
# Entropy of the marginal
|
|
840
|
+
entropy = -(m * m.log()).sum()
|
|
841
|
+
|
|
842
|
+
running_target_entropy.update(entropy.detach())
|
|
843
|
+
|
|
844
|
+
# Compute center drift
|
|
845
|
+
if args.centering == "centering":
|
|
846
|
+
dino_center_drift = torch.norm(dino_loss.center - prev_dino_center, p=2).detach()
|
|
847
|
+
ibot_center_drift = torch.norm(ibot_patch_loss.center - prev_ibot_center, p=2).detach()
|
|
848
|
+
running_dino_center_drift.update(dino_center_drift)
|
|
849
|
+
running_ibot_center_drift.update(ibot_center_drift)
|
|
850
|
+
|
|
823
851
|
# Write statistics
|
|
824
|
-
if i % args.log_interval == 0 or i == last_batch_idx:
|
|
852
|
+
if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
|
|
825
853
|
time_now = time.time()
|
|
826
854
|
time_cost = time_now - start_time
|
|
827
855
|
iters_processed_in_interval = i - last_idx
|
|
@@ -840,9 +868,13 @@ def train(args: argparse.Namespace) -> None:
|
|
|
840
868
|
running_loss_dino_global.synchronize_between_processes(device)
|
|
841
869
|
running_loss_koleo.synchronize_between_processes(device)
|
|
842
870
|
running_loss_ibot_patch.synchronize_between_processes(device)
|
|
843
|
-
if
|
|
871
|
+
if track_extended_metrics is True:
|
|
844
872
|
train_proto_agreement.synchronize_between_processes(device)
|
|
845
873
|
train_patch_agreement.synchronize_between_processes(device)
|
|
874
|
+
running_target_entropy.synchronize_between_processes(device)
|
|
875
|
+
if args.centering == "centering":
|
|
876
|
+
running_dino_center_drift.synchronize_between_processes(device)
|
|
877
|
+
running_ibot_center_drift.synchronize_between_processes(device)
|
|
846
878
|
|
|
847
879
|
with training_utils.single_handler_logging(logger, file_handler, enabled=not disable_tqdm) as log:
|
|
848
880
|
log.info(
|
|
@@ -867,13 +899,19 @@ def train(args: argparse.Namespace) -> None:
|
|
|
867
899
|
},
|
|
868
900
|
((epoch - 1) * epoch_samples) + ((i + 1) * batch_size * args.world_size),
|
|
869
901
|
)
|
|
870
|
-
if
|
|
902
|
+
if track_extended_metrics is True:
|
|
903
|
+
metrics = {
|
|
904
|
+
"prototype_agreement": train_proto_agreement.avg,
|
|
905
|
+
"patch_agreement": train_patch_agreement.avg,
|
|
906
|
+
"target_entropy": running_target_entropy.avg,
|
|
907
|
+
}
|
|
908
|
+
if args.centering == "centering":
|
|
909
|
+
metrics["dino_center_drift"] = running_dino_center_drift.avg
|
|
910
|
+
metrics["ibot_center_drift"] = running_ibot_center_drift.avg
|
|
911
|
+
|
|
871
912
|
summary_writer.add_scalars(
|
|
872
913
|
"performance",
|
|
873
|
-
|
|
874
|
-
"prototype_agreement": train_proto_agreement.avg,
|
|
875
|
-
"patch_agreement": train_patch_agreement.avg,
|
|
876
|
-
},
|
|
914
|
+
metrics,
|
|
877
915
|
((epoch - 1) * epoch_samples) + ((i + 1) * batch_size * args.world_size),
|
|
878
916
|
)
|
|
879
917
|
|
|
@@ -888,9 +926,17 @@ def train(args: argparse.Namespace) -> None:
|
|
|
888
926
|
logger.info(f"[Trn] Epoch {epoch}/{epochs-1} dino_global_loss: {running_loss_dino_global.global_avg:.4f}")
|
|
889
927
|
logger.info(f"[Trn] Epoch {epoch}/{epochs-1} koleo_loss: {running_loss_koleo.global_avg:.4f}")
|
|
890
928
|
logger.info(f"[Trn] Epoch {epoch}/{epochs-1} ibot_patch_loss: {running_loss_ibot_patch.global_avg:.4f}")
|
|
891
|
-
if
|
|
929
|
+
if track_extended_metrics is True:
|
|
892
930
|
logger.info(f"[Trn] Epoch {epoch}/{epochs-1} prototype_agreement: {train_proto_agreement.global_avg:.4f}")
|
|
893
931
|
logger.info(f"[Trn] Epoch {epoch}/{epochs-1} patch_agreement: {train_patch_agreement.global_avg:.4f}")
|
|
932
|
+
logger.info(f"[Trn] Epoch {epoch}/{epochs-1} target_entropy: {running_target_entropy.global_avg:.4f}")
|
|
933
|
+
if args.centering == "centering":
|
|
934
|
+
logger.info(
|
|
935
|
+
f"[Trn] Epoch {epoch}/{epochs-1} dino_center_drift: {running_dino_center_drift.global_avg:.4f}"
|
|
936
|
+
)
|
|
937
|
+
logger.info(
|
|
938
|
+
f"[Trn] Epoch {epoch}/{epochs-1} ibot_center_drift: {running_ibot_center_drift.global_avg:.4f}"
|
|
939
|
+
)
|
|
894
940
|
|
|
895
941
|
# Learning rate scheduler update
|
|
896
942
|
if step_update is False:
|
|
@@ -998,6 +1044,7 @@ def get_args_parser() -> argparse.ArgumentParser:
|
|
|
998
1044
|
formatter_class=cli.ArgumentHelpFormatter,
|
|
999
1045
|
)
|
|
1000
1046
|
parser.add_argument("-n", "--network", type=str, help="the neural network to use")
|
|
1047
|
+
parser.add_argument("-t", "--tag", type=str, help="add model tag")
|
|
1001
1048
|
parser.add_argument(
|
|
1002
1049
|
"--model-config",
|
|
1003
1050
|
action=cli.FlexibleDictAction,
|
|
@@ -1006,8 +1053,8 @@ def get_args_parser() -> argparse.ArgumentParser:
|
|
|
1006
1053
|
"('drop_path_rate=0.2' or '{\"units\": [3, 24, 36, 3], \"dropout\": 0.2}'"
|
|
1007
1054
|
),
|
|
1008
1055
|
)
|
|
1009
|
-
parser.add_argument("-t", "--tag", type=str, help="add model tag")
|
|
1010
1056
|
parser.add_argument("--teacher", type=str, help="the neural network to use as teacher")
|
|
1057
|
+
parser.add_argument("--teacher-tag", type=str, help="teacher training logs tag")
|
|
1011
1058
|
parser.add_argument(
|
|
1012
1059
|
"--teacher-model-config",
|
|
1013
1060
|
action=cli.FlexibleDictAction,
|
|
@@ -1016,7 +1063,6 @@ def get_args_parser() -> argparse.ArgumentParser:
|
|
|
1016
1063
|
"('drop_path_rate=0.2' or '{\"units\": [3, 24, 36, 3], \"dropout\": 0.2}'"
|
|
1017
1064
|
),
|
|
1018
1065
|
)
|
|
1019
|
-
parser.add_argument("--teacher-tag", type=str, help="teacher training logs tag")
|
|
1020
1066
|
parser.add_argument("--teacher-epoch", type=int, metavar="N", help="load teacher weights from selected epoch")
|
|
1021
1067
|
parser.add_argument("--dino-loss-weight", type=float, default=1.0, help="weight for the DINO loss component")
|
|
1022
1068
|
parser.add_argument("--dino-out-dim", type=int, default=65536, help="dimensionality of the DINO head output")
|
|
@@ -1070,7 +1116,10 @@ def get_args_parser() -> argparse.ArgumentParser:
|
|
|
1070
1116
|
help="number of initial epochs to disable Sinkhorn queueing",
|
|
1071
1117
|
)
|
|
1072
1118
|
parser.add_argument(
|
|
1073
|
-
"--no-
|
|
1119
|
+
"--no-extended-metrics",
|
|
1120
|
+
default=False,
|
|
1121
|
+
action="store_true",
|
|
1122
|
+
help="disable extended metrics (prototype/patch agreement, target entropy, center drift)",
|
|
1074
1123
|
)
|
|
1075
1124
|
training_cli.add_optimization_args(parser)
|
|
1076
1125
|
training_cli.add_lr_wd_args(parser, wd_end=True)
|
birder/scripts/train_franca.py
CHANGED
|
@@ -612,15 +612,16 @@ def train(args: argparse.Namespace) -> None:
|
|
|
612
612
|
if virtual_epoch_mode is True:
|
|
613
613
|
train_iter = iter(training_loader)
|
|
614
614
|
|
|
615
|
+
running_loss = training_utils.SmoothedValue()
|
|
616
|
+
running_loss_dino_local = training_utils.SmoothedValue()
|
|
617
|
+
running_loss_dino_global = training_utils.SmoothedValue()
|
|
618
|
+
running_loss_koleo = training_utils.SmoothedValue()
|
|
619
|
+
running_loss_ibot_patch = training_utils.SmoothedValue()
|
|
620
|
+
|
|
615
621
|
logger.info(f"Starting training with learning rate of {last_lr}")
|
|
616
622
|
for epoch in range(begin_epoch, args.stop_epoch):
|
|
617
623
|
tic = time.time()
|
|
618
624
|
net.train()
|
|
619
|
-
running_loss = training_utils.SmoothedValue()
|
|
620
|
-
running_loss_dino_local = training_utils.SmoothedValue()
|
|
621
|
-
running_loss_dino_global = training_utils.SmoothedValue()
|
|
622
|
-
running_loss_koleo = training_utils.SmoothedValue()
|
|
623
|
-
running_loss_ibot_patch = training_utils.SmoothedValue()
|
|
624
625
|
|
|
625
626
|
if args.sinkhorn_queue_size is not None:
|
|
626
627
|
queue_active = epoch > args.sinkhorn_queue_warmup_epochs
|
|
@@ -804,7 +805,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
804
805
|
running_loss_ibot_patch.update(loss_ibot_patch.detach())
|
|
805
806
|
|
|
806
807
|
# Write statistics
|
|
807
|
-
if i % args.log_interval == 0 or i == last_batch_idx:
|
|
808
|
+
if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
|
|
808
809
|
time_now = time.time()
|
|
809
810
|
time_cost = time_now - start_time
|
|
810
811
|
iters_processed_in_interval = i - last_idx
|
|
@@ -963,6 +964,7 @@ def get_args_parser() -> argparse.ArgumentParser:
|
|
|
963
964
|
formatter_class=cli.ArgumentHelpFormatter,
|
|
964
965
|
)
|
|
965
966
|
parser.add_argument("-n", "--network", type=str, help="the neural network to use")
|
|
967
|
+
parser.add_argument("-t", "--tag", type=str, help="add model tag")
|
|
966
968
|
parser.add_argument(
|
|
967
969
|
"--model-config",
|
|
968
970
|
action=cli.FlexibleDictAction,
|
|
@@ -1024,7 +1026,6 @@ def get_args_parser() -> argparse.ArgumentParser:
|
|
|
1024
1026
|
parser.add_argument(
|
|
1025
1027
|
"--local-crop-size", type=int, nargs="+", default=[96, 96], metavar=("H", "W"), help="local view size"
|
|
1026
1028
|
)
|
|
1027
|
-
parser.add_argument("-t", "--tag", type=str, help="add model tag")
|
|
1028
1029
|
training_cli.add_optimization_args(parser)
|
|
1029
1030
|
training_cli.add_lr_wd_args(parser, wd_end=True)
|
|
1030
1031
|
training_cli.add_lr_scheduler_args(parser)
|
birder/scripts/train_i_jepa.py
CHANGED
|
@@ -433,11 +433,12 @@ def train(args: argparse.Namespace) -> None:
|
|
|
433
433
|
if virtual_epoch_mode is True:
|
|
434
434
|
train_iter = iter(training_loader)
|
|
435
435
|
|
|
436
|
+
running_loss = training_utils.SmoothedValue()
|
|
437
|
+
|
|
436
438
|
logger.info(f"Starting training with learning rate of {last_lr}")
|
|
437
439
|
for epoch in range(begin_epoch, args.stop_epoch):
|
|
438
440
|
tic = time.time()
|
|
439
441
|
net.train()
|
|
440
|
-
running_loss = training_utils.SmoothedValue()
|
|
441
442
|
|
|
442
443
|
if args.distributed is True or virtual_epoch_mode is True:
|
|
443
444
|
train_sampler.set_epoch(epoch)
|
|
@@ -534,7 +535,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
534
535
|
running_loss.update(loss.detach())
|
|
535
536
|
|
|
536
537
|
# Write statistics
|
|
537
|
-
if i % args.log_interval == 0 or i == last_batch_idx:
|
|
538
|
+
if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
|
|
538
539
|
time_now = time.time()
|
|
539
540
|
time_cost = time_now - start_time
|
|
540
541
|
iters_processed_in_interval = i - last_idx
|
|
@@ -677,6 +678,7 @@ def get_args_parser() -> argparse.ArgumentParser:
|
|
|
677
678
|
formatter_class=cli.ArgumentHelpFormatter,
|
|
678
679
|
)
|
|
679
680
|
parser.add_argument("-n", "--network", type=str, help="the neural network to use")
|
|
681
|
+
parser.add_argument("-t", "--tag", type=str, help="add model tag")
|
|
680
682
|
parser.add_argument(
|
|
681
683
|
"--model-config",
|
|
682
684
|
action=cli.FlexibleDictAction,
|
|
@@ -688,7 +690,6 @@ def get_args_parser() -> argparse.ArgumentParser:
|
|
|
688
690
|
parser.add_argument("--predictor-embed-dim", type=int, default=384, help="predictor embedding dimension")
|
|
689
691
|
parser.add_argument("--predictor-num-heads", type=int, default=12, help="predictor number of heads")
|
|
690
692
|
parser.add_argument("--predictor-depth", type=int, default=12, help="predictor number of layers")
|
|
691
|
-
parser.add_argument("-t", "--tag", type=str, help="add model tag")
|
|
692
693
|
training_cli.add_optimization_args(parser)
|
|
693
694
|
training_cli.add_lr_wd_args(parser, wd_end=True)
|
|
694
695
|
training_cli.add_lr_scheduler_args(parser)
|
birder/scripts/train_ibot.py
CHANGED
|
@@ -499,12 +499,13 @@ def train(args: argparse.Namespace) -> None:
|
|
|
499
499
|
if virtual_epoch_mode is True:
|
|
500
500
|
train_iter = iter(training_loader)
|
|
501
501
|
|
|
502
|
+
running_loss = training_utils.SmoothedValue()
|
|
503
|
+
train_proto_agreement = training_utils.SmoothedValue()
|
|
504
|
+
|
|
502
505
|
logger.info(f"Starting training with learning rate of {last_lr}")
|
|
503
506
|
for epoch in range(begin_epoch, args.stop_epoch):
|
|
504
507
|
tic = time.time()
|
|
505
508
|
net.train()
|
|
506
|
-
running_loss = training_utils.SmoothedValue()
|
|
507
|
-
train_proto_agreement = training_utils.SmoothedValue()
|
|
508
509
|
|
|
509
510
|
if args.distributed is True or virtual_epoch_mode is True:
|
|
510
511
|
train_sampler.set_epoch(epoch)
|
|
@@ -617,7 +618,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
617
618
|
train_proto_agreement.update(training_utils.accuracy(pred_teacher, pred_student))
|
|
618
619
|
|
|
619
620
|
# Write statistics
|
|
620
|
-
if i % args.log_interval == 0 or i == last_batch_idx:
|
|
621
|
+
if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
|
|
621
622
|
time_now = time.time()
|
|
622
623
|
time_cost = time_now - start_time
|
|
623
624
|
iters_processed_in_interval = i - last_idx
|
|
@@ -774,6 +775,7 @@ def get_args_parser() -> argparse.ArgumentParser:
|
|
|
774
775
|
formatter_class=cli.ArgumentHelpFormatter,
|
|
775
776
|
)
|
|
776
777
|
parser.add_argument("-n", "--network", type=str, help="the neural network to use")
|
|
778
|
+
parser.add_argument("-t", "--tag", type=str, help="add model tag")
|
|
777
779
|
parser.add_argument(
|
|
778
780
|
"--model-config",
|
|
779
781
|
action=cli.FlexibleDictAction,
|
|
@@ -832,7 +834,6 @@ def get_args_parser() -> argparse.ArgumentParser:
|
|
|
832
834
|
"try increasing this value if the loss does not decrease"
|
|
833
835
|
),
|
|
834
836
|
)
|
|
835
|
-
parser.add_argument("-t", "--tag", type=str, help="add model tag")
|
|
836
837
|
training_cli.add_optimization_args(parser)
|
|
837
838
|
training_cli.add_lr_wd_args(parser, wd_end=True)
|
|
838
839
|
training_cli.add_lr_scheduler_args(parser)
|