birder 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. birder/adversarial/deepfool.py +2 -0
  2. birder/adversarial/simba.py +2 -0
  3. birder/common/masking.py +13 -4
  4. birder/inference/classification.py +1 -1
  5. birder/introspection/__init__.py +2 -0
  6. birder/introspection/base.py +0 -7
  7. birder/introspection/feature_pca.py +101 -0
  8. birder/kernels/soft_nms/soft_nms.cpp +5 -2
  9. birder/model_registry/model_registry.py +3 -2
  10. birder/net/convnext_v1.py +20 -0
  11. birder/net/fastvit.py +0 -1
  12. birder/net/flexivit.py +5 -0
  13. birder/net/focalnet.py +0 -1
  14. birder/net/hiera.py +3 -3
  15. birder/net/hieradet.py +116 -28
  16. birder/net/rope_flexivit.py +7 -0
  17. birder/net/rope_vit.py +49 -4
  18. birder/net/smt.py +0 -1
  19. birder/net/ssl/ibot.py +0 -1
  20. birder/net/vit.py +166 -2
  21. birder/scripts/train.py +24 -21
  22. birder/scripts/train_barlow_twins.py +4 -3
  23. birder/scripts/train_byol.py +4 -3
  24. birder/scripts/train_capi.py +6 -5
  25. birder/scripts/train_data2vec.py +4 -3
  26. birder/scripts/train_data2vec2.py +4 -3
  27. birder/scripts/train_detection.py +7 -5
  28. birder/scripts/train_dino_v1.py +5 -4
  29. birder/scripts/train_dino_v2.py +69 -20
  30. birder/scripts/train_dino_v2_dist.py +70 -21
  31. birder/scripts/train_franca.py +8 -7
  32. birder/scripts/train_i_jepa.py +4 -3
  33. birder/scripts/train_ibot.py +5 -4
  34. birder/scripts/train_kd.py +25 -24
  35. birder/scripts/train_mim.py +4 -3
  36. birder/scripts/train_mmcr.py +4 -3
  37. birder/scripts/train_rotnet.py +5 -4
  38. birder/scripts/train_simclr.py +4 -3
  39. birder/scripts/train_vicreg.py +4 -3
  40. birder/tools/avg_model.py +24 -8
  41. birder/tools/introspection.py +35 -9
  42. birder/tools/show_iterator.py +17 -3
  43. birder/version.py +1 -1
  44. {birder-0.3.1.dist-info → birder-0.3.3.dist-info}/METADATA +1 -1
  45. {birder-0.3.1.dist-info → birder-0.3.3.dist-info}/RECORD +49 -48
  46. {birder-0.3.1.dist-info → birder-0.3.3.dist-info}/WHEEL +0 -0
  47. {birder-0.3.1.dist-info → birder-0.3.3.dist-info}/entry_points.txt +0 -0
  48. {birder-0.3.1.dist-info → birder-0.3.3.dist-info}/licenses/LICENSE +0 -0
  49. {birder-0.3.1.dist-info → birder-0.3.3.dist-info}/top_level.txt +0 -0
@@ -356,7 +356,7 @@ def train(args: argparse.Namespace) -> None:
356
356
 
357
357
  # Distillation
358
358
  if distillation_type == "soft":
359
- distillation_criterion = torch.nn.KLDivLoss(reduction="batchmean", log_target=False)
359
+ distillation_criterion = torch.nn.KLDivLoss(reduction="batchmean", log_target=True)
360
360
  elif distillation_type == "hard":
361
361
  distillation_criterion = torch.nn.CrossEntropyLoss()
362
362
  elif distillation_type == "deit":
@@ -567,6 +567,11 @@ def train(args: argparse.Namespace) -> None:
567
567
  if virtual_epoch_mode is True:
568
568
  train_iter = iter(training_loader)
569
569
 
570
+ running_loss = training_utils.SmoothedValue(window_size=64)
571
+ running_val_loss = training_utils.SmoothedValue()
572
+ train_accuracy = training_utils.SmoothedValue(window_size=64)
573
+ val_accuracy = training_utils.SmoothedValue()
574
+
570
575
  logger.info(f"Starting training with learning rate of {last_lr}")
571
576
  for epoch in range(begin_epoch, args.stop_epoch):
572
577
  tic = time.time()
