birder 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. birder/adversarial/deepfool.py +2 -0
  2. birder/adversarial/simba.py +2 -0
  3. birder/common/masking.py +13 -4
  4. birder/inference/classification.py +1 -1
  5. birder/introspection/__init__.py +2 -0
  6. birder/introspection/base.py +0 -7
  7. birder/introspection/feature_pca.py +101 -0
  8. birder/kernels/soft_nms/soft_nms.cpp +5 -2
  9. birder/model_registry/model_registry.py +3 -2
  10. birder/net/convnext_v1.py +20 -0
  11. birder/net/fastvit.py +0 -1
  12. birder/net/flexivit.py +5 -0
  13. birder/net/focalnet.py +0 -1
  14. birder/net/hiera.py +3 -3
  15. birder/net/hieradet.py +116 -28
  16. birder/net/rope_flexivit.py +7 -0
  17. birder/net/rope_vit.py +49 -4
  18. birder/net/smt.py +0 -1
  19. birder/net/ssl/ibot.py +0 -1
  20. birder/net/vit.py +166 -2
  21. birder/scripts/train.py +24 -21
  22. birder/scripts/train_barlow_twins.py +4 -3
  23. birder/scripts/train_byol.py +4 -3
  24. birder/scripts/train_capi.py +6 -5
  25. birder/scripts/train_data2vec.py +4 -3
  26. birder/scripts/train_data2vec2.py +4 -3
  27. birder/scripts/train_detection.py +7 -5
  28. birder/scripts/train_dino_v1.py +5 -4
  29. birder/scripts/train_dino_v2.py +69 -20
  30. birder/scripts/train_dino_v2_dist.py +70 -21
  31. birder/scripts/train_franca.py +8 -7
  32. birder/scripts/train_i_jepa.py +4 -3
  33. birder/scripts/train_ibot.py +5 -4
  34. birder/scripts/train_kd.py +25 -24
  35. birder/scripts/train_mim.py +4 -3
  36. birder/scripts/train_mmcr.py +4 -3
  37. birder/scripts/train_rotnet.py +5 -4
  38. birder/scripts/train_simclr.py +4 -3
  39. birder/scripts/train_vicreg.py +4 -3
  40. birder/tools/avg_model.py +24 -8
  41. birder/tools/introspection.py +35 -9
  42. birder/tools/show_iterator.py +17 -3
  43. birder/version.py +1 -1
  44. {birder-0.3.1.dist-info → birder-0.3.3.dist-info}/METADATA +1 -1
  45. {birder-0.3.1.dist-info → birder-0.3.3.dist-info}/RECORD +49 -48
  46. {birder-0.3.1.dist-info → birder-0.3.3.dist-info}/WHEEL +0 -0
  47. {birder-0.3.1.dist-info → birder-0.3.3.dist-info}/entry_points.txt +0 -0
  48. {birder-0.3.1.dist-info → birder-0.3.3.dist-info}/licenses/LICENSE +0 -0
  49. {birder-0.3.1.dist-info → birder-0.3.3.dist-info}/top_level.txt +0 -0
@@ -384,11 +384,12 @@ def train(args: argparse.Namespace) -> None:
384
384
  if virtual_epoch_mode is True:
385
385
  train_iter = iter(training_loader)
386
386
 
387
+ running_loss = training_utils.SmoothedValue()
388
+
387
389
  logger.info(f"Starting training with learning rate of {last_lr}")
388
390
  for epoch in range(begin_epoch, args.stop_epoch):
389
391
  tic = time.time()
390
392
  net.train()
391
- running_loss = training_utils.SmoothedValue()
392
393
 
393
394
  if args.distributed is True or virtual_epoch_mode is True:
394
395
  train_sampler.set_epoch(epoch)
@@ -463,7 +464,7 @@ def train(args: argparse.Namespace) -> None:
463
464
  running_loss.update(loss.detach())
464
465
 
465
466
  # Write statistics
466
- if i % args.log_interval == 0 or i == last_batch_idx:
467
+ if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
467
468
  time_now = time.time()
468
469
  time_cost = time_now - start_time
469
470
  iters_processed_in_interval = i - last_idx
@@ -603,6 +604,7 @@ def get_args_parser() -> argparse.ArgumentParser:
603
604
  formatter_class=cli.ArgumentHelpFormatter,
604
605
  )
605
606
  parser.add_argument("-n", "--network", type=str, help="the neural network to use")
607
+ parser.add_argument("-t", "--tag", type=str, help="add model tag")
606
608
  parser.add_argument(
607
609
  "--model-config",
608
610
  action=cli.FlexibleDictAction,
@@ -617,7 +619,6 @@ def get_args_parser() -> argparse.ArgumentParser:
617
619
  default=0.999,
618
620
  help="base EMA parameter for teacher update, set a higher value with small batches",
619
621
  )
620
- parser.add_argument("-t", "--tag", type=str, help="add model tag")
621
622
  training_cli.add_optimization_args(parser)
622
623
  training_cli.add_lr_wd_args(parser)
