birder 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/adversarial/base.py +1 -1
- birder/adversarial/simba.py +4 -4
- birder/common/cli.py +1 -1
- birder/common/fs_ops.py +13 -13
- birder/common/lib.py +2 -2
- birder/common/masking.py +3 -3
- birder/common/training_cli.py +24 -2
- birder/common/training_utils.py +28 -4
- birder/data/collators/detection.py +9 -1
- birder/data/transforms/detection.py +27 -8
- birder/data/transforms/mosaic.py +1 -1
- birder/datahub/classification.py +3 -3
- birder/inference/classification.py +3 -3
- birder/inference/data_parallel.py +1 -1
- birder/inference/detection.py +5 -5
- birder/inference/wbf.py +1 -1
- birder/introspection/attention_rollout.py +6 -6
- birder/introspection/feature_pca.py +4 -4
- birder/introspection/gradcam.py +1 -1
- birder/introspection/guided_backprop.py +2 -2
- birder/introspection/transformer_attribution.py +4 -4
- birder/layers/attention_pool.py +2 -2
- birder/layers/layer_scale.py +1 -1
- birder/model_registry/model_registry.py +2 -1
- birder/net/__init__.py +4 -10
- birder/net/_rope_vit_configs.py +435 -0
- birder/net/_vit_configs.py +466 -0
- birder/net/alexnet.py +5 -5
- birder/net/base.py +28 -3
- birder/net/biformer.py +18 -17
- birder/net/cait.py +7 -7
- birder/net/cas_vit.py +1 -1
- birder/net/coat.py +27 -27
- birder/net/conv2former.py +3 -3
- birder/net/convmixer.py +1 -1
- birder/net/convnext_v1.py +3 -11
- birder/net/convnext_v1_iso.py +198 -0
- birder/net/convnext_v2.py +2 -10
- birder/net/crossformer.py +9 -9
- birder/net/crossvit.py +6 -6
- birder/net/cspnet.py +1 -1
- birder/net/cswin_transformer.py +10 -10
- birder/net/davit.py +11 -11
- birder/net/deit.py +68 -29
- birder/net/deit3.py +69 -204
- birder/net/densenet.py +9 -8
- birder/net/detection/__init__.py +4 -0
- birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
- birder/net/detection/base.py +6 -5
- birder/net/detection/deformable_detr.py +31 -30
- birder/net/detection/detr.py +14 -11
- birder/net/detection/efficientdet.py +10 -29
- birder/net/detection/faster_rcnn.py +22 -22
- birder/net/detection/fcos.py +8 -8
- birder/net/detection/plain_detr.py +852 -0
- birder/net/detection/retinanet.py +4 -4
- birder/net/detection/rt_detr_v1.py +81 -25
- birder/net/detection/rt_detr_v2.py +1147 -0
- birder/net/detection/ssd.py +5 -5
- birder/net/detection/yolo_v2.py +12 -12
- birder/net/detection/yolo_v3.py +19 -19
- birder/net/detection/yolo_v4.py +16 -16
- birder/net/detection/yolo_v4_tiny.py +3 -3
- birder/net/dpn.py +1 -2
- birder/net/edgenext.py +5 -4
- birder/net/edgevit.py +13 -14
- birder/net/efficientformer_v1.py +3 -2
- birder/net/efficientformer_v2.py +18 -31
- birder/net/efficientnet_v2.py +3 -0
- birder/net/efficientvim.py +9 -9
- birder/net/efficientvit_mit.py +7 -7
- birder/net/efficientvit_msft.py +3 -3
- birder/net/fasternet.py +3 -3
- birder/net/fastvit.py +5 -12
- birder/net/flexivit.py +50 -58
- birder/net/focalnet.py +5 -9
- birder/net/gc_vit.py +11 -11
- birder/net/ghostnet_v1.py +1 -1
- birder/net/ghostnet_v2.py +1 -1
- birder/net/groupmixformer.py +13 -13
- birder/net/hgnet_v1.py +6 -6
- birder/net/hgnet_v2.py +4 -4
- birder/net/hiera.py +6 -6
- birder/net/hieradet.py +9 -9
- birder/net/hornet.py +3 -3
- birder/net/iformer.py +4 -4
- birder/net/inception_next.py +5 -15
- birder/net/inception_resnet_v1.py +3 -3
- birder/net/inception_resnet_v2.py +7 -4
- birder/net/inception_v3.py +3 -0
- birder/net/inception_v4.py +3 -0
- birder/net/levit.py +3 -3
- birder/net/lit_v1.py +13 -15
- birder/net/lit_v1_tiny.py +9 -9
- birder/net/lit_v2.py +14 -15
- birder/net/maxvit.py +11 -23
- birder/net/metaformer.py +5 -5
- birder/net/mim/crossmae.py +6 -6
- birder/net/mim/fcmae.py +3 -5
- birder/net/mim/mae_hiera.py +7 -7
- birder/net/mim/mae_vit.py +4 -6
- birder/net/mim/simmim.py +3 -4
- birder/net/mobilenet_v1.py +0 -9
- birder/net/mobilenet_v2.py +38 -44
- birder/net/{mobilenet_v3_large.py → mobilenet_v3.py} +37 -10
- birder/net/mobilenet_v4_hybrid.py +4 -4
- birder/net/mobileone.py +5 -12
- birder/net/mobilevit_v1.py +7 -34
- birder/net/mobilevit_v2.py +6 -54
- birder/net/moganet.py +8 -5
- birder/net/mvit_v2.py +30 -30
- birder/net/nextvit.py +2 -2
- birder/net/nfnet.py +4 -0
- birder/net/pit.py +11 -26
- birder/net/pvt_v1.py +9 -9
- birder/net/pvt_v2.py +10 -16
- birder/net/regionvit.py +15 -15
- birder/net/regnet.py +1 -1
- birder/net/repghost.py +5 -35
- birder/net/repvgg.py +3 -5
- birder/net/repvit.py +2 -2
- birder/net/resmlp.py +2 -2
- birder/net/resnest.py +4 -1
- birder/net/resnet_v1.py +125 -1
- birder/net/resnet_v2.py +75 -1
- birder/net/resnext.py +35 -1
- birder/net/rope_deit3.py +62 -151
- birder/net/rope_flexivit.py +46 -33
- birder/net/rope_vit.py +44 -758
- birder/net/sequencer2d.py +3 -4
- birder/net/shufflenet_v1.py +1 -1
- birder/net/shufflenet_v2.py +1 -1
- birder/net/simple_vit.py +69 -21
- birder/net/smt.py +8 -8
- birder/net/squeezenet.py +5 -12
- birder/net/squeezenext.py +0 -24
- birder/net/ssl/barlow_twins.py +1 -1
- birder/net/ssl/byol.py +2 -2
- birder/net/ssl/capi.py +4 -4
- birder/net/ssl/data2vec.py +1 -1
- birder/net/ssl/data2vec2.py +1 -1
- birder/net/ssl/dino_v2.py +13 -3
- birder/net/ssl/franca.py +28 -4
- birder/net/ssl/i_jepa.py +5 -5
- birder/net/ssl/ibot.py +1 -1
- birder/net/ssl/mmcr.py +1 -1
- birder/net/swiftformer.py +13 -3
- birder/net/swin_transformer_v1.py +4 -5
- birder/net/swin_transformer_v2.py +5 -8
- birder/net/tiny_vit.py +6 -19
- birder/net/transnext.py +19 -19
- birder/net/uniformer.py +4 -4
- birder/net/van.py +2 -2
- birder/net/vgg.py +1 -10
- birder/net/vit.py +72 -987
- birder/net/vit_parallel.py +35 -20
- birder/net/vit_sam.py +23 -48
- birder/net/vovnet_v2.py +1 -1
- birder/net/xcit.py +16 -13
- birder/ops/msda.py +4 -4
- birder/ops/swattention.py +10 -10
- birder/results/classification.py +3 -3
- birder/results/gui.py +8 -8
- birder/scripts/benchmark.py +37 -12
- birder/scripts/evaluate.py +1 -1
- birder/scripts/predict.py +3 -3
- birder/scripts/predict_detection.py +2 -2
- birder/scripts/train.py +63 -15
- birder/scripts/train_barlow_twins.py +10 -7
- birder/scripts/train_byol.py +10 -7
- birder/scripts/train_capi.py +15 -10
- birder/scripts/train_data2vec.py +10 -7
- birder/scripts/train_data2vec2.py +10 -7
- birder/scripts/train_detection.py +29 -14
- birder/scripts/train_dino_v1.py +13 -9
- birder/scripts/train_dino_v2.py +27 -14
- birder/scripts/train_dino_v2_dist.py +28 -15
- birder/scripts/train_franca.py +16 -9
- birder/scripts/train_i_jepa.py +12 -9
- birder/scripts/train_ibot.py +15 -11
- birder/scripts/train_kd.py +64 -17
- birder/scripts/train_mim.py +11 -8
- birder/scripts/train_mmcr.py +11 -8
- birder/scripts/train_rotnet.py +11 -7
- birder/scripts/train_simclr.py +10 -7
- birder/scripts/train_vicreg.py +10 -7
- birder/tools/adversarial.py +4 -4
- birder/tools/auto_anchors.py +5 -5
- birder/tools/avg_model.py +1 -1
- birder/tools/convert_model.py +30 -22
- birder/tools/det_results.py +1 -1
- birder/tools/download_model.py +1 -1
- birder/tools/ensemble_model.py +1 -1
- birder/tools/introspection.py +12 -3
- birder/tools/labelme_to_coco.py +2 -2
- birder/tools/model_info.py +15 -15
- birder/tools/pack.py +8 -8
- birder/tools/quantize_model.py +53 -4
- birder/tools/results.py +2 -2
- birder/tools/show_det_iterator.py +19 -6
- birder/tools/show_iterator.py +2 -2
- birder/tools/similarity.py +5 -5
- birder/tools/stats.py +4 -6
- birder/tools/voc_to_coco.py +1 -1
- birder/version.py +1 -1
- {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
- birder-0.4.1.dist-info/RECORD +300 -0
- {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
- birder/net/mobilenet_v3_small.py +0 -43
- birder/net/se_resnet_v1.py +0 -105
- birder/net/se_resnet_v2.py +0 -59
- birder/net/se_resnext.py +0 -30
- birder-0.3.3.dist-info/RECORD +0 -299
- {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
- {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
birder/scripts/train_i_jepa.py
CHANGED
|
@@ -74,7 +74,7 @@ class TrainCollator:
|
|
|
74
74
|
def __call__(self, batch: Any) -> tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor]]:
|
|
75
75
|
B = len(batch)
|
|
76
76
|
collated_batch = torch.utils.data.default_collate(batch)
|
|
77
|
-
|
|
77
|
+
enc_masks, pred_masks = self.mask_generator(B)
|
|
78
78
|
|
|
79
79
|
return (collated_batch, enc_masks, pred_masks)
|
|
80
80
|
|
|
@@ -84,7 +84,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
84
84
|
#
|
|
85
85
|
# Initialize
|
|
86
86
|
#
|
|
87
|
-
|
|
87
|
+
device, device_id, disable_tqdm = training_utils.init_training(args, logger)
|
|
88
88
|
|
|
89
89
|
if args.size is None:
|
|
90
90
|
args.size = registry.get_default_size(args.network)
|
|
@@ -119,9 +119,9 @@ def train(args: argparse.Namespace) -> None:
|
|
|
119
119
|
else:
|
|
120
120
|
model_config = {"drop_path_rate": 0.0}
|
|
121
121
|
|
|
122
|
-
backbone = registry.net_factory(args.network, sample_shape[1],
|
|
122
|
+
backbone = registry.net_factory(args.network, 0, sample_shape[1], config=model_config, size=args.size)
|
|
123
123
|
num_special_tokens = backbone.num_special_tokens
|
|
124
|
-
target_backbone = registry.net_factory(args.network, sample_shape[1],
|
|
124
|
+
target_backbone = registry.net_factory(args.network, 0, sample_shape[1], config=model_config, size=args.size)
|
|
125
125
|
encoder = I_JEPA(backbone)
|
|
126
126
|
target_encoder = I_JEPA(target_backbone)
|
|
127
127
|
target_encoder.load_state_dict(encoder.state_dict())
|
|
@@ -148,7 +148,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
148
148
|
|
|
149
149
|
if args.resume_epoch is not None:
|
|
150
150
|
begin_epoch = args.resume_epoch + 1
|
|
151
|
-
|
|
151
|
+
net, training_states = fs_ops.load_simple_checkpoint(
|
|
152
152
|
device, net, network_name, epoch=args.resume_epoch, strict=not args.non_strict_weights
|
|
153
153
|
)
|
|
154
154
|
encoder = net["encoder"]
|
|
@@ -198,11 +198,11 @@ def train(args: argparse.Namespace) -> None:
|
|
|
198
198
|
elif args.wds is True:
|
|
199
199
|
wds_path: str | list[str]
|
|
200
200
|
if args.wds_info is not None:
|
|
201
|
-
|
|
201
|
+
wds_path, dataset_size = wds_args_from_info(args.wds_info, args.wds_split)
|
|
202
202
|
if args.wds_size is not None:
|
|
203
203
|
dataset_size = args.wds_size
|
|
204
204
|
else:
|
|
205
|
-
|
|
205
|
+
wds_path, dataset_size = prepare_wds_args(args.data_path[0], args.wds_size, device)
|
|
206
206
|
|
|
207
207
|
training_dataset = make_wds_dataset(
|
|
208
208
|
wds_path,
|
|
@@ -228,7 +228,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
228
228
|
|
|
229
229
|
# Data loaders and samplers
|
|
230
230
|
virtual_epoch_mode = args.steps_per_epoch is not None
|
|
231
|
-
|
|
231
|
+
train_sampler, _ = training_utils.get_samplers(
|
|
232
232
|
args, training_dataset, validation_dataset=None, infinite=virtual_epoch_mode
|
|
233
233
|
)
|
|
234
234
|
|
|
@@ -320,7 +320,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
320
320
|
wd_schedule = None
|
|
321
321
|
|
|
322
322
|
# Gradient scaler and AMP related tasks
|
|
323
|
-
|
|
323
|
+
scaler, amp_dtype = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
|
|
324
324
|
|
|
325
325
|
# Load states
|
|
326
326
|
if args.load_states is True:
|
|
@@ -440,6 +440,9 @@ def train(args: argparse.Namespace) -> None:
|
|
|
440
440
|
tic = time.time()
|
|
441
441
|
net.train()
|
|
442
442
|
|
|
443
|
+
# Clear metrics
|
|
444
|
+
running_loss.clear()
|
|
445
|
+
|
|
443
446
|
if args.distributed is True or virtual_epoch_mode is True:
|
|
444
447
|
train_sampler.set_epoch(epoch)
|
|
445
448
|
|
birder/scripts/train_ibot.py
CHANGED
|
@@ -107,7 +107,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
107
107
|
#
|
|
108
108
|
# Initialize
|
|
109
109
|
#
|
|
110
|
-
|
|
110
|
+
device, device_id, disable_tqdm = training_utils.init_training(args, logger)
|
|
111
111
|
|
|
112
112
|
if args.size is None:
|
|
113
113
|
args.size = registry.get_default_size(args.network)
|
|
@@ -136,7 +136,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
136
136
|
|
|
137
137
|
network_name = get_mim_network_name("ibot", encoder=args.network, tag=args.tag)
|
|
138
138
|
|
|
139
|
-
student_backbone = registry.net_factory(args.network, sample_shape[1],
|
|
139
|
+
student_backbone = registry.net_factory(args.network, 0, sample_shape[1], config=args.model_config, size=args.size)
|
|
140
140
|
if args.model_config is not None:
|
|
141
141
|
teacher_model_config = args.model_config.copy()
|
|
142
142
|
teacher_model_config.update({"drop_path_rate": 0.0})
|
|
@@ -144,7 +144,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
144
144
|
teacher_model_config = {"drop_path_rate": 0.0}
|
|
145
145
|
|
|
146
146
|
teacher_backbone = registry.net_factory(
|
|
147
|
-
args.network, sample_shape[1],
|
|
147
|
+
args.network, 0, sample_shape[1], config=teacher_model_config, size=args.size
|
|
148
148
|
)
|
|
149
149
|
student_backbone.set_dynamic_size()
|
|
150
150
|
student = iBOT(
|
|
@@ -204,7 +204,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
204
204
|
|
|
205
205
|
if args.resume_epoch is not None:
|
|
206
206
|
begin_epoch = args.resume_epoch + 1
|
|
207
|
-
|
|
207
|
+
net, training_states = fs_ops.load_simple_checkpoint(
|
|
208
208
|
device, net, network_name, epoch=args.resume_epoch, strict=not args.non_strict_weights
|
|
209
209
|
)
|
|
210
210
|
student = net["student"]
|
|
@@ -266,11 +266,11 @@ def train(args: argparse.Namespace) -> None:
|
|
|
266
266
|
elif args.wds is True:
|
|
267
267
|
wds_path: str | list[str]
|
|
268
268
|
if args.wds_info is not None:
|
|
269
|
-
|
|
269
|
+
wds_path, dataset_size = wds_args_from_info(args.wds_info, args.wds_split)
|
|
270
270
|
if args.wds_size is not None:
|
|
271
271
|
dataset_size = args.wds_size
|
|
272
272
|
else:
|
|
273
|
-
|
|
273
|
+
wds_path, dataset_size = prepare_wds_args(args.data_path[0], args.wds_size, device)
|
|
274
274
|
|
|
275
275
|
training_dataset = make_wds_dataset(
|
|
276
276
|
wds_path,
|
|
@@ -296,7 +296,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
296
296
|
|
|
297
297
|
# Data loaders and samplers
|
|
298
298
|
virtual_epoch_mode = args.steps_per_epoch is not None
|
|
299
|
-
|
|
299
|
+
train_sampler, _ = training_utils.get_samplers(
|
|
300
300
|
args, training_dataset, validation_dataset=None, infinite=virtual_epoch_mode
|
|
301
301
|
)
|
|
302
302
|
|
|
@@ -387,7 +387,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
387
387
|
wd_schedule = None
|
|
388
388
|
|
|
389
389
|
# Gradient scaler and AMP related tasks
|
|
390
|
-
|
|
390
|
+
scaler, amp_dtype = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
|
|
391
391
|
|
|
392
392
|
# Load states
|
|
393
393
|
if args.load_states is True:
|
|
@@ -507,6 +507,10 @@ def train(args: argparse.Namespace) -> None:
|
|
|
507
507
|
tic = time.time()
|
|
508
508
|
net.train()
|
|
509
509
|
|
|
510
|
+
# Clear metrics
|
|
511
|
+
running_loss.clear()
|
|
512
|
+
train_proto_agreement.clear()
|
|
513
|
+
|
|
510
514
|
if args.distributed is True or virtual_epoch_mode is True:
|
|
511
515
|
train_sampler.set_epoch(epoch)
|
|
512
516
|
|
|
@@ -553,12 +557,12 @@ def train(args: argparse.Namespace) -> None:
|
|
|
553
557
|
with torch.amp.autocast("cuda", enabled=args.amp, dtype=amp_dtype):
|
|
554
558
|
# Global views
|
|
555
559
|
with torch.no_grad():
|
|
556
|
-
|
|
560
|
+
teacher_embedding, teacher_features = teacher(torch.concat(images[:2], dim=0), None)
|
|
557
561
|
|
|
558
|
-
|
|
562
|
+
student_embedding, student_features = student(torch.concat(images[:2], dim=0), masks)
|
|
559
563
|
|
|
560
564
|
# Local views
|
|
561
|
-
|
|
565
|
+
student_local_embedding, _ = student(torch.concat(images[2:], dim=0), None, return_keys="embedding")
|
|
562
566
|
|
|
563
567
|
loss = ibot_loss(
|
|
564
568
|
student_embedding,
|
birder/scripts/train_kd.py
CHANGED
|
@@ -76,13 +76,13 @@ def train(args: argparse.Namespace) -> None:
|
|
|
76
76
|
#
|
|
77
77
|
# Initialize
|
|
78
78
|
#
|
|
79
|
-
|
|
79
|
+
device, device_id, disable_tqdm = training_utils.init_training(args, logger)
|
|
80
80
|
|
|
81
81
|
if args.type != "soft":
|
|
82
82
|
args.temperature = 1.0
|
|
83
83
|
|
|
84
84
|
# Using the teacher rgb values for the student
|
|
85
|
-
|
|
85
|
+
teacher, (class_to_idx, signature, rgb_stats, *_) = fs_ops.load_model(
|
|
86
86
|
device,
|
|
87
87
|
args.teacher,
|
|
88
88
|
config=args.teacher_model_config,
|
|
@@ -113,15 +113,15 @@ def train(args: argparse.Namespace) -> None:
|
|
|
113
113
|
training_wds_path: str | list[str]
|
|
114
114
|
val_wds_path: str | list[str]
|
|
115
115
|
if args.wds_info is not None:
|
|
116
|
-
|
|
117
|
-
|
|
116
|
+
training_wds_path, training_size = wds_args_from_info(args.wds_info, args.wds_training_split)
|
|
117
|
+
val_wds_path, val_size = wds_args_from_info(args.wds_info, args.wds_val_split)
|
|
118
118
|
if args.wds_train_size is not None:
|
|
119
119
|
training_size = args.wds_train_size
|
|
120
120
|
if args.wds_val_size is not None:
|
|
121
121
|
val_size = args.wds_val_size
|
|
122
122
|
else:
|
|
123
|
-
|
|
124
|
-
|
|
123
|
+
training_wds_path, training_size = prepare_wds_args(args.data_path, args.wds_train_size, device)
|
|
124
|
+
val_wds_path, val_size = prepare_wds_args(args.val_path, args.wds_val_size, device)
|
|
125
125
|
|
|
126
126
|
training_dataset = make_wds_dataset(
|
|
127
127
|
training_wds_path,
|
|
@@ -187,7 +187,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
187
187
|
|
|
188
188
|
# Data loaders and samplers
|
|
189
189
|
virtual_epoch_mode = args.steps_per_epoch is not None
|
|
190
|
-
|
|
190
|
+
train_sampler, validation_sampler = training_utils.get_samplers(
|
|
191
191
|
args, training_dataset, validation_dataset, infinite=virtual_epoch_mode
|
|
192
192
|
)
|
|
193
193
|
|
|
@@ -269,7 +269,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
269
269
|
|
|
270
270
|
if args.resume_epoch is not None:
|
|
271
271
|
begin_epoch = args.resume_epoch + 1
|
|
272
|
-
|
|
272
|
+
student, class_to_idx_saved, training_states = fs_ops.load_checkpoint(
|
|
273
273
|
device,
|
|
274
274
|
args.student,
|
|
275
275
|
config=args.student_model_config,
|
|
@@ -283,8 +283,8 @@ def train(args: argparse.Namespace) -> None:
|
|
|
283
283
|
else:
|
|
284
284
|
student = registry.net_factory(
|
|
285
285
|
args.student,
|
|
286
|
-
sample_shape[1],
|
|
287
286
|
num_outputs,
|
|
287
|
+
sample_shape[1],
|
|
288
288
|
config=args.student_model_config,
|
|
289
289
|
size=args.size,
|
|
290
290
|
)
|
|
@@ -383,7 +383,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
383
383
|
optimizer.step = torch.compile(optimizer.step, fullgraph=False)
|
|
384
384
|
|
|
385
385
|
# Gradient scaler and AMP related tasks
|
|
386
|
-
|
|
386
|
+
scaler, amp_dtype = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
|
|
387
387
|
|
|
388
388
|
# Load states
|
|
389
389
|
if args.load_states is True:
|
|
@@ -567,10 +567,16 @@ def train(args: argparse.Namespace) -> None:
|
|
|
567
567
|
if virtual_epoch_mode is True:
|
|
568
568
|
train_iter = iter(training_loader)
|
|
569
569
|
|
|
570
|
+
top_k = args.top_k
|
|
570
571
|
running_loss = training_utils.SmoothedValue(window_size=64)
|
|
571
572
|
running_val_loss = training_utils.SmoothedValue()
|
|
572
573
|
train_accuracy = training_utils.SmoothedValue(window_size=64)
|
|
573
574
|
val_accuracy = training_utils.SmoothedValue()
|
|
575
|
+
train_topk: Optional[training_utils.SmoothedValue] = None
|
|
576
|
+
val_topk: Optional[training_utils.SmoothedValue] = None
|
|
577
|
+
if top_k is not None:
|
|
578
|
+
train_topk = training_utils.SmoothedValue(window_size=64)
|
|
579
|
+
val_topk = training_utils.SmoothedValue()
|
|
574
580
|
|
|
575
581
|
logger.info(f"Starting training with learning rate of {last_lr}")
|
|
576
582
|
for epoch in range(begin_epoch, args.stop_epoch):
|
|
@@ -579,6 +585,16 @@ def train(args: argparse.Namespace) -> None:
|
|
|
579
585
|
if embedding_projection is not None:
|
|
580
586
|
embedding_projection.train()
|
|
581
587
|
|
|
588
|
+
# Clear metrics
|
|
589
|
+
running_loss.clear()
|
|
590
|
+
running_val_loss.clear()
|
|
591
|
+
train_accuracy.clear()
|
|
592
|
+
val_accuracy.clear()
|
|
593
|
+
if train_topk is not None:
|
|
594
|
+
train_topk.clear()
|
|
595
|
+
if val_topk is not None:
|
|
596
|
+
val_topk.clear()
|
|
597
|
+
|
|
582
598
|
if args.distributed is True or virtual_epoch_mode is True:
|
|
583
599
|
train_sampler.set_epoch(epoch)
|
|
584
600
|
|
|
@@ -616,7 +632,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
616
632
|
teacher_embedding = teacher.embedding(inputs)
|
|
617
633
|
teacher_embedding = F.normalize(teacher_embedding, dim=-1)
|
|
618
634
|
|
|
619
|
-
|
|
635
|
+
outputs, student_embedding = train_student(inputs)
|
|
620
636
|
student_embedding = embedding_projection(student_embedding) # type: ignore[misc]
|
|
621
637
|
student_embedding = F.normalize(student_embedding, dim=-1)
|
|
622
638
|
dist_loss = distillation_criterion(student_embedding, teacher_embedding)
|
|
@@ -637,7 +653,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
637
653
|
outputs = train_student(inputs)
|
|
638
654
|
dist_loss = distillation_criterion(outputs, teacher_targets)
|
|
639
655
|
elif distillation_type == "deit":
|
|
640
|
-
|
|
656
|
+
outputs, dist_output = torch.unbind(train_student(inputs), dim=1)
|
|
641
657
|
dist_loss = distillation_criterion(dist_output, teacher_targets)
|
|
642
658
|
else:
|
|
643
659
|
raise RuntimeError
|
|
@@ -693,6 +709,9 @@ def train(args: argparse.Namespace) -> None:
|
|
|
693
709
|
targets = targets.argmax(dim=1)
|
|
694
710
|
|
|
695
711
|
train_accuracy.update(training_utils.accuracy(targets, outputs.detach()))
|
|
712
|
+
if train_topk is not None:
|
|
713
|
+
topk_val = training_utils.topk_accuracy(targets, outputs.detach(), topk=(top_k,))[0]
|
|
714
|
+
train_topk.update(topk_val)
|
|
696
715
|
|
|
697
716
|
# Write statistics
|
|
698
717
|
if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
|
|
@@ -711,6 +730,9 @@ def train(args: argparse.Namespace) -> None:
|
|
|
711
730
|
|
|
712
731
|
running_loss.synchronize_between_processes(device)
|
|
713
732
|
train_accuracy.synchronize_between_processes(device)
|
|
733
|
+
if train_topk is not None:
|
|
734
|
+
train_topk.synchronize_between_processes(device)
|
|
735
|
+
|
|
714
736
|
with training_utils.single_handler_logging(logger, file_handler, enabled=not disable_tqdm) as log:
|
|
715
737
|
log.info(
|
|
716
738
|
f"[Trn] Epoch {epoch}/{epochs-1}, iter {i+1}/{last_batch_idx+1} "
|
|
@@ -725,8 +747,17 @@ def train(args: argparse.Namespace) -> None:
|
|
|
725
747
|
f"[Trn] Epoch {epoch}/{epochs-1}, iter {i+1}/{last_batch_idx+1} "
|
|
726
748
|
f"Accuracy: {train_accuracy.avg:.4f}"
|
|
727
749
|
)
|
|
750
|
+
if train_topk is not None:
|
|
751
|
+
log.info(
|
|
752
|
+
f"[Trn] Epoch {epoch}/{epochs-1}, iter {i+1}/{last_batch_idx+1} "
|
|
753
|
+
f"Accuracy@{top_k}: {train_topk.avg:.4f}"
|
|
754
|
+
)
|
|
728
755
|
|
|
729
756
|
if training_utils.is_local_primary(args) is True:
|
|
757
|
+
performance = {"training_accuracy": train_accuracy.avg}
|
|
758
|
+
if train_topk is not None:
|
|
759
|
+
performance[f"training_accuracy@{top_k}"] = train_topk.avg
|
|
760
|
+
|
|
730
761
|
summary_writer.add_scalars(
|
|
731
762
|
"loss",
|
|
732
763
|
{"training": running_loss.avg},
|
|
@@ -734,7 +765,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
734
765
|
)
|
|
735
766
|
summary_writer.add_scalars(
|
|
736
767
|
"performance",
|
|
737
|
-
|
|
768
|
+
performance,
|
|
738
769
|
((epoch - 1) * epoch_samples) + ((i + 1) * batch_size * args.world_size),
|
|
739
770
|
)
|
|
740
771
|
|
|
@@ -746,6 +777,8 @@ def train(args: argparse.Namespace) -> None:
|
|
|
746
777
|
# Epoch training metrics
|
|
747
778
|
logger.info(f"[Trn] Epoch {epoch}/{epochs-1} training_loss: {running_loss.global_avg:.4f}")
|
|
748
779
|
logger.info(f"[Trn] Epoch {epoch}/{epochs-1} training_accuracy: {train_accuracy.global_avg:.4f}")
|
|
780
|
+
if train_topk is not None:
|
|
781
|
+
logger.info(f"[Trn] Epoch {epoch}/{epochs-1} training_accuracy@{top_k}: {train_topk.global_avg:.4f}")
|
|
749
782
|
|
|
750
783
|
# Validation
|
|
751
784
|
eval_model.eval()
|
|
@@ -772,6 +805,9 @@ def train(args: argparse.Namespace) -> None:
|
|
|
772
805
|
# Statistics
|
|
773
806
|
running_val_loss.update(val_loss.detach())
|
|
774
807
|
val_accuracy.update(training_utils.accuracy(targets, outputs), n=outputs.size(0))
|
|
808
|
+
if val_topk is not None:
|
|
809
|
+
topk_val = training_utils.topk_accuracy(targets, outputs, topk=(top_k,))[0]
|
|
810
|
+
val_topk.update(topk_val, n=outputs.size(0))
|
|
775
811
|
|
|
776
812
|
# Update progress bar
|
|
777
813
|
progress.update(n=batch_size * args.world_size)
|
|
@@ -789,19 +825,30 @@ def train(args: argparse.Namespace) -> None:
|
|
|
789
825
|
|
|
790
826
|
running_val_loss.synchronize_between_processes(device)
|
|
791
827
|
val_accuracy.synchronize_between_processes(device)
|
|
828
|
+
if val_topk is not None:
|
|
829
|
+
val_topk.synchronize_between_processes(device)
|
|
830
|
+
|
|
792
831
|
epoch_val_loss = running_val_loss.global_avg
|
|
793
832
|
epoch_val_accuracy = val_accuracy.global_avg
|
|
833
|
+
if val_topk is not None:
|
|
834
|
+
epoch_val_topk = val_topk.global_avg
|
|
835
|
+
else:
|
|
836
|
+
epoch_val_topk = None
|
|
794
837
|
|
|
795
838
|
# Write statistics
|
|
796
839
|
if training_utils.is_local_primary(args) is True:
|
|
797
840
|
summary_writer.add_scalars("loss", {"validation": epoch_val_loss}, epoch * epoch_samples)
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
841
|
+
performance = {"validation_accuracy": epoch_val_accuracy}
|
|
842
|
+
if epoch_val_topk is not None:
|
|
843
|
+
performance[f"validation_accuracy@{top_k}"] = epoch_val_topk
|
|
844
|
+
|
|
845
|
+
summary_writer.add_scalars("performance", performance, epoch * epoch_samples)
|
|
801
846
|
|
|
802
847
|
# Epoch validation metrics
|
|
803
848
|
logger.info(f"[Val] Epoch {epoch}/{epochs-1} validation_loss (target only): {epoch_val_loss:.4f}")
|
|
804
849
|
logger.info(f"[Val] Epoch {epoch}/{epochs-1} validation_accuracy: {epoch_val_accuracy:.4f}")
|
|
850
|
+
if epoch_val_topk is not None:
|
|
851
|
+
logger.info(f"[Val] Epoch {epoch}/{epochs-1} validation_accuracy@{top_k}: {epoch_val_topk:.4f}")
|
|
805
852
|
|
|
806
853
|
# Learning rate scheduler update
|
|
807
854
|
if step_update is False:
|
|
@@ -989,7 +1036,7 @@ def get_args_parser() -> argparse.ArgumentParser:
|
|
|
989
1036
|
training_cli.add_compile_args(parser, teacher=True)
|
|
990
1037
|
training_cli.add_checkpoint_args(parser, default_save_frequency=5)
|
|
991
1038
|
training_cli.add_distributed_args(parser)
|
|
992
|
-
training_cli.add_logging_and_debug_args(parser)
|
|
1039
|
+
training_cli.add_logging_and_debug_args(parser, classification=True)
|
|
993
1040
|
training_cli.add_training_data_args(parser)
|
|
994
1041
|
|
|
995
1042
|
return parser
|
birder/scripts/train_mim.py
CHANGED
|
@@ -49,7 +49,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
49
49
|
#
|
|
50
50
|
# Initialize
|
|
51
51
|
#
|
|
52
|
-
|
|
52
|
+
device, device_id, disable_tqdm = training_utils.init_training(args, logger)
|
|
53
53
|
|
|
54
54
|
if args.size is None:
|
|
55
55
|
# Prefer mim size over encoder default size
|
|
@@ -73,11 +73,11 @@ def train(args: argparse.Namespace) -> None:
|
|
|
73
73
|
elif args.wds is True:
|
|
74
74
|
wds_path: str | list[str]
|
|
75
75
|
if args.wds_info is not None:
|
|
76
|
-
|
|
76
|
+
wds_path, dataset_size = wds_args_from_info(args.wds_info, args.wds_split)
|
|
77
77
|
if args.wds_size is not None:
|
|
78
78
|
dataset_size = args.wds_size
|
|
79
79
|
else:
|
|
80
|
-
|
|
80
|
+
wds_path, dataset_size = prepare_wds_args(args.data_path[0], args.wds_size, device)
|
|
81
81
|
|
|
82
82
|
training_dataset = make_wds_dataset(
|
|
83
83
|
wds_path,
|
|
@@ -107,7 +107,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
107
107
|
|
|
108
108
|
# Data loaders and samplers
|
|
109
109
|
virtual_epoch_mode = args.steps_per_epoch is not None
|
|
110
|
-
|
|
110
|
+
train_sampler, _ = training_utils.get_samplers(
|
|
111
111
|
args, training_dataset, validation_dataset=None, infinite=virtual_epoch_mode
|
|
112
112
|
)
|
|
113
113
|
|
|
@@ -172,7 +172,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
172
172
|
|
|
173
173
|
if args.resume_epoch is not None:
|
|
174
174
|
begin_epoch = args.resume_epoch + 1
|
|
175
|
-
|
|
175
|
+
net, training_states = fs_ops.load_mim_checkpoint(
|
|
176
176
|
device,
|
|
177
177
|
args.network,
|
|
178
178
|
config=args.model_config,
|
|
@@ -187,7 +187,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
187
187
|
|
|
188
188
|
elif args.pretrained is True:
|
|
189
189
|
fs_ops.download_model_by_weights(network_name, progress_bar=training_utils.is_local_primary(args))
|
|
190
|
-
|
|
190
|
+
net, training_states = fs_ops.load_mim_checkpoint(
|
|
191
191
|
device,
|
|
192
192
|
args.network,
|
|
193
193
|
config=args.model_config,
|
|
@@ -202,7 +202,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
202
202
|
|
|
203
203
|
else:
|
|
204
204
|
encoder = registry.net_factory(
|
|
205
|
-
args.encoder, sample_shape[1],
|
|
205
|
+
args.encoder, 0, sample_shape[1], config=args.encoder_model_config, size=args.size
|
|
206
206
|
)
|
|
207
207
|
net = registry.mim_net_factory(
|
|
208
208
|
args.network,
|
|
@@ -263,7 +263,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
263
263
|
optimizer.step = torch.compile(optimizer.step, fullgraph=False)
|
|
264
264
|
|
|
265
265
|
# Gradient scaler and AMP related tasks
|
|
266
|
-
|
|
266
|
+
scaler, amp_dtype = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
|
|
267
267
|
|
|
268
268
|
# Load states
|
|
269
269
|
if args.load_states is True:
|
|
@@ -375,6 +375,9 @@ def train(args: argparse.Namespace) -> None:
|
|
|
375
375
|
tic = time.time()
|
|
376
376
|
net.train()
|
|
377
377
|
|
|
378
|
+
# Clear metrics
|
|
379
|
+
running_loss.clear()
|
|
380
|
+
|
|
378
381
|
if args.distributed is True or virtual_epoch_mode is True:
|
|
379
382
|
train_sampler.set_epoch(epoch)
|
|
380
383
|
|
birder/scripts/train_mmcr.py
CHANGED
|
@@ -74,7 +74,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
74
74
|
#
|
|
75
75
|
# Initialize
|
|
76
76
|
#
|
|
77
|
-
|
|
77
|
+
device, device_id, disable_tqdm = training_utils.init_training(args, logger)
|
|
78
78
|
|
|
79
79
|
if args.size is None:
|
|
80
80
|
args.size = registry.get_default_size(args.network)
|
|
@@ -97,11 +97,11 @@ def train(args: argparse.Namespace) -> None:
|
|
|
97
97
|
elif args.wds is True:
|
|
98
98
|
wds_path: str | list[str]
|
|
99
99
|
if args.wds_info is not None:
|
|
100
|
-
|
|
100
|
+
wds_path, dataset_size = wds_args_from_info(args.wds_info, args.wds_split)
|
|
101
101
|
if args.wds_size is not None:
|
|
102
102
|
dataset_size = args.wds_size
|
|
103
103
|
else:
|
|
104
|
-
|
|
104
|
+
wds_path, dataset_size = prepare_wds_args(args.data_path[0], args.wds_size, device)
|
|
105
105
|
|
|
106
106
|
training_dataset = make_wds_dataset(
|
|
107
107
|
wds_path,
|
|
@@ -131,7 +131,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
131
131
|
|
|
132
132
|
# Data loaders and samplers
|
|
133
133
|
virtual_epoch_mode = args.steps_per_epoch is not None
|
|
134
|
-
|
|
134
|
+
train_sampler, _ = training_utils.get_samplers(
|
|
135
135
|
args, training_dataset, validation_dataset=None, infinite=virtual_epoch_mode
|
|
136
136
|
)
|
|
137
137
|
|
|
@@ -194,12 +194,12 @@ def train(args: argparse.Namespace) -> None:
|
|
|
194
194
|
|
|
195
195
|
network_name = get_mim_network_name("mmcr", encoder=args.network, tag=args.tag)
|
|
196
196
|
|
|
197
|
-
backbone = registry.net_factory(args.network, sample_shape[1],
|
|
197
|
+
backbone = registry.net_factory(args.network, 0, sample_shape[1], config=args.model_config, size=args.size)
|
|
198
198
|
net = MMCR(backbone, config={"projector_dims": args.projector_dims})
|
|
199
199
|
|
|
200
200
|
if args.resume_epoch is not None:
|
|
201
201
|
begin_epoch = args.resume_epoch + 1
|
|
202
|
-
|
|
202
|
+
net, training_states = fs_ops.load_simple_checkpoint(
|
|
203
203
|
device, net, network_name, epoch=args.resume_epoch, strict=not args.non_strict_weights
|
|
204
204
|
)
|
|
205
205
|
|
|
@@ -265,7 +265,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
265
265
|
optimizer.step = torch.compile(optimizer.step, fullgraph=False)
|
|
266
266
|
|
|
267
267
|
# Gradient scaler and AMP related tasks
|
|
268
|
-
|
|
268
|
+
scaler, amp_dtype = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
|
|
269
269
|
|
|
270
270
|
# Load states
|
|
271
271
|
if args.load_states is True:
|
|
@@ -377,6 +377,9 @@ def train(args: argparse.Namespace) -> None:
|
|
|
377
377
|
tic = time.time()
|
|
378
378
|
net.train()
|
|
379
379
|
|
|
380
|
+
# Clear metrics
|
|
381
|
+
running_loss.clear()
|
|
382
|
+
|
|
380
383
|
if args.distributed is True or virtual_epoch_mode is True:
|
|
381
384
|
train_sampler.set_epoch(epoch)
|
|
382
385
|
|
|
@@ -407,7 +410,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
407
410
|
|
|
408
411
|
# Forward, backward and optimize
|
|
409
412
|
with torch.amp.autocast("cuda", enabled=args.amp, dtype=amp_dtype):
|
|
410
|
-
|
|
413
|
+
z, z_m = net(images)
|
|
411
414
|
loss = mmcr_loss(z, z_m)
|
|
412
415
|
|
|
413
416
|
if scaler is not None:
|
birder/scripts/train_rotnet.py
CHANGED
|
@@ -83,7 +83,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
83
83
|
#
|
|
84
84
|
# Initialize
|
|
85
85
|
#
|
|
86
|
-
|
|
86
|
+
device, device_id, disable_tqdm = training_utils.init_training(args, logger)
|
|
87
87
|
|
|
88
88
|
if args.size is None:
|
|
89
89
|
args.size = registry.get_default_size(args.network)
|
|
@@ -111,11 +111,11 @@ def train(args: argparse.Namespace) -> None:
|
|
|
111
111
|
elif args.wds is True:
|
|
112
112
|
wds_path: str | list[str]
|
|
113
113
|
if args.wds_info is not None:
|
|
114
|
-
|
|
114
|
+
wds_path, dataset_size = wds_args_from_info(args.wds_info, args.wds_split)
|
|
115
115
|
if args.wds_size is not None:
|
|
116
116
|
dataset_size = args.wds_size
|
|
117
117
|
else:
|
|
118
|
-
|
|
118
|
+
wds_path, dataset_size = prepare_wds_args(args.data_path[0], args.wds_size, device)
|
|
119
119
|
|
|
120
120
|
training_dataset = make_wds_dataset(
|
|
121
121
|
wds_path,
|
|
@@ -145,7 +145,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
145
145
|
|
|
146
146
|
# Data loaders and samplers
|
|
147
147
|
virtual_epoch_mode = args.steps_per_epoch is not None
|
|
148
|
-
|
|
148
|
+
train_sampler, _ = training_utils.get_samplers(
|
|
149
149
|
args, training_dataset, validation_dataset=None, infinite=virtual_epoch_mode
|
|
150
150
|
)
|
|
151
151
|
|
|
@@ -207,12 +207,12 @@ def train(args: argparse.Namespace) -> None:
|
|
|
207
207
|
network_name = f"{network_name}-{args.tag}"
|
|
208
208
|
|
|
209
209
|
net = registry.net_factory(
|
|
210
|
-
args.network, sample_shape[1],
|
|
210
|
+
args.network, len(class_to_idx), sample_shape[1], config=args.model_config, size=args.size
|
|
211
211
|
)
|
|
212
212
|
|
|
213
213
|
if args.resume_epoch is not None:
|
|
214
214
|
begin_epoch = args.resume_epoch + 1
|
|
215
|
-
|
|
215
|
+
net, training_states = fs_ops.load_simple_checkpoint(
|
|
216
216
|
device, net, network_name, epoch=args.resume_epoch, strict=not args.non_strict_weights
|
|
217
217
|
)
|
|
218
218
|
|
|
@@ -277,7 +277,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
277
277
|
optimizer.step = torch.compile(optimizer.step, fullgraph=False)
|
|
278
278
|
|
|
279
279
|
# Gradient scaler and AMP related tasks
|
|
280
|
-
|
|
280
|
+
scaler, amp_dtype = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
|
|
281
281
|
|
|
282
282
|
# Load states
|
|
283
283
|
if args.load_states is True:
|
|
@@ -389,6 +389,10 @@ def train(args: argparse.Namespace) -> None:
|
|
|
389
389
|
tic = time.time()
|
|
390
390
|
net.train()
|
|
391
391
|
|
|
392
|
+
# Clear metrics
|
|
393
|
+
running_loss.clear()
|
|
394
|
+
train_accuracy.clear()
|
|
395
|
+
|
|
392
396
|
if args.distributed is True or virtual_epoch_mode is True:
|
|
393
397
|
train_sampler.set_epoch(epoch)
|
|
394
398
|
|