@@ -574,11 +579,6 @@ def train(args: argparse.Namespace) -> None:
574
579
  if embedding_projection is not None:
575
580
  embedding_projection.train()
576
581
 
577
- running_loss = training_utils.SmoothedValue(window_size=64)
578
- running_val_loss = training_utils.SmoothedValue()
579
- train_accuracy = training_utils.SmoothedValue(window_size=64)
580
- val_accuracy = training_utils.SmoothedValue()
581
-
582
582
  if args.distributed is True or virtual_epoch_mode is True:
583
583
  train_sampler.set_epoch(epoch)
584
584
 
@@ -625,7 +625,7 @@ def train(args: argparse.Namespace) -> None:
625
625
  with torch.no_grad():
626
626
  teacher_outputs = teacher(inputs)
627
627
  if distillation_type == "soft":
628
- teacher_targets = F.softmax(teacher_outputs / args.temperature, dim=-1)
628
+ teacher_targets = F.log_softmax(teacher_outputs / args.temperature, dim=-1)
629
629
  else:
630
630
  teacher_targets = teacher_outputs.argmax(dim=-1)
631
631
 
@@ -695,7 +695,7 @@ def train(args: argparse.Namespace) -> None:
695
695
  train_accuracy.update(training_utils.accuracy(targets, outputs.detach()))
696
696
 
697
697
  # Write statistics
698
- if i % args.log_interval == 0 or i == last_batch_idx:
698
+ if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
699
699
  time_now = time.time()
700
700
  time_cost = time_now - start_time
701
701
  iters_processed_in_interval = i - last_idx
@@ -900,43 +900,44 @@ def get_args_parser() -> argparse.ArgumentParser:
900
900
  "A typical 'soft' distillation:\n"
901
901
  "torchrun --nproc_per_node=2 train_kd.py \\\n"
902
902
  " --type soft \\\n"
903
- " --temperature 1 \\\n"
904
903
  " --teacher vit_l16 \\\n"
905
904
  " --student tiny_vit_5m \\\n"
905
+ " --temperature 3.5 \\\n"
906
+ " --batch-size 32 \\\n"
906
907
  " --opt adamw \\\n"
908
+ " --clip-grad-norm 5 \\\n"
907
909
  " --lr 0.002 \\\n"
910
+ " --wd 0.01 \\\n"
911
+ " --norm-wd 0 \\\n"
908
912
  " --lr-scheduler cosine \\\n"
909
913
  " --lr-cosine-min 1e-7 \\\n"
910
- " --batch-size 64 \\\n"
911
914
  " --warmup-epochs 5 \\\n"
912
- " --wd 0.01 \\\n"
913
- " --norm-wd 0 \\\n"
914
915
  " --smoothing-alpha 0.1 \\\n"
915
- " --clip-grad-norm 5 \\\n"
916
- " --amp \\\n"
916
+ " --amp --amp-dtype bfloat16 \\\n"
917
917
  " --compile \\\n"
918
918
  " --wds \\\n"
919
- " --wds-class-file data/intermediate_packed/classes.txt \\\n"
920
- " --wds-info data/intermediate_packed/_info.json\n"
919
+ " --wds-info data/intermediate_packed/_info.json \\\n"
920
+ " --wds-class-file data/intermediate_packed/classes.txt\n"
921
921
  "\n"
922
- "DeiT style distillation:\n"
922
+ "DeiT-style distillation:\n"
923
923
  "torchrun --nproc_per_node=2 train_kd.py \\\n"
924
924
  " --type deit \\\n"
925
925
  " --teacher regnet_y_8g \\\n"
926
926
  " --student deit_s16 \\\n"
927
+ " --batch-size 64 \\\n"
927
928
  " --opt adamw \\\n"
929
+ " --clip-grad-norm 1 \\\n"
928
930
  " --lr 0.0005 \\\n"
929
- " --lr-scheduler cosine \\\n"
930
- " --warmup-epochs 5 \\\n"
931
- " --epochs 300 \\\n"
932
931
  " --wd 0.05 \\\n"
933
932
  " --norm-wd 0 \\\n"
933
+ " --lr-scheduler cosine \\\n"
934
+ " --epochs 300 \\\n"
935
+ " --warmup-epochs 5 \\\n"
936
+ " --aug-level 8 \\\n"
934
937
  " --smoothing-alpha 0.1 \\\n"