623
624
  training_cli.add_lr_scheduler_args(parser)
@@ -393,11 +393,12 @@ def train(args: argparse.Namespace) -> None:
393
393
  if virtual_epoch_mode is True:
394
394
  train_iter = iter(training_loader)
395
395
 
396
+ running_loss = training_utils.SmoothedValue()
397
+
396
398
  logger.info(f"Starting training with learning rate of {last_lr}")
397
399
  for epoch in range(begin_epoch, args.stop_epoch):
398
400
  tic = time.time()
399
401
  net.train()
400
- running_loss = training_utils.SmoothedValue()
401
402
 
402
403
  if args.distributed is True or virtual_epoch_mode is True:
403
404
  train_sampler.set_epoch(epoch)
@@ -473,7 +474,7 @@ def train(args: argparse.Namespace) -> None:
473
474
  running_loss.update(loss.detach())
474
475
 
475
476
  # Write statistics
476
- if i % args.log_interval == 0 or i == last_batch_idx:
477
+ if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
477
478
  time_now = time.time()
478
479
  time_cost = time_now - start_time
479
480
  iters_processed_in_interval = i - last_idx
@@ -615,6 +616,7 @@ def get_args_parser() -> argparse.ArgumentParser:
615
616
  formatter_class=cli.ArgumentHelpFormatter,
616
617
  )
617
618
  parser.add_argument("-n", "--network", type=str, help="the neural network to use")
619
+ parser.add_argument("-t", "--tag", type=str, help="add model tag")
618
620
  parser.add_argument(
619
621
  "--model-config",
620
622
  action=cli.FlexibleDictAction,
@@ -635,7 +637,6 @@ def get_args_parser() -> argparse.ArgumentParser:
635
637
  default=0.9998,
636
638
  help="base EMA parameter for teacher update, set a higher value with small batches",
637
639
  )
638
- parser.add_argument("-t", "--tag", type=str, help="add model tag")
639
640
  training_cli.add_optimization_args(parser)
640
641
  training_cli.add_lr_wd_args(parser)
641
642
  training_cli.add_lr_scheduler_args(parser)
@@ -538,12 +538,14 @@ def train(args: argparse.Namespace) -> None:
538
538
  if virtual_epoch_mode is True:
539
539
  train_iter = iter(training_loader)
540
540
 
541
+ running_loss = training_utils.SmoothedValue()
542
+ loss_trackers: dict[str, training_utils.SmoothedValue] = {}
543
+
541
544
  logger.info(f"Starting training with learning rate of {last_lr}")
542
545
  for epoch in range(begin_epoch, args.stop_epoch):
543
546
  tic = time.time()
544
547
  net.train()
545
- running_loss = training_utils.SmoothedValue()
546
- loss_trackers: dict[str, training_utils.SmoothedValue] = {}
548
+
547
549
  validation_metrics.reset()
548
550
 
549
551
  if args.distributed is True or virtual_epoch_mode is True:
@@ -634,7 +636,7 @@ def train(args: argparse.Namespace) -> None:
634
636
  loss_trackers[key].update(value.detach())
635
637
 
636
638
  # Write statistics
637
- if i % args.log_interval == 0 or i == last_batch_idx:
639
+ if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
638
640
  time_now = time.time()
639
641
  time_cost = time_now - start_time
640
642
  iters_processed_in_interval = i - last_idx
@@ -889,6 +891,7 @@ def get_args_parser() -> argparse.ArgumentParser:
889
891
  formatter_class=cli.ArgumentHelpFormatter,
890
892
  )
891
893
  parser.add_argument("-n", "--network", type=str, help="the neural network to use")
894
+ parser.add_argument("-t", "--tag", type=str, help="add model tag")
892
895
  parser.add_argument(
893
896
  "--model-config",
894
897
  action=cli.FlexibleDictAction,
@@ -897,8 +900,8 @@ def get_args_parser() -> argparse.ArgumentParser:
897
900
  "('drop_path_rate=0.2' or '{\"units\": [3, 24, 36, 3], \"dropout\": 0.2}'"
898
901
  ),
899
902
  )
900
- parser.add_argument("-t", "--tag", type=str, help="add model tag")
901
903
  parser.add_argument("--backbone", type=str, help="the neural network to used as backbone")
904
+ parser.add_argument("--backbone-tag", type=str, help="backbone training log tag (loading only)")
902
905
  parser.add_argument(
903
906
  "--backbone-model-config",
904
907
  action=cli.FlexibleDictAction,
@@ -907,7 +910,6 @@ def get_args_parser() -> argparse.ArgumentParser:
907
910
  "('drop_path_rate=0.2' or '{\"units\": [3, 24, 36, 3], \"dropout\": 0.2}'"
908
911
  ),
909
912
  )
910
- parser.add_argument("--backbone-tag", type=str, help="backbone training log tag (loading only)")
911
913
  parser.add_argument("--backbone-epoch", type=int, help="load backbone weights from selected epoch")
