birder 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/adversarial/deepfool.py +2 -0
- birder/adversarial/simba.py +2 -0
- birder/common/masking.py +13 -4
- birder/inference/classification.py +1 -1
- birder/introspection/__init__.py +2 -0
- birder/introspection/base.py +0 -7
- birder/introspection/feature_pca.py +101 -0
- birder/kernels/soft_nms/soft_nms.cpp +5 -2
- birder/model_registry/model_registry.py +3 -2
- birder/net/convnext_v1.py +20 -0
- birder/net/fastvit.py +0 -1
- birder/net/flexivit.py +5 -0
- birder/net/focalnet.py +0 -1
- birder/net/hiera.py +3 -3
- birder/net/hieradet.py +116 -28
- birder/net/rope_flexivit.py +7 -0
- birder/net/rope_vit.py +49 -4
- birder/net/smt.py +0 -1
- birder/net/ssl/ibot.py +0 -1
- birder/net/vit.py +166 -2
- birder/scripts/train.py +24 -21
- birder/scripts/train_barlow_twins.py +4 -3
- birder/scripts/train_byol.py +4 -3
- birder/scripts/train_capi.py +6 -5
- birder/scripts/train_data2vec.py +4 -3
- birder/scripts/train_data2vec2.py +4 -3
- birder/scripts/train_detection.py +7 -5
- birder/scripts/train_dino_v1.py +5 -4
- birder/scripts/train_dino_v2.py +69 -20
- birder/scripts/train_dino_v2_dist.py +70 -21
- birder/scripts/train_franca.py +8 -7
- birder/scripts/train_i_jepa.py +4 -3
- birder/scripts/train_ibot.py +5 -4
- birder/scripts/train_kd.py +25 -24
- birder/scripts/train_mim.py +4 -3
- birder/scripts/train_mmcr.py +4 -3
- birder/scripts/train_rotnet.py +5 -4
- birder/scripts/train_simclr.py +4 -3
- birder/scripts/train_vicreg.py +4 -3
- birder/tools/avg_model.py +24 -8
- birder/tools/introspection.py +35 -9
- birder/tools/show_iterator.py +17 -3
- birder/version.py +1 -1
- {birder-0.3.1.dist-info → birder-0.3.3.dist-info}/METADATA +1 -1
- {birder-0.3.1.dist-info → birder-0.3.3.dist-info}/RECORD +49 -48
- {birder-0.3.1.dist-info → birder-0.3.3.dist-info}/WHEEL +0 -0
- {birder-0.3.1.dist-info → birder-0.3.3.dist-info}/entry_points.txt +0 -0
- {birder-0.3.1.dist-info → birder-0.3.3.dist-info}/licenses/LICENSE +0 -0
- {birder-0.3.1.dist-info → birder-0.3.3.dist-info}/top_level.txt +0 -0
birder/scripts/train_kd.py
CHANGED
|
@@ -356,7 +356,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
356
356
|
|
|
357
357
|
# Distillation
|
|
358
358
|
if distillation_type == "soft":
|
|
359
|
-
distillation_criterion = torch.nn.KLDivLoss(reduction="batchmean", log_target=
|
|
359
|
+
distillation_criterion = torch.nn.KLDivLoss(reduction="batchmean", log_target=True)
|
|
360
360
|
elif distillation_type == "hard":
|
|
361
361
|
distillation_criterion = torch.nn.CrossEntropyLoss()
|
|
362
362
|
elif distillation_type == "deit":
|
|
@@ -567,6 +567,11 @@ def train(args: argparse.Namespace) -> None:
|
|
|
567
567
|
if virtual_epoch_mode is True:
|
|
568
568
|
train_iter = iter(training_loader)
|
|
569
569
|
|
|
570
|
+
running_loss = training_utils.SmoothedValue(window_size=64)
|
|
571
|
+
running_val_loss = training_utils.SmoothedValue()
|
|
572
|
+
train_accuracy = training_utils.SmoothedValue(window_size=64)
|
|
573
|
+
val_accuracy = training_utils.SmoothedValue()
|
|
574
|
+
|
|
570
575
|
logger.info(f"Starting training with learning rate of {last_lr}")
|
|
571
576
|
for epoch in range(begin_epoch, args.stop_epoch):
|
|
572
577
|
tic = time.time()
|
|
@@ -574,11 +579,6 @@ def train(args: argparse.Namespace) -> None:
|
|
|
574
579
|
if embedding_projection is not None:
|
|
575
580
|
embedding_projection.train()
|
|
576
581
|
|
|
577
|
-
running_loss = training_utils.SmoothedValue(window_size=64)
|
|
578
|
-
running_val_loss = training_utils.SmoothedValue()
|
|
579
|
-
train_accuracy = training_utils.SmoothedValue(window_size=64)
|
|
580
|
-
val_accuracy = training_utils.SmoothedValue()
|
|
581
|
-
|
|
582
582
|
if args.distributed is True or virtual_epoch_mode is True:
|
|
583
583
|
train_sampler.set_epoch(epoch)
|
|
584
584
|
|
|
@@ -625,7 +625,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
625
625
|
with torch.no_grad():
|
|
626
626
|
teacher_outputs = teacher(inputs)
|
|
627
627
|
if distillation_type == "soft":
|
|
628
|
-
teacher_targets = F.
|
|
628
|
+
teacher_targets = F.log_softmax(teacher_outputs / args.temperature, dim=-1)
|
|
629
629
|
else:
|
|
630
630
|
teacher_targets = teacher_outputs.argmax(dim=-1)
|
|
631
631
|
|
|
@@ -695,7 +695,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
695
695
|
train_accuracy.update(training_utils.accuracy(targets, outputs.detach()))
|
|
696
696
|
|
|
697
697
|
# Write statistics
|
|
698
|
-
if i % args.log_interval == 0 or i == last_batch_idx:
|
|
698
|
+
if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
|
|
699
699
|
time_now = time.time()
|
|
700
700
|
time_cost = time_now - start_time
|
|
701
701
|
iters_processed_in_interval = i - last_idx
|
|
@@ -900,43 +900,44 @@ def get_args_parser() -> argparse.ArgumentParser:
|
|
|
900
900
|
"A typical 'soft' distillation:\n"
|
|
901
901
|
"torchrun --nproc_per_node=2 train_kd.py \\\n"
|
|
902
902
|
" --type soft \\\n"
|
|
903
|
-
" --temperature 1 \\\n"
|
|
904
903
|
" --teacher vit_l16 \\\n"
|
|
905
904
|
" --student tiny_vit_5m \\\n"
|
|
905
|
+
" --temperature 3.5 \\\n"
|
|
906
|
+
" --batch-size 32 \\\n"
|
|
906
907
|
" --opt adamw \\\n"
|
|
908
|
+
" --clip-grad-norm 5 \\\n"
|
|
907
909
|
" --lr 0.002 \\\n"
|
|
910
|
+
" --wd 0.01 \\\n"
|
|
911
|
+
" --norm-wd 0 \\\n"
|
|
908
912
|
" --lr-scheduler cosine \\\n"
|
|
909
913
|
" --lr-cosine-min 1e-7 \\\n"
|
|
910
|
-
" --batch-size 64 \\\n"
|
|
911
914
|
" --warmup-epochs 5 \\\n"
|
|
912
|
-
" --wd 0.01 \\\n"
|
|
913
|
-
" --norm-wd 0 \\\n"
|
|
914
915
|
" --smoothing-alpha 0.1 \\\n"
|
|
915
|
-
" --
|
|
916
|
-
" --amp \\\n"
|
|
916
|
+
" --amp --amp-dtype bfloat16 \\\n"
|
|
917
917
|
" --compile \\\n"
|
|
918
918
|
" --wds \\\n"
|
|
919
|
-
" --wds-
|
|
920
|
-
" --wds-
|
|
919
|
+
" --wds-info data/intermediate_packed/_info.json \\\n"
|
|
920
|
+
" --wds-class-file data/intermediate_packed/classes.txt\n"
|
|
921
921
|
"\n"
|
|
922
|
-
"DeiT
|
|
922
|
+
"DeiT-style distillation:\n"
|
|
923
923
|
"torchrun --nproc_per_node=2 train_kd.py \\\n"
|
|
924
924
|
" --type deit \\\n"
|
|
925
925
|
" --teacher regnet_y_8g \\\n"
|
|
926
926
|
" --student deit_s16 \\\n"
|
|
927
|
+
" --batch-size 64 \\\n"
|
|
927
928
|
" --opt adamw \\\n"
|
|
929
|
+
" --clip-grad-norm 1 \\\n"
|
|
928
930
|
" --lr 0.0005 \\\n"
|
|
929
|
-
" --lr-scheduler cosine \\\n"
|
|
930
|
-
" --warmup-epochs 5 \\\n"
|
|
931
|
-
" --epochs 300 \\\n"
|
|
932
931
|
" --wd 0.05 \\\n"
|
|
933
932
|
" --norm-wd 0 \\\n"
|
|
933
|
+
" --lr-scheduler cosine \\\n"
|
|
934
|
+
" --epochs 300 \\\n"
|
|
935
|
+
" --warmup-epochs 5 \\\n"
|
|
936
|
+
" --aug-level 8 \\\n"
|
|
934
937
|
" --smoothing-alpha 0.1 \\\n"
|
|
935
938
|
" --mixup-alpha 0.8 \\\n"
|
|
936
|
-
" --aug-level 8 \\\n"
|
|
937
939
|
" --model-ema \\\n"
|
|
938
940
|
" --ra-sampler --ra-reps 2 \\\n"
|
|
939
|
-
" --clip-grad-norm 1 \\\n"
|
|
940
941
|
" --amp \\\n"
|
|
941
942
|
" --compile\n"
|
|
942
943
|
),
|
|
@@ -944,6 +945,7 @@ def get_args_parser() -> argparse.ArgumentParser:
|
|
|
944
945
|
)
|
|
945
946
|
parser.add_argument("--type", type=str, choices=typing.get_args(DistType), help="type of distillation")
|
|
946
947
|
parser.add_argument("--teacher", type=str, help="the teacher network")
|
|
948
|
+
parser.add_argument("--teacher-tag", type=str, help="teacher training log tag (loading only)")
|
|
947
949
|
parser.add_argument(
|
|
948
950
|
"--teacher-model-config",
|
|
949
951
|
action=cli.FlexibleDictAction,
|
|
@@ -952,11 +954,11 @@ def get_args_parser() -> argparse.ArgumentParser:
|
|
|
952
954
|
"('drop_path_rate=0.2' or '{\"units\": [3, 24, 36, 3], \"dropout\": 0.2}'"
|
|
953
955
|
),
|
|
954
956
|
)
|
|
955
|
-
parser.add_argument("--teacher-tag", type=str, help="teacher training log tag (loading only)")
|
|
956
957
|
parser.add_argument("--pts", default=False, action="store_true", help="load torchscript teacher")
|
|
957
958
|
parser.add_argument("--pt2", default=False, action="store_true", help="load pt2 teacher")
|
|
958
959
|
parser.add_argument("--teacher-epoch", type=int, help="load teacher weights from selected epoch")
|
|
959
960
|
parser.add_argument("--student", type=str, help="the student network to train")
|
|
961
|
+
parser.add_argument("--student-tag", type=str, help="add student training logs tag")
|
|
960
962
|
parser.add_argument(
|
|
961
963
|
"--student-model-config",
|
|
962
964
|
action=cli.FlexibleDictAction,
|
|
@@ -965,7 +967,6 @@ def get_args_parser() -> argparse.ArgumentParser:
|
|
|
965
967
|
"('drop_path_rate=0.2' or '{\"units\": [3, 24, 36, 3], \"dropout\": 0.2}'"
|
|
966
968
|
),
|
|
967
969
|
)
|
|
968
|
-
parser.add_argument("--student-tag", type=str, help="add student training logs tag")
|
|
969
970
|
parser.add_argument(
|
|
970
971
|
"--temperature",
|
|
971
972
|
type=float,
|
birder/scripts/train_mim.py
CHANGED
|
@@ -368,11 +368,12 @@ def train(args: argparse.Namespace) -> None:
|
|
|
368
368
|
if virtual_epoch_mode is True:
|
|
369
369
|
train_iter = iter(training_loader)
|
|
370
370
|
|
|
371
|
+
running_loss = training_utils.SmoothedValue()
|
|
372
|
+
|
|
371
373
|
logger.info(f"Starting training with learning rate of {last_lr}")
|
|
372
374
|
for epoch in range(begin_epoch, args.stop_epoch):
|
|
373
375
|
tic = time.time()
|
|
374
376
|
net.train()
|
|
375
|
-
running_loss = training_utils.SmoothedValue()
|
|
376
377
|
|
|
377
378
|
if args.distributed is True or virtual_epoch_mode is True:
|
|
378
379
|
train_sampler.set_epoch(epoch)
|
|
@@ -436,7 +437,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
436
437
|
running_loss.update(loss.detach())
|
|
437
438
|
|
|
438
439
|
# Write statistics
|
|
439
|
-
if i % args.log_interval == 0 or i == last_batch_idx:
|
|
440
|
+
if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
|
|
440
441
|
time_now = time.time()
|
|
441
442
|
time_cost = time_now - start_time
|
|
442
443
|
iters_processed_in_interval = i - last_idx
|
|
@@ -578,6 +579,7 @@ def get_args_parser() -> argparse.ArgumentParser:
|
|
|
578
579
|
formatter_class=cli.ArgumentHelpFormatter,
|
|
579
580
|
)
|
|
580
581
|
parser.add_argument("-n", "--network", type=str, help="the neural network to use")
|
|
582
|
+
parser.add_argument("-t", "--tag", type=str, help="add model tag")
|
|
581
583
|
parser.add_argument(
|
|
582
584
|
"--model-config",
|
|
583
585
|
action=cli.FlexibleDictAction,
|
|
@@ -586,7 +588,6 @@ def get_args_parser() -> argparse.ArgumentParser:
|
|
|
586
588
|
"('drop_path_rate=0.2' or '{\"units\": [3, 24, 36, 3], \"dropout\": 0.2}'"
|
|
587
589
|
),
|
|
588
590
|
)
|
|
589
|
-
parser.add_argument("-t", "--tag", type=str, help="add model tag")
|
|
590
591
|
parser.add_argument("--encoder", type=str, help="the neural network to used as encoder (network being pre-trained)")
|
|
591
592
|
parser.add_argument(
|
|
592
593
|
"--encoder-model-config",
|
birder/scripts/train_mmcr.py
CHANGED
|
@@ -370,11 +370,12 @@ def train(args: argparse.Namespace) -> None:
|
|
|
370
370
|
if virtual_epoch_mode is True:
|
|
371
371
|
train_iter = iter(training_loader)
|
|
372
372
|
|
|
373
|
+
running_loss = training_utils.SmoothedValue()
|
|
374
|
+
|
|
373
375
|
logger.info(f"Starting training with learning rate of {last_lr}")
|
|
374
376
|
for epoch in range(begin_epoch, args.stop_epoch):
|
|
375
377
|
tic = time.time()
|
|
376
378
|
net.train()
|
|
377
|
-
running_loss = training_utils.SmoothedValue()
|
|
378
379
|
|
|
379
380
|
if args.distributed is True or virtual_epoch_mode is True:
|
|
380
381
|
train_sampler.set_epoch(epoch)
|
|
@@ -447,7 +448,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
447
448
|
running_loss.update(loss.detach())
|
|
448
449
|
|
|
449
450
|
# Write statistics
|
|
450
|
-
if i % args.log_interval == 0 or i == last_batch_idx:
|
|
451
|
+
if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
|
|
451
452
|
time_now = time.time()
|
|
452
453
|
time_cost = time_now - start_time
|
|
453
454
|
iters_processed_in_interval = i - last_idx
|
|
@@ -587,6 +588,7 @@ def get_args_parser() -> argparse.ArgumentParser:
|
|
|
587
588
|
formatter_class=cli.ArgumentHelpFormatter,
|
|
588
589
|
)
|
|
589
590
|
parser.add_argument("-n", "--network", type=str, help="the neural network to train")
|
|
591
|
+
parser.add_argument("-t", "--tag", type=str, help="add model tag")
|
|
590
592
|
parser.add_argument(
|
|
591
593
|
"--model-config",
|
|
592
594
|
action=cli.FlexibleDictAction,
|
|
@@ -606,7 +608,6 @@ def get_args_parser() -> argparse.ArgumentParser:
|
|
|
606
608
|
parser.add_argument("--lambda-coeff", type=float, default=0.0, help="weight of local nuc")
|
|
607
609
|
parser.add_argument("--n-aug", type=int, default=2, help="number of views")
|
|
608
610
|
parser.add_argument("--momentum-tau", type=float, default=0.99, help="base EMA parameter for momentum update")
|
|
609
|
-
parser.add_argument("-t", "--tag", type=str, help="add model tag")
|
|
610
611
|
training_cli.add_optimization_args(parser)
|
|
611
612
|
training_cli.add_lr_wd_args(parser)
|
|
612
613
|
training_cli.add_lr_scheduler_args(parser)
|
birder/scripts/train_rotnet.py
CHANGED
|
@@ -381,12 +381,13 @@ def train(args: argparse.Namespace) -> None:
|
|
|
381
381
|
if virtual_epoch_mode is True:
|
|
382
382
|
train_iter = iter(training_loader)
|
|
383
383
|
|
|
384
|
+
running_loss = training_utils.SmoothedValue(window_size=64)
|
|
385
|
+
train_accuracy = training_utils.SmoothedValue(window_size=64)
|
|
386
|
+
|
|
384
387
|
logger.info(f"Starting training with learning rate of {last_lr}")
|
|
385
388
|
for epoch in range(begin_epoch, args.stop_epoch):
|
|
386
389
|
tic = time.time()
|
|
387
390
|
net.train()
|
|
388
|
-
running_loss = training_utils.SmoothedValue(window_size=64)
|
|
389
|
-
train_accuracy = training_utils.SmoothedValue(window_size=64)
|
|
390
391
|
|
|
391
392
|
if args.distributed is True or virtual_epoch_mode is True:
|
|
392
393
|
train_sampler.set_epoch(epoch)
|
|
@@ -455,7 +456,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
455
456
|
train_accuracy.update(training_utils.accuracy(targets, outputs.detach()))
|
|
456
457
|
|
|
457
458
|
# Write statistics
|
|
458
|
-
if i % args.log_interval == 0 or i == last_batch_idx:
|
|
459
|
+
if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
|
|
459
460
|
time_now = time.time()
|
|
460
461
|
time_cost = time_now - start_time
|
|
461
462
|
iters_processed_in_interval = i - last_idx
|
|
@@ -580,6 +581,7 @@ def get_args_parser() -> argparse.ArgumentParser:
|
|
|
580
581
|
formatter_class=cli.ArgumentHelpFormatter,
|
|
581
582
|
)
|
|
582
583
|
parser.add_argument("-n", "--network", type=str, help="the neural network to train")
|
|
584
|
+
parser.add_argument("-t", "--tag", type=str, help="add model tag")
|
|
583
585
|
parser.add_argument(
|
|
584
586
|
"--model-config",
|
|
585
587
|
action=cli.FlexibleDictAction,
|
|
@@ -594,7 +596,6 @@ def get_args_parser() -> argparse.ArgumentParser:
|
|
|
594
596
|
default=0.75,
|
|
595
597
|
help="probability of applying a non-zero rotation (90, 180, or 270 degrees)",
|
|
596
598
|
)
|
|
597
|
-
parser.add_argument("-t", "--tag", type=str, help="add model tag")
|
|
598
599
|
parser.add_argument(
|
|
599
600
|
"--freeze-body",
|
|
600
601
|
default=False,
|
birder/scripts/train_simclr.py
CHANGED
|
@@ -363,11 +363,12 @@ def train(args: argparse.Namespace) -> None:
|
|
|
363
363
|
if virtual_epoch_mode is True:
|
|
364
364
|
train_iter = iter(training_loader)
|
|
365
365
|
|
|
366
|
+
running_loss = training_utils.SmoothedValue()
|
|
367
|
+
|
|
366
368
|
logger.info(f"Starting training with learning rate of {last_lr}")
|
|
367
369
|
for epoch in range(begin_epoch, args.stop_epoch):
|
|
368
370
|
tic = time.time()
|
|
369
371
|
net.train()
|
|
370
|
-
running_loss = training_utils.SmoothedValue()
|
|
371
372
|
|
|
372
373
|
if args.distributed is True or virtual_epoch_mode is True:
|
|
373
374
|
train_sampler.set_epoch(epoch)
|
|
@@ -431,7 +432,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
431
432
|
running_loss.update(loss.detach())
|
|
432
433
|
|
|
433
434
|
# Write statistics
|
|
434
|
-
if i % args.log_interval == 0 or i == last_batch_idx:
|
|
435
|
+
if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
|
|
435
436
|
time_now = time.time()
|
|
436
437
|
time_cost = time_now - start_time
|
|
437
438
|
iters_processed_in_interval = i - last_idx
|
|
@@ -572,6 +573,7 @@ def get_args_parser() -> argparse.ArgumentParser:
|
|
|
572
573
|
formatter_class=cli.ArgumentHelpFormatter,
|
|
573
574
|
)
|
|
574
575
|
parser.add_argument("-n", "--network", type=str, help="the neural network to train")
|
|
576
|
+
parser.add_argument("-t", "--tag", type=str, help="add model tag")
|
|
575
577
|
parser.add_argument(
|
|
576
578
|
"--model-config",
|
|
577
579
|
action=cli.FlexibleDictAction,
|
|
@@ -582,7 +584,6 @@ def get_args_parser() -> argparse.ArgumentParser:
|
|
|
582
584
|
)
|
|
583
585
|
parser.add_argument("--projection-dim", type=int, default=128, metavar="DIM", help="projection dim")
|
|
584
586
|
parser.add_argument("--temperature", type=float, default=0.1, help="loss temperature")
|
|
585
|
-
parser.add_argument("-t", "--tag", type=str, help="add model tag")
|
|
586
587
|
training_cli.add_optimization_args(parser)
|
|
587
588
|
training_cli.add_lr_wd_args(parser)
|
|
588
589
|
training_cli.add_lr_scheduler_args(parser)
|
birder/scripts/train_vicreg.py
CHANGED
|
@@ -369,11 +369,12 @@ def train(args: argparse.Namespace) -> None:
|
|
|
369
369
|
if virtual_epoch_mode is True:
|
|
370
370
|
train_iter = iter(training_loader)
|
|
371
371
|
|
|
372
|
+
running_loss = training_utils.SmoothedValue()
|
|
373
|
+
|
|
372
374
|
logger.info(f"Starting training with learning rate of {last_lr}")
|
|
373
375
|
for epoch in range(begin_epoch, args.stop_epoch):
|
|
374
376
|
tic = time.time()
|
|
375
377
|
net.train()
|
|
376
|
-
running_loss = training_utils.SmoothedValue()
|
|
377
378
|
|
|
378
379
|
if args.distributed is True or virtual_epoch_mode is True:
|
|
379
380
|
train_sampler.set_epoch(epoch)
|
|
@@ -437,7 +438,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
437
438
|
running_loss.update(loss.detach())
|
|
438
439
|
|
|
439
440
|
# Write statistics
|
|
440
|
-
if i % args.log_interval == 0 or i == last_batch_idx:
|
|
441
|
+
if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
|
|
441
442
|
time_now = time.time()
|
|
442
443
|
time_cost = time_now - start_time
|
|
443
444
|
iters_processed_in_interval = i - last_idx
|
|
@@ -577,6 +578,7 @@ def get_args_parser() -> argparse.ArgumentParser:
|
|
|
577
578
|
formatter_class=cli.ArgumentHelpFormatter,
|
|
578
579
|
)
|
|
579
580
|
parser.add_argument("-n", "--network", type=str, help="the neural network to train")
|
|
581
|
+
parser.add_argument("-t", "--tag", type=str, help="add model tag")
|
|
580
582
|
parser.add_argument(
|
|
581
583
|
"--model-config",
|
|
582
584
|
action=cli.FlexibleDictAction,
|
|
@@ -589,7 +591,6 @@ def get_args_parser() -> argparse.ArgumentParser:
|
|
|
589
591
|
parser.add_argument("--sim-coeff", type=float, default=25.0, help="invariance regularization loss coefficient")
|
|
590
592
|
parser.add_argument("--std-coeff", type=float, default=25.0, help="variance regularization loss coefficient")
|
|
591
593
|
parser.add_argument("--cov-coeff", type=float, default=1.0, help="covariance regularization loss coefficient")
|
|
592
|
-
parser.add_argument("-t", "--tag", type=str, help="add model tag")
|
|
593
594
|
training_cli.add_optimization_args(parser)
|
|
594
595
|
training_cli.add_lr_wd_args(parser)
|
|
595
596
|
training_cli.add_lr_scheduler_args(parser)
|
birder/tools/avg_model.py
CHANGED
|
@@ -15,12 +15,15 @@ from birder.net.base import SignatureType
|
|
|
15
15
|
logger = logging.getLogger(__name__)
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
|
|
18
|
+
# pylint: disable=too-many-locals
|
|
19
|
+
def avg_models(
|
|
20
|
+
network: str, tag: Optional[str], reparameterized: bool, epochs: list[int], accum_dtype: torch.dtype, force: bool
|
|
21
|
+
) -> None:
|
|
19
22
|
device = torch.device("cpu")
|
|
23
|
+
network_name = get_network_name(network, tag)
|
|
20
24
|
state_list = []
|
|
21
25
|
aux_data = {}
|
|
22
26
|
for idx, epoch in enumerate(epochs):
|
|
23
|
-
network_name = get_network_name(network, tag)
|
|
24
27
|
path = fs_ops.model_path(network_name, epoch=epoch)
|
|
25
28
|
logger.info(f"Loading model from {path}...")
|
|
26
29
|
|
|
@@ -51,12 +54,18 @@ def avg_models(network: str, tag: Optional[str], reparameterized: bool, epochs:
|
|
|
51
54
|
logger.info("Calculating averages...")
|
|
52
55
|
avg_state = {}
|
|
53
56
|
for state_name in state_list[0]:
|
|
54
|
-
|
|
57
|
+
t0 = state_list[0][state_name]
|
|
58
|
+
if torch.is_floating_point(t0) is True:
|
|
59
|
+
params = torch.empty((len(state_list),) + t0.size(), dtype=accum_dtype)
|
|
55
60
|
|
|
56
|
-
|
|
57
|
-
|
|
61
|
+
for idx, state in enumerate(state_list):
|
|
62
|
+
params[idx] = state[state_name].to(accum_dtype)
|
|
58
63
|
|
|
59
|
-
|
|
64
|
+
avg_state[state_name] = params.mean(dim=0).to(dtype=t0.dtype)
|
|
65
|
+
else:
|
|
66
|
+
# For int/bool buffers (e.g. num_batches_tracked / relative_position_index), averaging is not meaningful
|
|
67
|
+
logger.info(f"Not averaging non-floating state entry: {state_name} (dtype={t0.dtype})")
|
|
68
|
+
avg_state[state_name] = t0
|
|
60
69
|
|
|
61
70
|
net.load_state_dict(avg_state)
|
|
62
71
|
|
|
@@ -86,7 +95,7 @@ def set_parser(subparsers: Any) -> None:
|
|
|
86
95
|
epilog=(
|
|
87
96
|
"Usage examples:\n"
|
|
88
97
|
"python -m birder.tools avg-model --network efficientnet_v2_m --epochs 290 295 300\n"
|
|
89
|
-
"python -m birder.tools avg-model --network shufflenet_v2_2_0 --epochs 95 100 100\n"
|
|
98
|
+
"python -m birder.tools avg-model --network shufflenet_v2_2_0 --epochs 95 100 100 --accum-dtype float64\n"
|
|
90
99
|
),
|
|
91
100
|
formatter_class=cli.ArgumentHelpFormatter,
|
|
92
101
|
)
|
|
@@ -98,9 +107,16 @@ def set_parser(subparsers: Any) -> None:
|
|
|
98
107
|
subparser.add_argument(
|
|
99
108
|
"-r", "--reparameterized", default=False, action="store_true", help="load reparameterized model"
|
|
100
109
|
)
|
|
110
|
+
subparser.add_argument(
|
|
111
|
+
"--accum-dtype",
|
|
112
|
+
choices=["float32", "float64"],
|
|
113
|
+
default="float32",
|
|
114
|
+
help="dtype used for averaging floating tensors",
|
|
115
|
+
)
|
|
101
116
|
subparser.add_argument("--force", action="store_true", help="override existing model")
|
|
102
117
|
subparser.set_defaults(func=main)
|
|
103
118
|
|
|
104
119
|
|
|
105
120
|
def main(args: argparse.Namespace) -> None:
|
|
106
|
-
|
|
121
|
+
accum_dtype: torch.dtype = getattr(torch, args.accum_dtype)
|
|
122
|
+
avg_models(args.network, args.tag, args.reparameterized, args.epochs, accum_dtype, args.force)
|
birder/tools/introspection.py
CHANGED
|
@@ -10,6 +10,7 @@ from birder.common import fs_ops
|
|
|
10
10
|
from birder.common import lib
|
|
11
11
|
from birder.data.transforms.classification import inference_preset
|
|
12
12
|
from birder.introspection import AttentionRollout
|
|
13
|
+
from birder.introspection import FeaturePCA
|
|
13
14
|
from birder.introspection import GradCAM
|
|
14
15
|
from birder.introspection import GuidedBackprop
|
|
15
16
|
from birder.introspection import TransformerAttribution
|
|
@@ -23,10 +24,7 @@ def _nhwc_reshape_transform(tensor: torch.Tensor) -> torch.Tensor:
|
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
def _show_attn_rollout(
|
|
26
|
-
args: argparse.Namespace,
|
|
27
|
-
net: BaseNet,
|
|
28
|
-
transform: Callable[..., torch.Tensor],
|
|
29
|
-
device: torch.device,
|
|
27
|
+
args: argparse.Namespace, net: BaseNet, transform: Callable[..., torch.Tensor], device: torch.device
|
|
30
28
|
) -> None:
|
|
31
29
|
ar = AttentionRollout(net, device, transform, args.attn_layer_name, args.discard_ratio, args.head_fusion)
|
|
32
30
|
result = ar(args.image_path)
|
|
@@ -92,6 +90,16 @@ def _show_grad_cam(
|
|
|
92
90
|
result.show()
|
|
93
91
|
|
|
94
92
|
|
|
93
|
+
def _show_feature_pca(
|
|
94
|
+
args: argparse.Namespace, net: BaseNet, transform: Callable[..., torch.Tensor], device: torch.device
|
|
95
|
+
) -> None:
|
|
96
|
+
feature_pca = FeaturePCA(
|
|
97
|
+
net, device, transform, args.normalize_features, channels_last=args.channels_last, stage=args.stage
|
|
98
|
+
)
|
|
99
|
+
result = feature_pca(args.image_path)
|
|
100
|
+
result.show()
|
|
101
|
+
|
|
102
|
+
|
|
95
103
|
def set_parser(subparsers: Any) -> None:
|
|
96
104
|
subparser = subparsers.add_parser(
|
|
97
105
|
"introspection",
|
|
@@ -102,6 +110,8 @@ def set_parser(subparsers: Any) -> None:
|
|
|
102
110
|
"Usage examples:\n"
|
|
103
111
|
"python -m birder.tools introspection --network efficientnet_v2_m -e 200 --method gradcam "
|
|
104
112
|
"'data/training/European goldfinch/000300.jpeg'\n"
|
|
113
|
+
"python -m birder.tools introspection -n convnext_v2_tiny -t vicreg --method feature-pca "
|
|
114
|
+
"--normalize-features --stage stage2 data/validation/Mallard/000015.jpeg\n"
|
|
105
115
|
"python -m birder.tools introspection -n resnest_50 --epoch 300 --method gradcam "
|
|
106
116
|
"data/index5.jpeg --target 'Grey heron'\n"
|
|
107
117
|
"python -m birder.tools introspection -n efficientnet_v2_s --method guided-backprop "
|
|
@@ -126,7 +136,7 @@ def set_parser(subparsers: Any) -> None:
|
|
|
126
136
|
subparser.add_argument(
|
|
127
137
|
"--method",
|
|
128
138
|
type=str,
|
|
129
|
-
choices=["
|
|
139
|
+
choices=["attn-rollout", "feature-pca", "gradcam", "guided-backprop", "transformer-attribution"],
|
|
130
140
|
help="introspection method",
|
|
131
141
|
)
|
|
132
142
|
subparser.add_argument(
|
|
@@ -142,7 +152,21 @@ def set_parser(subparsers: Any) -> None:
|
|
|
142
152
|
"--layer-num", type=int, default=-1, help="target layer, index for target block (gradcam only)"
|
|
143
153
|
)
|
|
144
154
|
subparser.add_argument(
|
|
145
|
-
"--channels-last",
|
|
155
|
+
"--channels-last",
|
|
156
|
+
default=False,
|
|
157
|
+
action="store_true",
|
|
158
|
+
help="channels last model, like swin (gradcam and feature-pca)",
|
|
159
|
+
)
|
|
160
|
+
subparser.add_argument(
|
|
161
|
+
"--normalize-features",
|
|
162
|
+
default=False,
|
|
163
|
+
action="store_true",
|
|
164
|
+
help="normalize feature vectors before PCA (feature-pca only)",
|
|
165
|
+
)
|
|
166
|
+
subparser.add_argument(
|
|
167
|
+
"--stage",
|
|
168
|
+
type=str,
|
|
169
|
+
help="stage to visualize, e.g., 'stage1', 'neck', etc. (feature-pca only, defaults to last stage)",
|
|
146
170
|
)
|
|
147
171
|
subparser.add_argument(
|
|
148
172
|
"--attn-layer-name",
|
|
@@ -193,11 +217,13 @@ def main(args: argparse.Namespace) -> None:
|
|
|
193
217
|
|
|
194
218
|
transform = inference_preset(args.size, model_info.rgb_stats, 1.0)
|
|
195
219
|
|
|
196
|
-
if args.method == "
|
|
220
|
+
if args.method == "attn-rollout":
|
|
221
|
+
_show_attn_rollout(args, net, transform, device)
|
|
222
|
+
elif args.method == "feature-pca":
|
|
223
|
+
_show_feature_pca(args, net, transform, device)
|
|
224
|
+
elif args.method == "gradcam":
|
|
197
225
|
_show_grad_cam(args, net, model_info.class_to_idx, transform, device)
|
|
198
226
|
elif args.method == "guided-backprop":
|
|
199
227
|
_show_guided_backprop(args, net, model_info.class_to_idx, transform, device)
|
|
200
|
-
elif args.method == "attn-rollout":
|
|
201
|
-
_show_attn_rollout(args, net, transform, device)
|
|
202
228
|
elif args.method == "transformer-attribution":
|
|
203
229
|
_show_transformer_attribution(args, net, model_info.class_to_idx, transform, device)
|
birder/tools/show_iterator.py
CHANGED
|
@@ -140,10 +140,16 @@ def show_iterator(args: argparse.Namespace) -> None:
|
|
|
140
140
|
mask_size = (args.size[0] // args.patch_size, args.size[1] // args.patch_size)
|
|
141
141
|
mask_generator: Optional[masking.Masking]
|
|
142
142
|
if args.masking == "uniform":
|
|
143
|
-
mask_generator = masking.UniformMasking(mask_size, args.mask_ratio)
|
|
143
|
+
mask_generator = masking.UniformMasking(mask_size, args.mask_ratio, min_mask_size=args.min_mask_size)
|
|
144
144
|
elif args.masking == "block":
|
|
145
145
|
max_patches = int(args.mask_ratio * mask_size[0] * mask_size[1])
|
|
146
146
|
mask_generator = masking.BlockMasking(mask_size, 4, max_patches, 0.33, 3.33)
|
|
147
|
+
elif args.masking == "roll-block":
|
|
148
|
+
num_masking_patches = int(args.mask_ratio * mask_size[0] * mask_size[1])
|
|
149
|
+
mask_generator = masking.RollBlockMasking(mask_size, num_masking_patches=num_masking_patches)
|
|
150
|
+
elif args.masking == "inverse-roll":
|
|
151
|
+
num_masking_patches = int(args.mask_ratio * mask_size[0] * mask_size[1])
|
|
152
|
+
mask_generator = masking.InverseRollBlockMasking(mask_size, num_masking_patches=num_masking_patches)
|
|
147
153
|
else:
|
|
148
154
|
mask_generator = None
|
|
149
155
|
|
|
@@ -187,7 +193,7 @@ def set_parser(subparsers: Any) -> None:
|
|
|
187
193
|
"python -m birder.tools show-iterator --mode training --size 224 --batch --wds "
|
|
188
194
|
"--wds-class-file ~/Datasets/imagenet-1k-wds/classes.txt --wds-size 50000 "
|
|
189
195
|
"--data-path ~/Datasets/imagenet-1k-wds/validation\n"
|
|
190
|
-
"python -m birder.tools show-iterator --mode training --batch --size 224 --aug-level
|
|
196
|
+
"python -m birder.tools show-iterator --mode training --batch --size 224 --aug-level 1 --masking uniform\n"
|
|
191
197
|
"python -m birder.tools show-iterator --mode training --size 224 --batch --wds "
|
|
192
198
|
"--data-path data/training_packed\n"
|
|
193
199
|
"python -m birder.tools show-iterator --mode training --batch --mixup-alpha 0.8 --cutmix "
|
|
@@ -206,8 +212,16 @@ def set_parser(subparsers: Any) -> None:
|
|
|
206
212
|
)
|
|
207
213
|
subparser.add_argument("--mixup-alpha", type=float, help="mixup alpha")
|
|
208
214
|
subparser.add_argument("--cutmix", default=False, action="store_true", help="enable cutmix")
|
|
209
|
-
subparser.add_argument(
|
|
215
|
+
subparser.add_argument(
|
|
216
|
+
"--masking",
|
|
217
|
+
type=str,
|
|
218
|
+
choices=["uniform", "block", "roll-block", "inverse-roll"],
|
|
219
|
+
help="masking strategy to apply",
|
|
220
|
+
)
|
|
210
221
|
subparser.add_argument("--mask-ratio", type=float, default=0.5, help="mask ratio")
|
|
222
|
+
subparser.add_argument(
|
|
223
|
+
"--min-mask-size", type=int, default=1, help="minimum mask unit size in patches (uniform only)"
|
|
224
|
+
)
|
|
211
225
|
subparser.add_argument("--patch-size", type=int, default=16, help="mask base patch size")
|
|
212
226
|
subparser.add_argument(
|
|
213
227
|
"--data-path", type=str, default=str(settings.TRAINING_DATA_PATH), help="image directory path"
|
birder/version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "v0.3.
|
|
1
|
+
__version__ = "v0.3.3"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: birder
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.3
|
|
4
4
|
Summary: An open-source computer vision framework for wildlife image analysis, featuring state-of-the-art models for species classification and detection.
|
|
5
5
|
Author: Ofer Hasson
|
|
6
6
|
License-Expression: Apache-2.0
|