935
938
  " --mixup-alpha 0.8 \\\n"
936
- " --aug-level 8 \\\n"
937
939
  " --model-ema \\\n"
938
940
  " --ra-sampler --ra-reps 2 \\\n"
939
- " --clip-grad-norm 1 \\\n"
940
941
  " --amp \\\n"
941
942
  " --compile\n"
942
943
  ),
@@ -944,6 +945,7 @@ def get_args_parser() -> argparse.ArgumentParser:
944
945
  )
945
946
  parser.add_argument("--type", type=str, choices=typing.get_args(DistType), help="type of distillation")
946
947
  parser.add_argument("--teacher", type=str, help="the teacher network")
948
+ parser.add_argument("--teacher-tag", type=str, help="teacher training log tag (loading only)")
947
949
  parser.add_argument(
948
950
  "--teacher-model-config",
949
951
  action=cli.FlexibleDictAction,
@@ -952,11 +954,11 @@ def get_args_parser() -> argparse.ArgumentParser:
952
954
  "('drop_path_rate=0.2' or '{\"units\": [3, 24, 36, 3], \"dropout\": 0.2}'"
953
955
  ),
954
956
  )
955
- parser.add_argument("--teacher-tag", type=str, help="teacher training log tag (loading only)")
956
957
  parser.add_argument("--pts", default=False, action="store_true", help="load torchscript teacher")
957
958
  parser.add_argument("--pt2", default=False, action="store_true", help="load pt2 teacher")
958
959
  parser.add_argument("--teacher-epoch", type=int, help="load teacher weights from selected epoch")
959
960
  parser.add_argument("--student", type=str, help="the student network to train")
961
+ parser.add_argument("--student-tag", type=str, help="add student training logs tag")
960
962
  parser.add_argument(
961
963
  "--student-model-config",
962
964
  action=cli.FlexibleDictAction,
@@ -965,7 +967,6 @@ def get_args_parser() -> argparse.ArgumentParser:
965
967
  "('drop_path_rate=0.2' or '{\"units\": [3, 24, 36, 3], \"dropout\": 0.2}'"
966
968
  ),
967
969
  )
968
- parser.add_argument("--student-tag", type=str, help="add student training logs tag")
969
970
  parser.add_argument(
970
971
  "--temperature",
971
972
  type=float,
@@ -368,11 +368,12 @@ def train(args: argparse.Namespace) -> None:
368
368
  if virtual_epoch_mode is True:
369
369
  train_iter = iter(training_loader)
370
370
 
371
+ running_loss = training_utils.SmoothedValue()
372
+
371
373
  logger.info(f"Starting training with learning rate of {last_lr}")
372
374
  for epoch in range(begin_epoch, args.stop_epoch):
373
375
  tic = time.time()
374
376
  net.train()
375
- running_loss = training_utils.SmoothedValue()
376
377
 
377
378
  if args.distributed is True or virtual_epoch_mode is True:
378
379
  train_sampler.set_epoch(epoch)
@@ -436,7 +437,7 @@ def train(args: argparse.Namespace) -> None:
436
437
  running_loss.update(loss.detach())
437
438
 
438
439
  # Write statistics
439
- if i % args.log_interval == 0 or i == last_batch_idx:
440
+ if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
440
441
  time_now = time.time()
441
442
  time_cost = time_now - start_time
442
443
  iters_processed_in_interval = i - last_idx
@@ -578,6 +579,7 @@ def get_args_parser() -> argparse.ArgumentParser:
578
579
  formatter_class=cli.ArgumentHelpFormatter,
579
580
  )
580
581
  parser.add_argument("-n", "--network", type=str, help="the neural network to use")
582
+ parser.add_argument("-t", "--tag", type=str, help="add model tag")
581
583
  parser.add_argument(
582
584
  "--model-config",
583
585
  action=cli.FlexibleDictAction,
@@ -586,7 +588,6 @@ def get_args_parser() -> argparse.ArgumentParser:
586
588
  "('drop_path_rate=0.2' or '{\"units\": [3, 24, 36, 3], \"dropout\": 0.2}'"
587
589
  ),
588
590
  )
589
- parser.add_argument("-t", "--tag", type=str, help="add model tag")
590
591
  parser.add_argument("--encoder", type=str, help="the neural network to used as encoder (network being pre-trained)")