912
914
  parser.add_argument(
913
915
  "--backbone-pretrained",
@@ -480,12 +480,13 @@ def train(args: argparse.Namespace) -> None:
480
480
  if virtual_epoch_mode is True:
481
481
  train_iter = iter(training_loader)
482
482
 
483
+ running_loss = training_utils.SmoothedValue()
484
+ train_proto_agreement = training_utils.SmoothedValue()
485
+
483
486
  logger.info(f"Starting training with learning rate of {last_lr}")
484
487
  for epoch in range(begin_epoch, args.stop_epoch):
485
488
  tic = time.time()
486
489
  net.train()
487
- running_loss = training_utils.SmoothedValue()
488
- train_proto_agreement = training_utils.SmoothedValue()
489
490
 
490
491
  if args.distributed is True or virtual_epoch_mode is True:
491
492
  train_sampler.set_epoch(epoch)
@@ -581,7 +582,7 @@ def train(args: argparse.Namespace) -> None:
581
582
  train_proto_agreement.update(training_utils.accuracy(pred_teacher, pred_student))
582
583
 
583
584
  # Write statistics
584
- if i % args.log_interval == 0 or i == last_batch_idx:
585
+ if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
585
586
  time_now = time.time()
586
587
  time_cost = time_now - start_time
587
588
  iters_processed_in_interval = i - last_idx
@@ -733,6 +734,7 @@ def get_args_parser() -> argparse.ArgumentParser:
733
734
  formatter_class=cli.ArgumentHelpFormatter,
734
735
  )
735
736
  parser.add_argument("-n", "--network", type=str, help="the neural network to use")
737
+ parser.add_argument("-t", "--tag", type=str, help="add model tag")
736
738
  parser.add_argument(
737
739
  "--model-config",
738
740
  action=cli.FlexibleDictAction,
@@ -788,7 +790,6 @@ def get_args_parser() -> argparse.ArgumentParser:
788
790
  parser.add_argument(
789
791
  "--local-crop-size", type=int, nargs="+", default=[96, 96], metavar=("H", "W"), help="local view size"
790
792
  )
791
- parser.add_argument("-t", "--tag", type=str, help="add model tag")
792
793
  parser.add_argument(
793
794
  "--backbone-epoch",
794
795
  type=int,
@@ -582,22 +582,26 @@ def train(args: argparse.Namespace) -> None:
582
582
  #
583
583
  # Training loop
584
584
  #
585
- track_agreement = not args.no_agreement_metrics
585
+ track_extended_metrics = not args.no_extended_metrics
586
586
  if virtual_epoch_mode is True:
587
587
  train_iter = iter(training_loader)
588
588
 
589
+ running_loss = training_utils.SmoothedValue()
590
+ running_loss_dino_local = training_utils.SmoothedValue()
591
+ running_loss_dino_global = training_utils.SmoothedValue()
592
+ running_loss_koleo = training_utils.SmoothedValue()
593
+ running_loss_ibot_patch = training_utils.SmoothedValue()
594
+ if track_extended_metrics is True:
595
+ train_proto_agreement = training_utils.SmoothedValue()
596
+ train_patch_agreement = training_utils.SmoothedValue()
597
+ running_target_entropy = training_utils.SmoothedValue()
598
+ running_dino_center_drift = training_utils.SmoothedValue()
599
+ running_ibot_center_drift = training_utils.SmoothedValue()
600
+
589
601
  logger.info(f"Starting training with learning rate of {last_lr}")
590
602
  for epoch in range(begin_epoch, args.stop_epoch):
591
603
  tic = time.time()
592
604
  net.train()
593
- running_loss = training_utils.SmoothedValue()
594
- running_loss_dino_local = training_utils.SmoothedValue()
595
- running_loss_dino_global = training_utils.SmoothedValue()
596
- running_loss_koleo = training_utils.SmoothedValue()
597
- running_loss_ibot_patch = training_utils.SmoothedValue()
598
- if track_agreement is True:
599
- train_proto_agreement = training_utils.SmoothedValue()
600
- train_patch_agreement = training_utils.SmoothedValue()
601
605
 
602
606
  if args.sinkhorn_queue_size is not None:
603
607
  queue_active = epoch > args.sinkhorn_queue_warmup_epochs
@@ -662,6 +666,11 @@ def train(args: argparse.Namespace) -> None:
662
666
  )
663
667
  teacher_patch_tokens_raw = teacher_masked_patch_tokens_after_head
664
668
  if args.centering == "centering":
669
+ # Track centers before update for drift computation
670
+ if track_extended_metrics is True:
671
+ prev_dino_center = dino_loss.center.clone()
672
+ prev_ibot_center = ibot_patch_loss.center.clone()
673
+
665
674
  teacher_dino_softmax_centered_list = dino_loss.softmax_center_teacher(
666
675
  teacher_embedding_after_head, teacher_temp=teacher_temp
667
676
  ).view(n_global_crops, -1, *teacher_embedding_after_head.shape[1:])
@@ -788,7 +797,7 @@ def train(args: argparse.Namespace) -> None:
788
797
  running_loss_koleo.update(loss_koleo.detach())
789
798
  running_loss_ibot_patch.update(loss_ibot_patch.detach())
790
799
 
791
- if track_agreement is True:
800
+ if track_extended_metrics is True:
792
801
  probs_teacher = teacher_embedding_after_head.chunk(n_global_crops)
793
802
  probs_student = student_global_embedding_after_head.chunk(n_global_crops)
794
803
  pred_teacher = probs_teacher[0].argmax(dim=1)
@@ -799,8 +808,27 @@ def train(args: argparse.Namespace) -> None:
799
808
  pred_patch_student = student_global_masked_patch_tokens_after_head.argmax(dim=1)
800
809
  train_patch_agreement.update(training_utils.accuracy(pred_patch_teacher, pred_patch_student))
801
810
 
811
+ with torch.no_grad():
812
+ p = teacher_dino_softmax_centered_list.detach()
813
+ p = p.reshape(-1, p.size(-1)) # (N, D)
814
+
815
+ # Mean distribution over prototypes (marginal)
816
+ m = p.mean(dim=0).clamp_min(1e-12)
817
+
818
+ # Entropy of the marginal
819
+ entropy = -(m * m.log()).sum()
820
+
821
+ running_target_entropy.update(entropy.detach())
822
+
823
+ # Compute center drift
824
+ if args.centering == "centering":
825
+ dino_center_drift = torch.norm(dino_loss.center - prev_dino_center, p=2).detach()
826
+ ibot_center_drift = torch.norm(ibot_patch_loss.center - prev_ibot_center, p=2).detach()
827
+ running_dino_center_drift.update(dino_center_drift)
828
+ running_ibot_center_drift.update(ibot_center_drift)
829
+
802
830
  # Write statistics
803
- if i % args.log_interval == 0 or i == last_batch_idx:
831
+ if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
804
832
  time_now = time.time()
805
833
  time_cost = time_now - start_time
806
834
  iters_processed_in_interval = i - last_idx
@@ -819,9 +847,13 @@ def train(args: argparse.Namespace) -> None:
819
847
  running_loss_dino_global.synchronize_between_processes(device)
820
848
  running_loss_koleo.synchronize_between_processes(device)
821
849
  running_loss_ibot_patch.synchronize_between_processes(device)
822
- if track_agreement is True:
850
+ if track_extended_metrics is True:
823
851
  train_proto_agreement.synchronize_between_processes(device)
824
852
  train_patch_agreement.synchronize_between_processes(device)
853
+ running_target_entropy.synchronize_between_processes(device)
854
+ if args.centering == "centering":
855
+ running_dino_center_drift.synchronize_between_processes(device)
856
+ running_ibot_center_drift.synchronize_between_processes(device)
825
857
 
826
858
  with training_utils.single_handler_logging(logger, file_handler, enabled=not disable_tqdm) as log:
827
859
  log.info(
@@ -846,13 +878,19 @@ def train(args: argparse.Namespace) -> None:
846
878
  },
847
879
  ((epoch - 1) * epoch_samples) + ((i + 1) * batch_size * args.world_size),
848
880
  )
849
- if track_agreement is True:
881
+ if track_extended_metrics is True:
882
+ metrics = {
883
+ "prototype_agreement": train_proto_agreement.avg,
884
+ "patch_agreement": train_patch_agreement.avg,
885
+ "target_entropy": running_target_entropy.avg,
886
+ }
887
+ if args.centering == "centering":
888
+ metrics["dino_center_drift"] = running_dino_center_drift.avg
889
+ metrics["ibot_center_drift"] = running_ibot_center_drift.avg
890
+
850
891
  summary_writer.add_scalars(
851
892
  "performance",
852
- {
853
- "prototype_agreement": train_proto_agreement.avg,
854
- "patch_agreement": train_patch_agreement.avg,
855
- },
893
+ metrics,
856
894
  ((epoch - 1) * epoch_samples) + ((i + 1) * batch_size * args.world_size),
857
895
  )
858
896
 
@@ -867,9 +905,17 @@ def train(args: argparse.Namespace) -> None:
867
905
  logger.info(f"[Trn] Epoch {epoch}/{epochs-1} dino_global_loss: {running_loss_dino_global.global_avg:.4f}")
868
906
  logger.info(f"[Trn] Epoch {epoch}/{epochs-1} koleo_loss: {running_loss_koleo.global_avg:.4f}")
869
907
  logger.info(f"[Trn] Epoch {epoch}/{epochs-1} ibot_patch_loss: {running_loss_ibot_patch.global_avg:.4f}")
870
- if track_agreement is True:
908
+ if track_extended_metrics is True:
871
909
  logger.info(f"[Trn] Epoch {epoch}/{epochs-1} prototype_agreement: {train_proto_agreement.global_avg:.4f}")
872
910
  logger.info(f"[Trn] Epoch {epoch}/{epochs-1} patch_agreement: {train_patch_agreement.global_avg:.4f}")