591
592
  parser.add_argument(
592
593
  "--encoder-model-config",
@@ -370,11 +370,12 @@ def train(args: argparse.Namespace) -> None:
370
370
  if virtual_epoch_mode is True:
371
371
  train_iter = iter(training_loader)
372
372
 
373
+ running_loss = training_utils.SmoothedValue()
374
+
373
375
  logger.info(f"Starting training with learning rate of {last_lr}")
374
376
  for epoch in range(begin_epoch, args.stop_epoch):
375
377
  tic = time.time()
376
378
  net.train()
377
- running_loss = training_utils.SmoothedValue()
378
379
 
379
380
  if args.distributed is True or virtual_epoch_mode is True:
380
381
  train_sampler.set_epoch(epoch)
@@ -447,7 +448,7 @@ def train(args: argparse.Namespace) -> None:
447
448
  running_loss.update(loss.detach())
448
449
 
449
450
  # Write statistics
450
- if i % args.log_interval == 0 or i == last_batch_idx:
451
+ if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
451
452
  time_now = time.time()
452
453
  time_cost = time_now - start_time
453
454
  iters_processed_in_interval = i - last_idx
@@ -587,6 +588,7 @@ def get_args_parser() -> argparse.ArgumentParser:
587
588
  formatter_class=cli.ArgumentHelpFormatter,
588
589
  )
589
590
  parser.add_argument("-n", "--network", type=str, help="the neural network to train")
591
+ parser.add_argument("-t", "--tag", type=str, help="add model tag")
590
592
  parser.add_argument(
591
593
  "--model-config",
592
594
  action=cli.FlexibleDictAction,
@@ -606,7 +608,6 @@ def get_args_parser() -> argparse.ArgumentParser:
606
608
  parser.add_argument("--lambda-coeff", type=float, default=0.0, help="weight of local nuc")
607
609
  parser.add_argument("--n-aug", type=int, default=2, help="number of views")
608
610
  parser.add_argument("--momentum-tau", type=float, default=0.99, help="base EMA parameter for momentum update")
609
- parser.add_argument("-t", "--tag", type=str, help="add model tag")
610
611
  training_cli.add_optimization_args(parser)
611
612
  training_cli.add_lr_wd_args(parser)
612
613
  training_cli.add_lr_scheduler_args(parser)
@@ -381,12 +381,13 @@ def train(args: argparse.Namespace) -> None:
381
381
  if virtual_epoch_mode is True:
382
382
  train_iter = iter(training_loader)
383
383
 
384
+ running_loss = training_utils.SmoothedValue(window_size=64)
385
+ train_accuracy = training_utils.SmoothedValue(window_size=64)
386
+
384
387
  logger.info(f"Starting training with learning rate of {last_lr}")
385
388
  for epoch in range(begin_epoch, args.stop_epoch):
386
389
  tic = time.time()
387
390
  net.train()
388
- running_loss = training_utils.SmoothedValue(window_size=64)
389
- train_accuracy = training_utils.SmoothedValue(window_size=64)
390
391
 
391
392
  if args.distributed is True or virtual_epoch_mode is True:
392
393
  train_sampler.set_epoch(epoch)
@@ -455,7 +456,7 @@ def train(args: argparse.Namespace) -> None:
455
456
  train_accuracy.update(training_utils.accuracy(targets, outputs.detach()))
456
457
 
457
458
  # Write statistics
458
- if i % args.log_interval == 0 or i == last_batch_idx:
459
+ if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
459
460
  time_now = time.time()
460
461
  time_cost = time_now - start_time
461
462
  iters_processed_in_interval = i - last_idx
@@ -580,6 +581,7 @@ def get_args_parser() -> argparse.ArgumentParser:
580
581
  formatter_class=cli.ArgumentHelpFormatter,
581
582
  )
582
583
  parser.add_argument("-n", "--network", type=str, help="the neural network to train")
584
+ parser.add_argument("-t", "--tag", type=str, help="add model tag")
583
585
  parser.add_argument(
584
586
  "--model-config",
585
587
  action=cli.FlexibleDictAction,
@@ -594,7 +596,6 @@ def get_args_parser() -> argparse.ArgumentParser:
594
596
  default=0.75,
595
597
  help="probability of applying a non-zero rotation (90, 180, or 270 degrees)",
596
598
  )