911
+ logger.info(f"[Trn] Epoch {epoch}/{epochs-1} target_entropy: {running_target_entropy.global_avg:.4f}")
912
+ if args.centering == "centering":
913
+ logger.info(
914
+ f"[Trn] Epoch {epoch}/{epochs-1} dino_center_drift: {running_dino_center_drift.global_avg:.4f}"
915
+ )
916
+ logger.info(
917
+ f"[Trn] Epoch {epoch}/{epochs-1} ibot_center_drift: {running_ibot_center_drift.global_avg:.4f}"
918
+ )
873
919
 
874
920
  # Learning rate scheduler update
875
921
  if step_update is False:
@@ -976,6 +1022,7 @@ def get_args_parser() -> argparse.ArgumentParser:
976
1022
  formatter_class=cli.ArgumentHelpFormatter,
977
1023
  )
978
1024
  parser.add_argument("-n", "--network", type=str, help="the neural network to use")
1025
+ parser.add_argument("-t", "--tag", type=str, help="add model tag")
979
1026
  parser.add_argument(
980
1027
  "--model-config",
981
1028
  action=cli.FlexibleDictAction,
@@ -1042,9 +1089,11 @@ def get_args_parser() -> argparse.ArgumentParser:
1042
1089
  help="number of initial epochs to disable Sinkhorn queueing",
1043
1090
  )
1044
1091
  parser.add_argument(
1045
- "--no-agreement-metrics", default=False, action="store_true", help="disable prototype/patch agreement tracking"
1092
+ "--no-extended-metrics",
1093
+ default=False,
1094
+ action="store_true",
1095
+ help="disable extended metrics (prototype/patch agreement, target entropy, center drift)",
1046
1096
  )
1047
- parser.add_argument("-t", "--tag", type=str, help="add model tag")
1048
1097
  training_cli.add_optimization_args(parser)
1049
1098
  training_cli.add_lr_wd_args(parser, wd_end=True)
1050
1099
  training_cli.add_lr_scheduler_args(parser)
@@ -603,23 +603,27 @@ def train(args: argparse.Namespace) -> None:
603
603
  #
604
604
  # Training loop
605
605
  #
606
- track_agreement = not args.no_agreement_metrics
606
+ track_extended_metrics = not args.no_extended_metrics
607
607
  if virtual_epoch_mode is True:
608
608
  train_iter = iter(training_loader)
609
609
 
610
+ running_loss = training_utils.SmoothedValue()
611
+ running_loss_dino_local = training_utils.SmoothedValue()
612
+ running_loss_dino_global = training_utils.SmoothedValue()
613
+ running_loss_koleo = training_utils.SmoothedValue()
614
+ running_loss_ibot_patch = training_utils.SmoothedValue()
615
+ if track_extended_metrics is True:
616
+ train_proto_agreement = training_utils.SmoothedValue()
617
+ train_patch_agreement = training_utils.SmoothedValue()
618
+ running_target_entropy = training_utils.SmoothedValue()
619
+ running_dino_center_drift = training_utils.SmoothedValue()
620
+ running_ibot_center_drift = training_utils.SmoothedValue()
621
+
610
622
  logger.info(f"Starting training with learning rate of {last_lr}")
611
623
  for epoch in range(begin_epoch, args.stop_epoch):
612
624
  tic = time.time()
613
625
  net.train()
614
626
  teacher.eval()
615
- running_loss = training_utils.SmoothedValue()
616
- running_loss_dino_local = training_utils.SmoothedValue()
617
- running_loss_dino_global = training_utils.SmoothedValue()
618
- running_loss_koleo = training_utils.SmoothedValue()
619
- running_loss_ibot_patch = training_utils.SmoothedValue()
620
- if track_agreement is True:
621
- train_proto_agreement = training_utils.SmoothedValue()
622
- train_patch_agreement = training_utils.SmoothedValue()
623
627
 
624
628
  if args.sinkhorn_queue_size is not None:
625
629
  queue_active = epoch > args.sinkhorn_queue_warmup_epochs
@@ -683,6 +687,11 @@ def train(args: argparse.Namespace) -> None:
683
687
  )
684
688
  teacher_patch_tokens_raw = teacher_masked_patch_tokens_after_head
685
689
  if args.centering == "centering":
690
+ # Track centers before update for drift computation
691
+ if track_extended_metrics is True:
692
+ prev_dino_center = dino_loss.center.clone()
693
+ prev_ibot_center = ibot_patch_loss.center.clone()
694
+
686
695
  teacher_dino_softmax_centered_list = dino_loss.softmax_center_teacher(
687
696
  teacher_embedding_after_head, teacher_temp=teacher_temp
688
697
  ).view(n_global_crops, -1, *teacher_embedding_after_head.shape[1:])
@@ -809,7 +818,7 @@ def train(args: argparse.Namespace) -> None:
809
818
  running_loss_koleo.update(loss_koleo.detach())
810
819
  running_loss_ibot_patch.update(loss_ibot_patch.detach())
811
820
 
812
- if track_agreement is True:
821
+ if track_extended_metrics is True:
813
822
  probs_teacher = teacher_embedding_after_head.chunk(n_global_crops)
814
823
  probs_student = student_global_embedding_after_head.chunk(n_global_crops)
815
824
  pred_teacher = probs_teacher[0].argmax(dim=1)
@@ -820,8 +829,27 @@ def train(args: argparse.Namespace) -> None:
820
829
  pred_patch_student = student_global_masked_patch_tokens_after_head.argmax(dim=1)
821
830
  train_patch_agreement.update(training_utils.accuracy(pred_patch_teacher, pred_patch_student))
822
831
 
832
+ with torch.no_grad():
833
+ p = teacher_dino_softmax_centered_list.detach()
834
+ p = p.reshape(-1, p.size(-1)) # (N, D)
835
+
836
+ # Mean distribution over prototypes (marginal)
837
+ m = p.mean(dim=0).clamp_min(1e-12)
838
+
839
+ # Entropy of the marginal
840
+ entropy = -(m * m.log()).sum()
841
+
842
+ running_target_entropy.update(entropy.detach())
843
+
844
+ # Compute center drift
845
+ if args.centering == "centering":
846
+ dino_center_drift = torch.norm(dino_loss.center - prev_dino_center, p=2).detach()
847
+ ibot_center_drift = torch.norm(ibot_patch_loss.center - prev_ibot_center, p=2).detach()
848
+ running_dino_center_drift.update(dino_center_drift)
849
+ running_ibot_center_drift.update(ibot_center_drift)
850
+
823
851
  # Write statistics
824
- if i % args.log_interval == 0 or i == last_batch_idx:
852
+ if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
825
853
  time_now = time.time()
826
854
  time_cost = time_now - start_time
827
855
  iters_processed_in_interval = i - last_idx
@@ -840,9 +868,13 @@ def train(args: argparse.Namespace) -> None:
840
868
  running_loss_dino_global.synchronize_between_processes(device)
841
869
  running_loss_koleo.synchronize_between_processes(device)
842
870
  running_loss_ibot_patch.synchronize_between_processes(device)
843
- if track_agreement is True:
871
+ if track_extended_metrics is True:
844
872
  train_proto_agreement.synchronize_between_processes(device)
845
873
  train_patch_agreement.synchronize_between_processes(device)
874
+ running_target_entropy.synchronize_between_processes(device)
875
+ if args.centering == "centering":
876
+ running_dino_center_drift.synchronize_between_processes(device)
877
+ running_ibot_center_drift.synchronize_between_processes(device)
846
878
 
847
879
  with training_utils.single_handler_logging(logger, file_handler, enabled=not disable_tqdm) as log:
848
880
  log.info(
@@ -867,13 +899,19 @@ def train(args: argparse.Namespace) -> None:
867
899
  },
868
900
  ((epoch - 1) * epoch_samples) + ((i + 1) * batch_size * args.world_size),
869
901
  )
870
- if track_agreement is True:
902
+ if track_extended_metrics is True:
903
+ metrics = {
904
+ "prototype_agreement": train_proto_agreement.avg,
905
+ "patch_agreement": train_patch_agreement.avg,
906
+ "target_entropy": running_target_entropy.avg,
907
+ }
908
+ if args.centering == "centering":
909
+ metrics["dino_center_drift"] = running_dino_center_drift.avg
910
+ metrics["ibot_center_drift"] = running_ibot_center_drift.avg
911
+
871
912
  summary_writer.add_scalars(
872
913
  "performance",
873
- {
874
- "prototype_agreement": train_proto_agreement.avg,
875
- "patch_agreement": train_patch_agreement.avg,
876
- },
914
+ metrics,
877
915
  ((epoch - 1) * epoch_samples) + ((i + 1) * batch_size * args.world_size),
878
916
  )
879
917
 
@@ -888,9 +926,17 @@ def train(args: argparse.Namespace) -> None:
888
926
  logger.info(f"[Trn] Epoch {epoch}/{epochs-1} dino_global_loss: {running_loss_dino_global.global_avg:.4f}")
889
927
  logger.info(f"[Trn] Epoch {epoch}/{epochs-1} koleo_loss: {running_loss_koleo.global_avg:.4f}")
890
928
  logger.info(f"[Trn] Epoch {epoch}/{epochs-1} ibot_patch_loss: {running_loss_ibot_patch.global_avg:.4f}")
891
- if track_agreement is True:
929
+ if track_extended_metrics is True:
892
930
  logger.info(f"[Trn] Epoch {epoch}/{epochs-1} prototype_agreement: {train_proto_agreement.global_avg:.4f}")
893
931
  logger.info(f"[Trn] Epoch {epoch}/{epochs-1} patch_agreement: {train_patch_agreement.global_avg:.4f}")
932
+ logger.info(f"[Trn] Epoch {epoch}/{epochs-1} target_entropy: {running_target_entropy.global_avg:.4f}")
933
+ if args.centering == "centering":
934
+ logger.info(
935
+ f"[Trn] Epoch {epoch}/{epochs-1} dino_center_drift: {running_dino_center_drift.global_avg:.4f}"
936
+ )
937
+ logger.info(
938
+ f"[Trn] Epoch {epoch}/{epochs-1} ibot_center_drift: {running_ibot_center_drift.global_avg:.4f}"
939
+ )
894
940
 