597
- parser.add_argument("-t", "--tag", type=str, help="add model tag")
598
599
  parser.add_argument(
599
600
  "--freeze-body",
600
601
  default=False,
@@ -363,11 +363,12 @@ def train(args: argparse.Namespace) -> None:
363
363
  if virtual_epoch_mode is True:
364
364
  train_iter = iter(training_loader)
365
365
 
366
+ running_loss = training_utils.SmoothedValue()
367
+
366
368
  logger.info(f"Starting training with learning rate of {last_lr}")
367
369
  for epoch in range(begin_epoch, args.stop_epoch):
368
370
  tic = time.time()
369
371
  net.train()
370
- running_loss = training_utils.SmoothedValue()
371
372
 
372
373
  if args.distributed is True or virtual_epoch_mode is True:
373
374
  train_sampler.set_epoch(epoch)
@@ -431,7 +432,7 @@ def train(args: argparse.Namespace) -> None:
431
432
  running_loss.update(loss.detach())
432
433
 
433
434
  # Write statistics
434
- if i % args.log_interval == 0 or i == last_batch_idx:
435
+ if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
435
436
  time_now = time.time()
436
437
  time_cost = time_now - start_time
437
438
  iters_processed_in_interval = i - last_idx
@@ -572,6 +573,7 @@ def get_args_parser() -> argparse.ArgumentParser:
572
573
  formatter_class=cli.ArgumentHelpFormatter,
573
574
  )
574
575
  parser.add_argument("-n", "--network", type=str, help="the neural network to train")
576
+ parser.add_argument("-t", "--tag", type=str, help="add model tag")
575
577
  parser.add_argument(
576
578
  "--model-config",
577
579
  action=cli.FlexibleDictAction,
@@ -582,7 +584,6 @@ def get_args_parser() -> argparse.ArgumentParser:
582
584
  )
583
585
  parser.add_argument("--projection-dim", type=int, default=128, metavar="DIM", help="projection dim")
584
586
  parser.add_argument("--temperature", type=float, default=0.1, help="loss temperature")
585
- parser.add_argument("-t", "--tag", type=str, help="add model tag")
586
587
  training_cli.add_optimization_args(parser)
587
588
  training_cli.add_lr_wd_args(parser)
588
589
  training_cli.add_lr_scheduler_args(parser)
@@ -369,11 +369,12 @@ def train(args: argparse.Namespace) -> None:
369
369
  if virtual_epoch_mode is True:
370
370
  train_iter = iter(training_loader)
371
371
 
372
+ running_loss = training_utils.SmoothedValue()
373
+
372
374
  logger.info(f"Starting training with learning rate of {last_lr}")
373
375
  for epoch in range(begin_epoch, args.stop_epoch):
374
376
  tic = time.time()
375
377
  net.train()
376
- running_loss = training_utils.SmoothedValue()
377
378
 
378
379
  if args.distributed is True or virtual_epoch_mode is True:
379
380
  train_sampler.set_epoch(epoch)
@@ -437,7 +438,7 @@ def train(args: argparse.Namespace) -> None:
437
438
  running_loss.update(loss.detach())
438
439
 
439
440
  # Write statistics
440
- if i % args.log_interval == 0 or i == last_batch_idx:
441
+ if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
441
442
  time_now = time.time()
442
443
  time_cost = time_now - start_time
443
444
  iters_processed_in_interval = i - last_idx
@@ -577,6 +578,7 @@ def get_args_parser() -> argparse.ArgumentParser:
577
578
  formatter_class=cli.ArgumentHelpFormatter,
578
579
  )
579
580
  parser.add_argument("-n", "--network", type=str, help="the neural network to train")