895
941
  # Learning rate scheduler update
896
942
  if step_update is False:
@@ -998,6 +1044,7 @@ def get_args_parser() -> argparse.ArgumentParser:
998
1044
  formatter_class=cli.ArgumentHelpFormatter,
999
1045
  )
1000
1046
  parser.add_argument("-n", "--network", type=str, help="the neural network to use")
1047
+ parser.add_argument("-t", "--tag", type=str, help="add model tag")
1001
1048
  parser.add_argument(
1002
1049
  "--model-config",
1003
1050
  action=cli.FlexibleDictAction,
@@ -1006,8 +1053,8 @@ def get_args_parser() -> argparse.ArgumentParser:
1006
1053
  "('drop_path_rate=0.2' or '{\"units\": [3, 24, 36, 3], \"dropout\": 0.2}'"
1007
1054
  ),
1008
1055
  )
1009
- parser.add_argument("-t", "--tag", type=str, help="add model tag")
1010
1056
  parser.add_argument("--teacher", type=str, help="the neural network to use as teacher")
1057
+ parser.add_argument("--teacher-tag", type=str, help="teacher training logs tag")
1011
1058
  parser.add_argument(
1012
1059
  "--teacher-model-config",
1013
1060
  action=cli.FlexibleDictAction,
@@ -1016,7 +1063,6 @@ def get_args_parser() -> argparse.ArgumentParser:
1016
1063
  "('drop_path_rate=0.2' or '{\"units\": [3, 24, 36, 3], \"dropout\": 0.2}'"
1017
1064
  ),
1018
1065
  )
1019
- parser.add_argument("--teacher-tag", type=str, help="teacher training logs tag")
1020
1066
  parser.add_argument("--teacher-epoch", type=int, metavar="N", help="load teacher weights from selected epoch")
1021
1067
  parser.add_argument("--dino-loss-weight", type=float, default=1.0, help="weight for the DINO loss component")
1022
1068
  parser.add_argument("--dino-out-dim", type=int, default=65536, help="dimensionality of the DINO head output")
@@ -1070,7 +1116,10 @@ def get_args_parser() -> argparse.ArgumentParser:
1070
1116
  help="number of initial epochs to disable Sinkhorn queueing",
1071
1117
  )
1072
1118
  parser.add_argument(
1073
- "--no-agreement-metrics", default=False, action="store_true", help="disable prototype/patch agreement tracking"
1119
+ "--no-extended-metrics",
1120
+ default=False,
1121
+ action="store_true",
1122
+ help="disable extended metrics (prototype/patch agreement, target entropy, center drift)",
1074
1123
  )
1075
1124
  training_cli.add_optimization_args(parser)
1076
1125
  training_cli.add_lr_wd_args(parser, wd_end=True)
@@ -612,15 +612,16 @@ def train(args: argparse.Namespace) -> None:
612
612
  if virtual_epoch_mode is True:
613
613
  train_iter = iter(training_loader)
614
614
 
615
+ running_loss = training_utils.SmoothedValue()
616
+ running_loss_dino_local = training_utils.SmoothedValue()
617
+ running_loss_dino_global = training_utils.SmoothedValue()
618
+ running_loss_koleo = training_utils.SmoothedValue()
619
+ running_loss_ibot_patch = training_utils.SmoothedValue()
620
+
615
621
  logger.info(f"Starting training with learning rate of {last_lr}")
616
622
  for epoch in range(begin_epoch, args.stop_epoch):
617
623
  tic = time.time()
618
624
  net.train()
619
- running_loss = training_utils.SmoothedValue()
620
- running_loss_dino_local = training_utils.SmoothedValue()
621
- running_loss_dino_global = training_utils.SmoothedValue()
622
- running_loss_koleo = training_utils.SmoothedValue()
623
- running_loss_ibot_patch = training_utils.SmoothedValue()
624
625
 
625
626
  if args.sinkhorn_queue_size is not None:
626
627
  queue_active = epoch > args.sinkhorn_queue_warmup_epochs
@@ -804,7 +805,7 @@ def train(args: argparse.Namespace) -> None:
804
805
  running_loss_ibot_patch.update(loss_ibot_patch.detach())
805
806
 
806
807
  # Write statistics
807
- if i % args.log_interval == 0 or i == last_batch_idx:
808
+ if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
808
809
  time_now = time.time()
809
810
  time_cost = time_now - start_time
810
811
  iters_processed_in_interval = i - last_idx
@@ -963,6 +964,7 @@ def get_args_parser() -> argparse.ArgumentParser:
963
964
  formatter_class=cli.ArgumentHelpFormatter,
964
965
  )
965
966
  parser.add_argument("-n", "--network", type=str, help="the neural network to use")