581
+ parser.add_argument("-t", "--tag", type=str, help="add model tag")
580
582
  parser.add_argument(
581
583
  "--model-config",
582
584
  action=cli.FlexibleDictAction,
@@ -589,7 +591,6 @@ def get_args_parser() -> argparse.ArgumentParser:
589
591
  parser.add_argument("--sim-coeff", type=float, default=25.0, help="invariance regularization loss coefficient")
590
592
  parser.add_argument("--std-coeff", type=float, default=25.0, help="variance regularization loss coefficient")
591
593
  parser.add_argument("--cov-coeff", type=float, default=1.0, help="covariance regularization loss coefficient")
592
- parser.add_argument("-t", "--tag", type=str, help="add model tag")
593
594
  training_cli.add_optimization_args(parser)
594
595
  training_cli.add_lr_wd_args(parser)
595
596
  training_cli.add_lr_scheduler_args(parser)
birder/tools/avg_model.py CHANGED
@@ -15,12 +15,15 @@ from birder.net.base import SignatureType
15
15
  logger = logging.getLogger(__name__)
16
16
 
17
17
 
18
- def avg_models(network: str, tag: Optional[str], reparameterized: bool, epochs: list[int], force: bool) -> None:
18
+ # pylint: disable=too-many-locals
19
+ def avg_models(
20
+ network: str, tag: Optional[str], reparameterized: bool, epochs: list[int], accum_dtype: torch.dtype, force: bool
21
+ ) -> None:
19
22
  device = torch.device("cpu")
23
+ network_name = get_network_name(network, tag)
20
24
  state_list = []
21
25
  aux_data = {}
22
26
  for idx, epoch in enumerate(epochs):
23
- network_name = get_network_name(network, tag)
24
27
  path = fs_ops.model_path(network_name, epoch=epoch)
25
28
  logger.info(f"Loading model from {path}...")
26
29
 
@@ -51,12 +54,18 @@ def avg_models(network: str, tag: Optional[str], reparameterized: bool, epochs:
51
54
  logger.info("Calculating averages...")
52
55
  avg_state = {}
53
56
  for state_name in state_list[0]:
54
- params = torch.empty((len(state_list),) + state_list[0][state_name].size())
57
+ t0 = state_list[0][state_name]
58
+ if torch.is_floating_point(t0) is True:
59
+ params = torch.empty((len(state_list),) + t0.size(), dtype=accum_dtype)
55
60
 
56
- for idx, state in enumerate(state_list):
57
- params[idx] = state[state_name]
61
+ for idx, state in enumerate(state_list):
62
+ params[idx] = state[state_name].to(accum_dtype)
58
63
 
59
- avg_state[state_name] = params.mean(axis=0)
64
+ avg_state[state_name] = params.mean(dim=0).to(dtype=t0.dtype)
65
+ else:
66
+ # For int/bool buffers (e.g. num_batches_tracked / relative_position_index), averaging is not meaningful
67
+ logger.info(f"Not averaging non-floating state entry: {state_name} (dtype={t0.dtype})")
68
+ avg_state[state_name] = t0
60
69
 
61
70
  net.load_state_dict(avg_state)
62
71
 
@@ -86,7 +95,7 @@ def set_parser(subparsers: Any) -> None:
86
95
  epilog=(
87
96
  "Usage examples:\n"
88
97
  "python -m birder.tools avg-model --network efficientnet_v2_m --epochs 290 295 300\n"
89
- "python -m birder.tools avg-model --network shufflenet_v2_2_0 --epochs 95 100 100\n"
98
+ "python -m birder.tools avg-model --network shufflenet_v2_2_0 --epochs 95 100 100 --accum-dtype float64\n"
90
99
  ),
91
100
  formatter_class=cli.ArgumentHelpFormatter,
92
101
  )
@@ -98,9 +107,16 @@ def set_parser(subparsers: Any) -> None:
98
107
  subparser.add_argument(
99
108
  "-r", "--reparameterized", default=False, action="store_true", help="load reparameterized model"
100
109
  )
110
+ subparser.add_argument(
111
+ "--accum-dtype",
112
+ choices=["float32", "float64"],
113
+ default="float32",
114
+ help="dtype used for averaging floating tensors",
115
+ )
101
116
  subparser.add_argument("--force", action="store_true", help="override existing model")
102
117
  subparser.set_defaults(func=main)
103
118
 
104
119
 
105
120
  def main(args: argparse.Namespace) -> None:
106
- avg_models(args.network, args.tag, args.reparameterized, args.epochs, args.force)
121
+ accum_dtype: torch.dtype = getattr(torch, args.accum_dtype)
122
+ avg_models(args.network, args.tag, args.reparameterized, args.epochs, accum_dtype, args.force)
@@ -10,6 +10,7 @@ from birder.common import fs_ops
10
10
  from birder.common import lib
11
11
  from birder.data.transforms.classification import inference_preset
12
12
  from birder.introspection import AttentionRollout
13
+ from birder.introspection import FeaturePCA
13
14
  from birder.introspection import GradCAM
14
15
  from birder.introspection import GuidedBackprop
15
16
  from birder.introspection import TransformerAttribution
@@ -23,10 +24,7 @@ def _nhwc_reshape_transform(tensor: torch.Tensor) -> torch.Tensor:
23
24
 
24
25
 
25
26
  def _show_attn_rollout(
26
- args: argparse.Namespace,
27
- net: BaseNet,
28
- transform: Callable[..., torch.Tensor],
29
- device: torch.device,
27
+ args: argparse.Namespace, net: BaseNet, transform: Callable[..., torch.Tensor], device: torch.device
30
28
  ) -> None:
31
29
  ar = AttentionRollout(net, device, transform, args.attn_layer_name, args.discard_ratio, args.head_fusion)
32
30
  result = ar(args.image_path)
@@ -92,6 +90,16 @@ def _show_grad_cam(
92
90
  result.show()
93
91
 
94
92
 
93
+ def _show_feature_pca(
94
+ args: argparse.Namespace, net: BaseNet, transform: Callable[..., torch.Tensor], device: torch.device
95
+ ) -> None:
96
+ feature_pca = FeaturePCA(
97
+ net, device, transform, args.normalize_features, channels_last=args.channels_last, stage=args.stage
98
+ )
99
+ result = feature_pca(args.image_path)
100
+ result.show()
101
+
102
+
95
103
  def set_parser(subparsers: Any) -> None:
96
104
  subparser = subparsers.add_parser(
97
105
  "introspection",
@@ -102,6 +110,8 @@ def set_parser(subparsers: Any) -> None:
102
110
  "Usage examples:\n"
103
111
  "python -m birder.tools introspection --network efficientnet_v2_m -e 200 --method gradcam "
104
112
  "'data/training/European goldfinch/000300.jpeg'\n"
113
+ "python -m birder.tools introspection -n convnext_v2_tiny -t vicreg --method feature-pca "
114
+ "--normalize-features --stage stage2 data/validation/Mallard/000015.jpeg\n"
105
115
  "python -m birder.tools introspection -n resnest_50 --epoch 300 --method gradcam "
106
116
  "data/index5.jpeg --target 'Grey heron'\n"
107
117
  "python -m birder.tools introspection -n efficientnet_v2_s --method guided-backprop "
@@ -126,7 +136,7 @@ def set_parser(subparsers: Any) -> None:
126
136
  subparser.add_argument(
127
137
  "--method",
128
138
  type=str,
129
- choices=["gradcam", "guided-backprop", "attn-rollout", "transformer-attribution"],
139
+ choices=["attn-rollout", "feature-pca", "gradcam", "guided-backprop", "transformer-attribution"],
130
140
  help="introspection method",
131
141
  )
132
142
  subparser.add_argument(
@@ -142,7 +152,21 @@ def set_parser(subparsers: Any) -> None:
142
152
  "--layer-num", type=int, default=-1, help="target layer, index for target block (gradcam only)"
143
153
  )
144
154
  subparser.add_argument(
145
- "--channels-last", default=False, action="store_true", help="channels last model, like swin (gradcam only)"
155
+ "--channels-last",
156
+ default=False,
157
+ action="store_true",
158
+ help="channels last model, like swin (gradcam and feature-pca)",
159
+ )
160
+ subparser.add_argument(
161
+ "--normalize-features",
162
+ default=False,
163
+ action="store_true",
164
+ help="normalize feature vectors before PCA (feature-pca only)",
165
+ )
166
+ subparser.add_argument(
167
+ "--stage",
168
+ type=str,
169
+ help="stage to visualize, e.g., 'stage1', 'neck', etc. (feature-pca only, defaults to last stage)",
146
170
  )
147
171
  subparser.add_argument(
148
172
  "--attn-layer-name",
@@ -193,11 +217,13 @@ def main(args: argparse.Namespace) -> None:
193
217
 
194
218
  transform = inference_preset(args.size, model_info.rgb_stats, 1.0)
195
219
 
196
- if args.method == "gradcam":
220
+ if args.method == "attn-rollout":
221
+ _show_attn_rollout(args, net, transform, device)
222
+ elif args.method == "feature-pca":
223
+ _show_feature_pca(args, net, transform, device)
224
+ elif args.method == "gradcam":
197
225
  _show_grad_cam(args, net, model_info.class_to_idx, transform, device)
198
226
  elif args.method == "guided-backprop":
199
227
  _show_guided_backprop(args, net, model_info.class_to_idx, transform, device)
200
- elif args.method == "attn-rollout":
201
- _show_attn_rollout(args, net, transform, device)
202
228
  elif args.method == "transformer-attribution":
203
229
  _show_transformer_attribution(args, net, model_info.class_to_idx, transform, device)
@@ -140,10 +140,16 @@ def show_iterator(args: argparse.Namespace) -> None:
140
140
  mask_size = (args.size[0] // args.patch_size, args.size[1] // args.patch_size)
141
141
  mask_generator: Optional[masking.Masking]
142
142
  if args.masking == "uniform":
143
- mask_generator = masking.UniformMasking(mask_size, args.mask_ratio)
143
+ mask_generator = masking.UniformMasking(mask_size, args.mask_ratio, min_mask_size=args.min_mask_size)
144
144
  elif args.masking == "block":
145
145
  max_patches = int(args.mask_ratio * mask_size[0] * mask_size[1])
146
146
  mask_generator = masking.BlockMasking(mask_size, 4, max_patches, 0.33, 3.33)
147
+ elif args.masking == "roll-block":
148
+ num_masking_patches = int(args.mask_ratio * mask_size[0] * mask_size[1])
149
+ mask_generator = masking.RollBlockMasking(mask_size, num_masking_patches=num_masking_patches)
150
+ elif args.masking == "inverse-roll":
151
+ num_masking_patches = int(args.mask_ratio * mask_size[0] * mask_size[1])
152
+ mask_generator = masking.InverseRollBlockMasking(mask_size, num_masking_patches=num_masking_patches)
147
153
  else:
148
154
  mask_generator = None
149
155
 
@@ -187,7 +193,7 @@ def set_parser(subparsers: Any) -> None:
187
193
  "python -m birder.tools show-iterator --mode training --size 224 --batch --wds "
188
194
  "--wds-class-file ~/Datasets/imagenet-1k-wds/classes.txt --wds-size 50000 "
189
195
  "--data-path ~/Datasets/imagenet-1k-wds/validation\n"
190
- "python -m birder.tools show-iterator --mode training --batch --size 224 --aug-level 6 --masking uniform\n"
196
+ "python -m birder.tools show-iterator --mode training --batch --size 224 --aug-level 1 --masking uniform\n"
191
197
  "python -m birder.tools show-iterator --mode training --size 224 --batch --wds "
192
198
  "--data-path data/training_packed\n"
193
199
  "python -m birder.tools show-iterator --mode training --batch --mixup-alpha 0.8 --cutmix "
@@ -206,8 +212,16 @@ def set_parser(subparsers: Any) -> None:
206
212
  )
207
213
  subparser.add_argument("--mixup-alpha", type=float, help="mixup alpha")
208
214
  subparser.add_argument("--cutmix", default=False, action="store_true", help="enable cutmix")
209
- subparser.add_argument("--masking", type=str, choices=["uniform", "block"], help="enable masking")
215
+ subparser.add_argument(
216
+ "--masking",
217
+ type=str,
218
+ choices=["uniform", "block", "roll-block", "inverse-roll"],
219
+ help="masking strategy to apply",
220
+ )
210
221
  subparser.add_argument("--mask-ratio", type=float, default=0.5, help="mask ratio")
222
+ subparser.add_argument(
223
+ "--min-mask-size", type=int, default=1, help="minimum mask unit size in patches (uniform only)"
224
+ )
211
225
  subparser.add_argument("--patch-size", type=int, default=16, help="mask base patch size")
212
226
  subparser.add_argument(
213
227
  "--data-path", type=str, default=str(settings.TRAINING_DATA_PATH), help="image directory path"
birder/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "v0.3.1"
1
+ __version__ = "v0.3.3"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: birder
3
- Version: 0.3.1
3
+ Version: 0.3.3
4
4
  Summary: An open-source computer vision framework for wildlife image analysis, featuring state-of-the-art models for species classification and detection.
5
5
  Author: Ofer Hasson
6
6
  License-Expression: Apache-2.0