967
+ parser.add_argument("-t", "--tag", type=str, help="add model tag")
966
968
  parser.add_argument(
967
969
  "--model-config",
968
970
  action=cli.FlexibleDictAction,
@@ -1024,7 +1026,6 @@ def get_args_parser() -> argparse.ArgumentParser:
1024
1026
  parser.add_argument(
1025
1027
  "--local-crop-size", type=int, nargs="+", default=[96, 96], metavar=("H", "W"), help="local view size"
1026
1028
  )
1027
- parser.add_argument("-t", "--tag", type=str, help="add model tag")
1028
1029
  training_cli.add_optimization_args(parser)
1029
1030
  training_cli.add_lr_wd_args(parser, wd_end=True)
1030
1031
  training_cli.add_lr_scheduler_args(parser)
@@ -433,11 +433,12 @@ def train(args: argparse.Namespace) -> None:
433
433
  if virtual_epoch_mode is True:
434
434
  train_iter = iter(training_loader)
435
435
 
436
+ running_loss = training_utils.SmoothedValue()
437
+
436
438
  logger.info(f"Starting training with learning rate of {last_lr}")
437
439
  for epoch in range(begin_epoch, args.stop_epoch):
438
440
  tic = time.time()
439
441
  net.train()
440
- running_loss = training_utils.SmoothedValue()
441
442
 
442
443
  if args.distributed is True or virtual_epoch_mode is True:
443
444
  train_sampler.set_epoch(epoch)
@@ -534,7 +535,7 @@ def train(args: argparse.Namespace) -> None:
534
535
  running_loss.update(loss.detach())
535
536
 
536
537
  # Write statistics
537
- if i % args.log_interval == 0 or i == last_batch_idx:
538
+ if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
538
539
  time_now = time.time()
539
540
  time_cost = time_now - start_time
540
541
  iters_processed_in_interval = i - last_idx
@@ -677,6 +678,7 @@ def get_args_parser() -> argparse.ArgumentParser:
677
678
  formatter_class=cli.ArgumentHelpFormatter,
678
679
  )
679
680
  parser.add_argument("-n", "--network", type=str, help="the neural network to use")
681
+ parser.add_argument("-t", "--tag", type=str, help="add model tag")
680
682
  parser.add_argument(
681
683
  "--model-config",
682
684
  action=cli.FlexibleDictAction,
@@ -688,7 +690,6 @@ def get_args_parser() -> argparse.ArgumentParser:
688
690
  parser.add_argument("--predictor-embed-dim", type=int, default=384, help="predictor embedding dimension")
689
691
  parser.add_argument("--predictor-num-heads", type=int, default=12, help="predictor number of heads")
690
692
  parser.add_argument("--predictor-depth", type=int, default=12, help="predictor number of layers")
691
- parser.add_argument("-t", "--tag", type=str, help="add model tag")
692
693
  training_cli.add_optimization_args(parser)
693
694
  training_cli.add_lr_wd_args(parser, wd_end=True)
694
695
  training_cli.add_lr_scheduler_args(parser)
@@ -499,12 +499,13 @@ def train(args: argparse.Namespace) -> None:
499
499
  if virtual_epoch_mode is True:
500
500
  train_iter = iter(training_loader)
501
501
 
502
+ running_loss = training_utils.SmoothedValue()
503
+ train_proto_agreement = training_utils.SmoothedValue()
504
+
502
505
  logger.info(f"Starting training with learning rate of {last_lr}")
503
506
  for epoch in range(begin_epoch, args.stop_epoch):
504
507
  tic = time.time()
505
508
  net.train()
506
- running_loss = training_utils.SmoothedValue()
507
- train_proto_agreement = training_utils.SmoothedValue()
508
509
 
509
510
  if args.distributed is True or virtual_epoch_mode is True:
510
511
  train_sampler.set_epoch(epoch)
@@ -617,7 +618,7 @@ def train(args: argparse.Namespace) -> None:
617
618
  train_proto_agreement.update(training_utils.accuracy(pred_teacher, pred_student))
618
619
 
619
620
  # Write statistics
620
- if i % args.log_interval == 0 or i == last_batch_idx:
621
+ if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
621
622
  time_now = time.time()
622
623
  time_cost = time_now - start_time
623
624
  iters_processed_in_interval = i - last_idx
@@ -774,6 +775,7 @@ def get_args_parser() -> argparse.ArgumentParser:
774
775
  formatter_class=cli.ArgumentHelpFormatter,
775
776
  )
776
777
  parser.add_argument("-n", "--network", type=str, help="the neural network to use")
778
+ parser.add_argument("-t", "--tag", type=str, help="add model tag")
777
779
  parser.add_argument(
778
780
  "--model-config",
779
781
  action=cli.FlexibleDictAction,
@@ -832,7 +834,6 @@ def get_args_parser() -> argparse.ArgumentParser:
832
834
  "try increasing this value if the loss does not decrease"
833
835
  ),
834
836
  )
835
- parser.add_argument("-t", "--tag", type=str, help="add model tag")
836
837
  training_cli.add_optimization_args(parser)
837
838
  training_cli.add_lr_wd_args(parser, wd_end=True)
838
839
  training_cli.add_lr_scheduler_args(parser)