birder 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/adversarial/base.py +1 -1
- birder/adversarial/simba.py +4 -4
- birder/common/cli.py +1 -1
- birder/common/fs_ops.py +11 -11
- birder/common/lib.py +2 -2
- birder/common/masking.py +3 -3
- birder/common/training_cli.py +24 -2
- birder/common/training_utils.py +28 -4
- birder/data/collators/detection.py +9 -1
- birder/data/transforms/detection.py +27 -8
- birder/data/transforms/mosaic.py +1 -1
- birder/datahub/classification.py +3 -3
- birder/inference/classification.py +3 -3
- birder/inference/data_parallel.py +1 -1
- birder/inference/detection.py +5 -5
- birder/inference/wbf.py +1 -1
- birder/introspection/attention_rollout.py +5 -5
- birder/introspection/feature_pca.py +4 -4
- birder/introspection/gradcam.py +1 -1
- birder/introspection/guided_backprop.py +2 -2
- birder/introspection/transformer_attribution.py +3 -3
- birder/layers/attention_pool.py +2 -2
- birder/model_registry/model_registry.py +2 -1
- birder/net/__init__.py +2 -0
- birder/net/_rope_vit_configs.py +5 -0
- birder/net/_vit_configs.py +0 -13
- birder/net/alexnet.py +5 -5
- birder/net/base.py +28 -3
- birder/net/biformer.py +17 -17
- birder/net/cait.py +2 -2
- birder/net/cas_vit.py +1 -1
- birder/net/coat.py +15 -15
- birder/net/convnext_v1.py +2 -10
- birder/net/convnext_v1_iso.py +198 -0
- birder/net/convnext_v2.py +2 -10
- birder/net/crossformer.py +9 -9
- birder/net/crossvit.py +1 -1
- birder/net/cspnet.py +1 -1
- birder/net/cswin_transformer.py +10 -10
- birder/net/davit.py +10 -10
- birder/net/deit.py +56 -3
- birder/net/deit3.py +27 -15
- birder/net/detection/__init__.py +4 -0
- birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
- birder/net/detection/base.py +6 -5
- birder/net/detection/deformable_detr.py +26 -28
- birder/net/detection/detr.py +9 -9
- birder/net/detection/efficientdet.py +9 -28
- birder/net/detection/faster_rcnn.py +22 -22
- birder/net/detection/fcos.py +8 -8
- birder/net/detection/plain_detr.py +852 -0
- birder/net/detection/retinanet.py +4 -4
- birder/net/detection/rt_detr_v1.py +81 -25
- birder/net/detection/rt_detr_v2.py +1147 -0
- birder/net/detection/ssd.py +5 -5
- birder/net/detection/yolo_v2.py +12 -12
- birder/net/detection/yolo_v3.py +19 -19
- birder/net/detection/yolo_v4.py +16 -16
- birder/net/detection/yolo_v4_tiny.py +3 -3
- birder/net/edgenext.py +3 -3
- birder/net/edgevit.py +10 -14
- birder/net/efficientformer_v1.py +1 -1
- birder/net/efficientvim.py +9 -9
- birder/net/efficientvit_mit.py +2 -2
- birder/net/efficientvit_msft.py +3 -3
- birder/net/fasternet.py +1 -1
- birder/net/fastvit.py +5 -12
- birder/net/flexivit.py +28 -15
- birder/net/focalnet.py +5 -9
- birder/net/gc_vit.py +11 -11
- birder/net/ghostnet_v1.py +1 -1
- birder/net/ghostnet_v2.py +1 -1
- birder/net/groupmixformer.py +12 -12
- birder/net/hgnet_v1.py +1 -1
- birder/net/hgnet_v2.py +4 -4
- birder/net/hiera.py +6 -6
- birder/net/hieradet.py +9 -9
- birder/net/hornet.py +3 -3
- birder/net/iformer.py +4 -4
- birder/net/inception_next.py +4 -14
- birder/net/levit.py +3 -3
- birder/net/lit_v1.py +13 -15
- birder/net/lit_v1_tiny.py +9 -9
- birder/net/lit_v2.py +14 -15
- birder/net/maxvit.py +10 -22
- birder/net/metaformer.py +2 -2
- birder/net/mim/crossmae.py +5 -5
- birder/net/mim/fcmae.py +3 -5
- birder/net/mim/mae_hiera.py +7 -7
- birder/net/mim/mae_vit.py +3 -5
- birder/net/mim/simmim.py +2 -3
- birder/net/mobilenet_v4_hybrid.py +4 -4
- birder/net/mobileone.py +5 -12
- birder/net/mobilevit_v1.py +2 -2
- birder/net/mobilevit_v2.py +5 -9
- birder/net/mvit_v2.py +24 -24
- birder/net/nextvit.py +2 -2
- birder/net/pit.py +11 -26
- birder/net/pvt_v1.py +4 -4
- birder/net/pvt_v2.py +5 -11
- birder/net/regionvit.py +15 -15
- birder/net/regnet.py +1 -1
- birder/net/repghost.py +4 -5
- birder/net/repvgg.py +3 -5
- birder/net/repvit.py +2 -2
- birder/net/resnest.py +1 -1
- birder/net/rope_deit3.py +29 -15
- birder/net/rope_flexivit.py +28 -15
- birder/net/rope_vit.py +41 -23
- birder/net/sequencer2d.py +3 -4
- birder/net/shufflenet_v1.py +1 -1
- birder/net/shufflenet_v2.py +1 -1
- birder/net/simple_vit.py +47 -5
- birder/net/smt.py +7 -7
- birder/net/ssl/barlow_twins.py +1 -1
- birder/net/ssl/byol.py +2 -2
- birder/net/ssl/capi.py +3 -3
- birder/net/ssl/data2vec2.py +1 -1
- birder/net/ssl/dino_v2.py +11 -1
- birder/net/ssl/franca.py +26 -2
- birder/net/ssl/i_jepa.py +4 -4
- birder/net/ssl/mmcr.py +1 -1
- birder/net/swiftformer.py +1 -1
- birder/net/swin_transformer_v1.py +4 -5
- birder/net/swin_transformer_v2.py +4 -7
- birder/net/tiny_vit.py +3 -3
- birder/net/transnext.py +19 -19
- birder/net/uniformer.py +4 -4
- birder/net/vgg.py +1 -10
- birder/net/vit.py +38 -25
- birder/net/vit_parallel.py +35 -20
- birder/net/vit_sam.py +10 -10
- birder/net/vovnet_v2.py +1 -1
- birder/net/xcit.py +9 -7
- birder/ops/msda.py +4 -4
- birder/ops/swattention.py +10 -10
- birder/results/classification.py +3 -3
- birder/results/gui.py +8 -8
- birder/scripts/benchmark.py +37 -12
- birder/scripts/evaluate.py +1 -1
- birder/scripts/predict.py +3 -3
- birder/scripts/predict_detection.py +2 -2
- birder/scripts/train.py +63 -15
- birder/scripts/train_barlow_twins.py +10 -7
- birder/scripts/train_byol.py +10 -7
- birder/scripts/train_capi.py +15 -10
- birder/scripts/train_data2vec.py +10 -7
- birder/scripts/train_data2vec2.py +10 -7
- birder/scripts/train_detection.py +29 -14
- birder/scripts/train_dino_v1.py +13 -9
- birder/scripts/train_dino_v2.py +27 -14
- birder/scripts/train_dino_v2_dist.py +28 -15
- birder/scripts/train_franca.py +16 -9
- birder/scripts/train_i_jepa.py +12 -9
- birder/scripts/train_ibot.py +15 -11
- birder/scripts/train_kd.py +64 -17
- birder/scripts/train_mim.py +11 -8
- birder/scripts/train_mmcr.py +11 -8
- birder/scripts/train_rotnet.py +11 -7
- birder/scripts/train_simclr.py +10 -7
- birder/scripts/train_vicreg.py +10 -7
- birder/tools/adversarial.py +4 -4
- birder/tools/auto_anchors.py +5 -5
- birder/tools/avg_model.py +1 -1
- birder/tools/convert_model.py +30 -22
- birder/tools/det_results.py +1 -1
- birder/tools/download_model.py +1 -1
- birder/tools/ensemble_model.py +1 -1
- birder/tools/introspection.py +11 -2
- birder/tools/labelme_to_coco.py +2 -2
- birder/tools/model_info.py +12 -14
- birder/tools/pack.py +8 -8
- birder/tools/quantize_model.py +53 -4
- birder/tools/results.py +2 -2
- birder/tools/show_det_iterator.py +19 -6
- birder/tools/show_iterator.py +2 -2
- birder/tools/similarity.py +5 -5
- birder/tools/stats.py +4 -6
- birder/tools/voc_to_coco.py +1 -1
- birder/version.py +1 -1
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
- birder-0.4.1.dist-info/RECORD +300 -0
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
- birder-0.4.0.dist-info/RECORD +0 -297
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -75,7 +75,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
75
75
|
#
|
|
76
76
|
# Initialize
|
|
77
77
|
#
|
|
78
|
-
|
|
78
|
+
device, device_id, disable_tqdm = training_utils.init_training(args, logger)
|
|
79
79
|
|
|
80
80
|
if args.size is None:
|
|
81
81
|
# Prefer mim size over encoder default size
|
|
@@ -105,7 +105,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
105
105
|
|
|
106
106
|
network_name = get_mim_network_name("data2vec2", encoder=args.network, tag=args.tag)
|
|
107
107
|
|
|
108
|
-
backbone = registry.net_factory(args.network, sample_shape[1],
|
|
108
|
+
backbone = registry.net_factory(args.network, 0, sample_shape[1], config=args.model_config, size=args.size)
|
|
109
109
|
net = Data2Vec2(
|
|
110
110
|
backbone,
|
|
111
111
|
config={
|
|
@@ -121,7 +121,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
121
121
|
|
|
122
122
|
if args.resume_epoch is not None:
|
|
123
123
|
begin_epoch = args.resume_epoch + 1
|
|
124
|
-
|
|
124
|
+
net, training_states = fs_ops.load_simple_checkpoint(
|
|
125
125
|
device, net, network_name, epoch=args.resume_epoch, strict=not args.non_strict_weights
|
|
126
126
|
)
|
|
127
127
|
|
|
@@ -169,11 +169,11 @@ def train(args: argparse.Namespace) -> None:
|
|
|
169
169
|
elif args.wds is True:
|
|
170
170
|
wds_path: str | list[str]
|
|
171
171
|
if args.wds_info is not None:
|
|
172
|
-
|
|
172
|
+
wds_path, dataset_size = wds_args_from_info(args.wds_info, args.wds_split)
|
|
173
173
|
if args.wds_size is not None:
|
|
174
174
|
dataset_size = args.wds_size
|
|
175
175
|
else:
|
|
176
|
-
|
|
176
|
+
wds_path, dataset_size = prepare_wds_args(args.data_path[0], args.wds_size, device)
|
|
177
177
|
|
|
178
178
|
training_dataset = make_wds_dataset(
|
|
179
179
|
wds_path,
|
|
@@ -199,7 +199,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
199
199
|
|
|
200
200
|
# Data loaders and samplers
|
|
201
201
|
virtual_epoch_mode = args.steps_per_epoch is not None
|
|
202
|
-
|
|
202
|
+
train_sampler, _ = training_utils.get_samplers(
|
|
203
203
|
args, training_dataset, validation_dataset=None, infinite=virtual_epoch_mode
|
|
204
204
|
)
|
|
205
205
|
|
|
@@ -288,7 +288,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
288
288
|
)
|
|
289
289
|
|
|
290
290
|
# Gradient scaler and AMP related tasks
|
|
291
|
-
|
|
291
|
+
scaler, amp_dtype = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
|
|
292
292
|
|
|
293
293
|
# Load states
|
|
294
294
|
if args.load_states is True:
|
|
@@ -400,6 +400,9 @@ def train(args: argparse.Namespace) -> None:
|
|
|
400
400
|
tic = time.time()
|
|
401
401
|
net.train()
|
|
402
402
|
|
|
403
|
+
# Clear metrics
|
|
404
|
+
running_loss.clear()
|
|
405
|
+
|
|
403
406
|
if args.distributed is True or virtual_epoch_mode is True:
|
|
404
407
|
train_sampler.set_epoch(epoch)
|
|
405
408
|
|
|
@@ -27,7 +27,7 @@ from birder.common import training_cli
|
|
|
27
27
|
from birder.common import training_utils
|
|
28
28
|
from birder.conf import settings
|
|
29
29
|
from birder.data.collators.detection import BatchRandomResizeCollator
|
|
30
|
-
from birder.data.collators.detection import
|
|
30
|
+
from birder.data.collators.detection import DetectionCollator
|
|
31
31
|
from birder.data.datasets.coco import CocoMosaicTraining
|
|
32
32
|
from birder.data.datasets.coco import CocoTraining
|
|
33
33
|
from birder.data.transforms.classification import get_rgb_stats
|
|
@@ -63,7 +63,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
63
63
|
)
|
|
64
64
|
model_dynamic_size = transform_dynamic_size or args.batch_multiscale is True
|
|
65
65
|
|
|
66
|
-
|
|
66
|
+
device, device_id, disable_tqdm = training_utils.init_training(
|
|
67
67
|
args, logger, cudnn_dynamic_size=transform_dynamic_size
|
|
68
68
|
)
|
|
69
69
|
|
|
@@ -92,6 +92,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
92
92
|
args.multiscale,
|
|
93
93
|
args.max_size,
|
|
94
94
|
args.multiscale_min_size,
|
|
95
|
+
args.multiscale_step,
|
|
95
96
|
)
|
|
96
97
|
mosaic_dataset = None
|
|
97
98
|
if args.mosaic_prob > 0.0:
|
|
@@ -104,6 +105,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
104
105
|
args.multiscale,
|
|
105
106
|
args.max_size,
|
|
106
107
|
args.multiscale_min_size,
|
|
108
|
+
args.multiscale_step,
|
|
107
109
|
post_mosaic=True,
|
|
108
110
|
)
|
|
109
111
|
if args.dynamic_size is True or args.multiscale is True:
|
|
@@ -177,14 +179,22 @@ def train(args: argparse.Namespace) -> None:
|
|
|
177
179
|
|
|
178
180
|
# Data loaders and samplers
|
|
179
181
|
virtual_epoch_mode = args.steps_per_epoch is not None
|
|
180
|
-
|
|
182
|
+
train_sampler, validation_sampler = training_utils.get_samplers(
|
|
181
183
|
args, training_dataset, validation_dataset, infinite=virtual_epoch_mode
|
|
182
184
|
)
|
|
183
185
|
|
|
184
186
|
if args.batch_multiscale is True:
|
|
185
|
-
train_collate_fn: Any = BatchRandomResizeCollator(
|
|
187
|
+
train_collate_fn: Any = BatchRandomResizeCollator(
|
|
188
|
+
0,
|
|
189
|
+
args.size,
|
|
190
|
+
size_divisible=args.multiscale_step,
|
|
191
|
+
multiscale_min_size=args.multiscale_min_size,
|
|
192
|
+
multiscale_step=args.multiscale_step,
|
|
193
|
+
)
|
|
186
194
|
else:
|
|
187
|
-
train_collate_fn =
|
|
195
|
+
train_collate_fn = DetectionCollator(0, size_divisible=args.multiscale_step)
|
|
196
|
+
|
|
197
|
+
validation_collate_fn = DetectionCollator(0, size_divisible=args.multiscale_step)
|
|
188
198
|
|
|
189
199
|
training_loader = DataLoader(
|
|
190
200
|
training_dataset,
|
|
@@ -202,7 +212,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
202
212
|
sampler=validation_sampler,
|
|
203
213
|
num_workers=args.num_workers,
|
|
204
214
|
prefetch_factor=args.prefetch_factor,
|
|
205
|
-
collate_fn=
|
|
215
|
+
collate_fn=validation_collate_fn,
|
|
206
216
|
pin_memory=True,
|
|
207
217
|
drop_last=args.drop_last,
|
|
208
218
|
)
|
|
@@ -243,7 +253,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
243
253
|
|
|
244
254
|
if args.resume_epoch is not None:
|
|
245
255
|
begin_epoch = args.resume_epoch + 1
|
|
246
|
-
|
|
256
|
+
net, class_to_idx_saved, training_states = fs_ops.load_detection_checkpoint(
|
|
247
257
|
device,
|
|
248
258
|
args.network,
|
|
249
259
|
config=args.model_config,
|
|
@@ -262,7 +272,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
262
272
|
|
|
263
273
|
elif args.pretrained is True:
|
|
264
274
|
fs_ops.download_model_by_weights(network_name, progress_bar=training_utils.is_local_primary(args))
|
|
265
|
-
|
|
275
|
+
net, class_to_idx_saved, training_states = fs_ops.load_detection_checkpoint(
|
|
266
276
|
device,
|
|
267
277
|
args.network,
|
|
268
278
|
config=args.model_config,
|
|
@@ -282,7 +292,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
282
292
|
else:
|
|
283
293
|
if args.backbone_epoch is not None:
|
|
284
294
|
backbone: DetectorBackbone
|
|
285
|
-
|
|
295
|
+
backbone, class_to_idx_saved, _ = fs_ops.load_checkpoint(
|
|
286
296
|
device,
|
|
287
297
|
args.backbone,
|
|
288
298
|
config=args.backbone_model_config,
|
|
@@ -297,7 +307,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
297
307
|
lib.get_network_name(args.backbone, tag=args.backbone_tag),
|
|
298
308
|
progress_bar=training_utils.is_local_primary(args),
|
|
299
309
|
)
|
|
300
|
-
|
|
310
|
+
backbone, class_to_idx_saved, _ = fs_ops.load_checkpoint(
|
|
301
311
|
device,
|
|
302
312
|
args.backbone,
|
|
303
313
|
config=args.backbone_model_config,
|
|
@@ -309,7 +319,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
309
319
|
|
|
310
320
|
else:
|
|
311
321
|
backbone = registry.net_factory(
|
|
312
|
-
args.backbone, sample_shape[1],
|
|
322
|
+
args.backbone, num_outputs, sample_shape[1], config=args.backbone_model_config, size=args.size
|
|
313
323
|
)
|
|
314
324
|
|
|
315
325
|
net = registry.detection_net_factory(
|
|
@@ -386,7 +396,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
386
396
|
optimizer.step = torch.compile(optimizer.step, fullgraph=False)
|
|
387
397
|
|
|
388
398
|
# Gradient scaler and AMP related tasks
|
|
389
|
-
|
|
399
|
+
scaler, amp_dtype = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
|
|
390
400
|
|
|
391
401
|
# Load states
|
|
392
402
|
if args.load_states is True:
|
|
@@ -546,6 +556,11 @@ def train(args: argparse.Namespace) -> None:
|
|
|
546
556
|
tic = time.time()
|
|
547
557
|
net.train()
|
|
548
558
|
|
|
559
|
+
# Clear metrics
|
|
560
|
+
running_loss.clear()
|
|
561
|
+
for tracker in loss_trackers.values():
|
|
562
|
+
tracker.clear()
|
|
563
|
+
|
|
549
564
|
validation_metrics.reset()
|
|
550
565
|
|
|
551
566
|
if args.distributed is True or virtual_epoch_mode is True:
|
|
@@ -586,7 +601,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
586
601
|
|
|
587
602
|
# Forward, backward and optimize
|
|
588
603
|
with torch.amp.autocast("cuda", enabled=args.amp, dtype=amp_dtype):
|
|
589
|
-
|
|
604
|
+
_detections, losses = net(inputs, targets, masks, image_sizes)
|
|
590
605
|
loss = sum(v for v in losses.values())
|
|
591
606
|
|
|
592
607
|
if scaler is not None:
|
|
@@ -708,7 +723,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
708
723
|
masks = masks.to(device, non_blocking=True)
|
|
709
724
|
|
|
710
725
|
with torch.amp.autocast("cuda", enabled=args.amp, dtype=amp_dtype):
|
|
711
|
-
|
|
726
|
+
detections, losses = eval_model(inputs, masks=masks, image_sizes=image_sizes)
|
|
712
727
|
|
|
713
728
|
for target in targets:
|
|
714
729
|
# TorchMetrics can't handle "empty" images
|
birder/scripts/train_dino_v1.py
CHANGED
|
@@ -101,7 +101,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
101
101
|
#
|
|
102
102
|
# Initialize
|
|
103
103
|
#
|
|
104
|
-
|
|
104
|
+
device, device_id, disable_tqdm = training_utils.init_training(args, logger)
|
|
105
105
|
|
|
106
106
|
if args.size is None:
|
|
107
107
|
args.size = registry.get_default_size(args.network)
|
|
@@ -129,11 +129,11 @@ def train(args: argparse.Namespace) -> None:
|
|
|
129
129
|
elif args.wds is True:
|
|
130
130
|
wds_path: str | list[str]
|
|
131
131
|
if args.wds_info is not None:
|
|
132
|
-
|
|
132
|
+
wds_path, dataset_size = wds_args_from_info(args.wds_info, args.wds_split)
|
|
133
133
|
if args.wds_size is not None:
|
|
134
134
|
dataset_size = args.wds_size
|
|
135
135
|
else:
|
|
136
|
-
|
|
136
|
+
wds_path, dataset_size = prepare_wds_args(args.data_path[0], args.wds_size, device)
|
|
137
137
|
|
|
138
138
|
training_dataset = make_wds_dataset(
|
|
139
139
|
wds_path,
|
|
@@ -163,7 +163,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
163
163
|
|
|
164
164
|
# Data loaders and samplers
|
|
165
165
|
virtual_epoch_mode = args.steps_per_epoch is not None
|
|
166
|
-
|
|
166
|
+
train_sampler, _ = training_utils.get_samplers(
|
|
167
167
|
args, training_dataset, validation_dataset=None, infinite=virtual_epoch_mode
|
|
168
168
|
)
|
|
169
169
|
|
|
@@ -226,9 +226,9 @@ def train(args: argparse.Namespace) -> None:
|
|
|
226
226
|
|
|
227
227
|
network_name = get_mim_network_name("dino_v1", encoder=args.network, tag=args.tag)
|
|
228
228
|
|
|
229
|
-
student_backbone = registry.net_factory(args.network, sample_shape[1],
|
|
229
|
+
student_backbone = registry.net_factory(args.network, 0, sample_shape[1], config=args.model_config, size=args.size)
|
|
230
230
|
if args.backbone_epoch is not None:
|
|
231
|
-
|
|
231
|
+
student_backbone, _ = fs_ops.load_simple_checkpoint(
|
|
232
232
|
device, student_backbone, backbone_name, epoch=args.backbone_epoch, strict=not args.non_strict_weights
|
|
233
233
|
)
|
|
234
234
|
|
|
@@ -239,7 +239,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
239
239
|
teacher_model_config = {"drop_path_rate": 0.0}
|
|
240
240
|
|
|
241
241
|
teacher_backbone = registry.net_factory(
|
|
242
|
-
args.network, sample_shape[1],
|
|
242
|
+
args.network, 0, sample_shape[1], config=teacher_model_config, size=args.size
|
|
243
243
|
)
|
|
244
244
|
if args.freeze_body is True:
|
|
245
245
|
student_backbone.freeze(freeze_classifier=False, unfreeze_features=True)
|
|
@@ -293,7 +293,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
293
293
|
|
|
294
294
|
if args.resume_epoch is not None:
|
|
295
295
|
begin_epoch = args.resume_epoch + 1
|
|
296
|
-
|
|
296
|
+
net, training_states = fs_ops.load_simple_checkpoint(
|
|
297
297
|
device, net, network_name, epoch=args.resume_epoch, strict=not args.non_strict_weights
|
|
298
298
|
)
|
|
299
299
|
student = net["student"]
|
|
@@ -368,7 +368,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
368
368
|
wd_schedule = None
|
|
369
369
|
|
|
370
370
|
# Gradient scaler and AMP related tasks
|
|
371
|
-
|
|
371
|
+
scaler, amp_dtype = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
|
|
372
372
|
|
|
373
373
|
# Load states
|
|
374
374
|
if args.load_states is True:
|
|
@@ -488,6 +488,10 @@ def train(args: argparse.Namespace) -> None:
|
|
|
488
488
|
tic = time.time()
|
|
489
489
|
net.train()
|
|
490
490
|
|
|
491
|
+
# Clear metrics
|
|
492
|
+
running_loss.clear()
|
|
493
|
+
train_proto_agreement.clear()
|
|
494
|
+
|
|
491
495
|
if args.distributed is True or virtual_epoch_mode is True:
|
|
492
496
|
train_sampler.set_epoch(epoch)
|
|
493
497
|
|
birder/scripts/train_dino_v2.py
CHANGED
|
@@ -178,7 +178,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
178
178
|
#
|
|
179
179
|
# Initialize
|
|
180
180
|
#
|
|
181
|
-
|
|
181
|
+
device, device_id, disable_tqdm = training_utils.init_training(args, logger)
|
|
182
182
|
|
|
183
183
|
if args.size is None:
|
|
184
184
|
args.size = registry.get_default_size(args.network)
|
|
@@ -207,7 +207,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
207
207
|
|
|
208
208
|
network_name = get_mim_network_name("dino_v2", encoder=args.network, tag=args.tag)
|
|
209
209
|
|
|
210
|
-
student_backbone = registry.net_factory(args.network, sample_shape[1],
|
|
210
|
+
student_backbone = registry.net_factory(args.network, 0, sample_shape[1], config=args.model_config, size=args.size)
|
|
211
211
|
if args.model_config is not None:
|
|
212
212
|
teacher_model_config = args.model_config.copy()
|
|
213
213
|
teacher_model_config.update({"drop_path_rate": 0.0})
|
|
@@ -215,7 +215,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
215
215
|
teacher_model_config = {"drop_path_rate": 0.0}
|
|
216
216
|
|
|
217
217
|
teacher_backbone = registry.net_factory(
|
|
218
|
-
args.network, sample_shape[1],
|
|
218
|
+
args.network, 0, sample_shape[1], config=teacher_model_config, size=args.size
|
|
219
219
|
)
|
|
220
220
|
student_backbone.set_dynamic_size()
|
|
221
221
|
if args.ibot_separate_head is False:
|
|
@@ -267,7 +267,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
267
267
|
|
|
268
268
|
if args.resume_epoch is not None:
|
|
269
269
|
begin_epoch = args.resume_epoch + 1
|
|
270
|
-
|
|
270
|
+
net, training_states = fs_ops.load_simple_checkpoint(
|
|
271
271
|
device, net, network_name, epoch=args.resume_epoch, strict=not args.non_strict_weights
|
|
272
272
|
)
|
|
273
273
|
student = net["student"]
|
|
@@ -336,11 +336,11 @@ def train(args: argparse.Namespace) -> None:
|
|
|
336
336
|
elif args.wds is True:
|
|
337
337
|
wds_path: str | list[str]
|
|
338
338
|
if args.wds_info is not None:
|
|
339
|
-
|
|
339
|
+
wds_path, dataset_size = wds_args_from_info(args.wds_info, args.wds_split)
|
|
340
340
|
if args.wds_size is not None:
|
|
341
341
|
dataset_size = args.wds_size
|
|
342
342
|
else:
|
|
343
|
-
|
|
343
|
+
wds_path, dataset_size = prepare_wds_args(args.data_path[0], args.wds_size, device)
|
|
344
344
|
|
|
345
345
|
training_dataset = make_wds_dataset(
|
|
346
346
|
wds_path,
|
|
@@ -366,7 +366,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
366
366
|
|
|
367
367
|
# Data loaders and samplers
|
|
368
368
|
virtual_epoch_mode = args.steps_per_epoch is not None
|
|
369
|
-
|
|
369
|
+
train_sampler, _ = training_utils.get_samplers(
|
|
370
370
|
args, training_dataset, validation_dataset=None, infinite=virtual_epoch_mode
|
|
371
371
|
)
|
|
372
372
|
|
|
@@ -466,7 +466,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
466
466
|
wd_schedule = None
|
|
467
467
|
|
|
468
468
|
# Gradient scaler and AMP related tasks
|
|
469
|
-
|
|
469
|
+
scaler, amp_dtype = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
|
|
470
470
|
|
|
471
471
|
# Load states
|
|
472
472
|
if args.load_states is True:
|
|
@@ -603,6 +603,19 @@ def train(args: argparse.Namespace) -> None:
|
|
|
603
603
|
tic = time.time()
|
|
604
604
|
net.train()
|
|
605
605
|
|
|
606
|
+
# Clear metrics
|
|
607
|
+
running_loss.clear()
|
|
608
|
+
running_loss_dino_local.clear()
|
|
609
|
+
running_loss_dino_global.clear()
|
|
610
|
+
running_loss_koleo.clear()
|
|
611
|
+
running_loss_ibot_patch.clear()
|
|
612
|
+
if track_extended_metrics is True:
|
|
613
|
+
train_proto_agreement.clear()
|
|
614
|
+
train_patch_agreement.clear()
|
|
615
|
+
running_target_entropy.clear()
|
|
616
|
+
running_dino_center_drift.clear()
|
|
617
|
+
running_ibot_center_drift.clear()
|
|
618
|
+
|
|
606
619
|
if args.sinkhorn_queue_size is not None:
|
|
607
620
|
queue_active = epoch > args.sinkhorn_queue_warmup_epochs
|
|
608
621
|
dino_loss.set_queue_active(queue_active)
|
|
@@ -661,7 +674,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
661
674
|
with torch.amp.autocast("cuda", enabled=args.amp, dtype=amp_dtype):
|
|
662
675
|
with torch.no_grad():
|
|
663
676
|
# Teacher
|
|
664
|
-
|
|
677
|
+
teacher_embedding_after_head, teacher_masked_patch_tokens_after_head = teacher(
|
|
665
678
|
global_crops, n_global_crops, upper_bound, mask_indices_list
|
|
666
679
|
)
|
|
667
680
|
teacher_patch_tokens_raw = teacher_masked_patch_tokens_after_head
|
|
@@ -671,7 +684,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
671
684
|
prev_dino_center = dino_loss.center.clone()
|
|
672
685
|
prev_ibot_center = ibot_patch_loss.center.clone()
|
|
673
686
|
|
|
674
|
-
|
|
687
|
+
teacher_dino_softmax_centered = dino_loss.softmax_center_teacher(
|
|
675
688
|
teacher_embedding_after_head, teacher_temp=teacher_temp
|
|
676
689
|
).view(n_global_crops, -1, *teacher_embedding_after_head.shape[1:])
|
|
677
690
|
dino_loss.update_center(teacher_embedding_after_head)
|
|
@@ -684,7 +697,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
684
697
|
ibot_patch_loss.update_center(teacher_masked_patch_tokens_after_head[:, :n_masked_patches])
|
|
685
698
|
|
|
686
699
|
else: # sinkhorn_knopp
|
|
687
|
-
|
|
700
|
+
teacher_dino_softmax_centered = dino_loss.sinkhorn_knopp_teacher(
|
|
688
701
|
teacher_embedding_after_head, teacher_temp=teacher_temp
|
|
689
702
|
).view(n_global_crops, -1, *teacher_embedding_after_head.shape[1:])
|
|
690
703
|
|
|
@@ -705,7 +718,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
705
718
|
# Local DINO loss
|
|
706
719
|
loss_dino_local_crops = dino_loss(
|
|
707
720
|
student_local_embedding_after_head.chunk(n_local_crops),
|
|
708
|
-
|
|
721
|
+
teacher_dino_softmax_centered.unbind(0),
|
|
709
722
|
) / (n_global_crops_loss_terms + n_local_crops_loss_terms)
|
|
710
723
|
loss = args.dino_loss_weight * loss_dino_local_crops
|
|
711
724
|
|
|
@@ -715,7 +728,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
715
728
|
dino_loss(
|
|
716
729
|
[student_global_embedding_after_head],
|
|
717
730
|
[
|
|
718
|
-
|
|
731
|
+
teacher_dino_softmax_centered.flatten(0, 1)
|
|
719
732
|
], # These were chunked and stacked in reverse so A is matched to B
|
|
720
733
|
)
|
|
721
734
|
* loss_scales
|
|
@@ -809,7 +822,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
809
822
|
train_patch_agreement.update(training_utils.accuracy(pred_patch_teacher, pred_patch_student))
|
|
810
823
|
|
|
811
824
|
with torch.no_grad():
|
|
812
|
-
p =
|
|
825
|
+
p = teacher_dino_softmax_centered.detach()
|
|
813
826
|
p = p.reshape(-1, p.size(-1)) # (N, D)
|
|
814
827
|
|
|
815
828
|
# Mean distribution over prototypes (marginal)
|
|
@@ -179,7 +179,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
179
179
|
#
|
|
180
180
|
# Initialize
|
|
181
181
|
#
|
|
182
|
-
|
|
182
|
+
device, device_id, disable_tqdm = training_utils.init_training(args, logger)
|
|
183
183
|
|
|
184
184
|
if args.size is None:
|
|
185
185
|
args.size = registry.get_default_size(args.network)
|
|
@@ -208,17 +208,17 @@ def train(args: argparse.Namespace) -> None:
|
|
|
208
208
|
|
|
209
209
|
network_name = get_mim_network_name("dino_v2_dist", encoder=args.network, tag=args.tag)
|
|
210
210
|
|
|
211
|
-
student_backbone = registry.net_factory(args.network, sample_shape[1],
|
|
211
|
+
student_backbone = registry.net_factory(args.network, 0, sample_shape[1], config=args.model_config, size=args.size)
|
|
212
212
|
student_backbone_ema = registry.net_factory(
|
|
213
|
-
args.network, sample_shape[1],
|
|
213
|
+
args.network, 0, sample_shape[1], config=args.model_config, size=args.size
|
|
214
214
|
)
|
|
215
215
|
student_backbone_ema.load_state_dict(student_backbone.state_dict())
|
|
216
216
|
student_backbone_ema.requires_grad_(False)
|
|
217
217
|
|
|
218
218
|
teacher_backbone = registry.net_factory(
|
|
219
219
|
args.teacher,
|
|
220
|
-
sample_shape[1],
|
|
221
220
|
0,
|
|
221
|
+
sample_shape[1],
|
|
222
222
|
config=args.teacher_model_config,
|
|
223
223
|
size=args.size,
|
|
224
224
|
)
|
|
@@ -277,7 +277,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
277
277
|
|
|
278
278
|
if args.resume_epoch is not None:
|
|
279
279
|
begin_epoch = args.resume_epoch + 1
|
|
280
|
-
|
|
280
|
+
net, training_states = fs_ops.load_simple_checkpoint(
|
|
281
281
|
device, net, network_name, epoch=args.resume_epoch, strict=not args.non_strict_weights
|
|
282
282
|
)
|
|
283
283
|
student = net["student"]
|
|
@@ -358,11 +358,11 @@ def train(args: argparse.Namespace) -> None:
|
|
|
358
358
|
elif args.wds is True:
|
|
359
359
|
wds_path: str | list[str]
|
|
360
360
|
if args.wds_info is not None:
|
|
361
|
-
|
|
361
|
+
wds_path, dataset_size = wds_args_from_info(args.wds_info, args.wds_split)
|
|
362
362
|
if args.wds_size is not None:
|
|
363
363
|
dataset_size = args.wds_size
|
|
364
364
|
else:
|
|
365
|
-
|
|
365
|
+
wds_path, dataset_size = prepare_wds_args(args.data_path[0], args.wds_size, device)
|
|
366
366
|
|
|
367
367
|
training_dataset = make_wds_dataset(
|
|
368
368
|
wds_path,
|
|
@@ -388,7 +388,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
388
388
|
|
|
389
389
|
# Data loaders and samplers
|
|
390
390
|
virtual_epoch_mode = args.steps_per_epoch is not None
|
|
391
|
-
|
|
391
|
+
train_sampler, _ = training_utils.get_samplers(
|
|
392
392
|
args, training_dataset, validation_dataset=None, infinite=virtual_epoch_mode
|
|
393
393
|
)
|
|
394
394
|
|
|
@@ -487,7 +487,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
487
487
|
wd_schedule = None
|
|
488
488
|
|
|
489
489
|
# Gradient scaler and AMP related tasks
|
|
490
|
-
|
|
490
|
+
scaler, amp_dtype = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
|
|
491
491
|
|
|
492
492
|
# Load states
|
|
493
493
|
if args.load_states is True:
|
|
@@ -625,6 +625,19 @@ def train(args: argparse.Namespace) -> None:
|
|
|
625
625
|
net.train()
|
|
626
626
|
teacher.eval()
|
|
627
627
|
|
|
628
|
+
# Clear metrics
|
|
629
|
+
running_loss.clear()
|
|
630
|
+
running_loss_dino_local.clear()
|
|
631
|
+
running_loss_dino_global.clear()
|
|
632
|
+
running_loss_koleo.clear()
|
|
633
|
+
running_loss_ibot_patch.clear()
|
|
634
|
+
if track_extended_metrics is True:
|
|
635
|
+
train_proto_agreement.clear()
|
|
636
|
+
train_patch_agreement.clear()
|
|
637
|
+
running_target_entropy.clear()
|
|
638
|
+
running_dino_center_drift.clear()
|
|
639
|
+
running_ibot_center_drift.clear()
|
|
640
|
+
|
|
628
641
|
if args.sinkhorn_queue_size is not None:
|
|
629
642
|
queue_active = epoch > args.sinkhorn_queue_warmup_epochs
|
|
630
643
|
dino_loss.set_queue_active(queue_active)
|
|
@@ -682,7 +695,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
682
695
|
with torch.amp.autocast("cuda", enabled=args.amp, dtype=amp_dtype):
|
|
683
696
|
with torch.no_grad():
|
|
684
697
|
# Teacher
|
|
685
|
-
|
|
698
|
+
teacher_embedding_after_head, teacher_masked_patch_tokens_after_head = teacher(
|
|
686
699
|
global_crops, n_global_crops, upper_bound, mask_indices_list
|
|
687
700
|
)
|
|
688
701
|
teacher_patch_tokens_raw = teacher_masked_patch_tokens_after_head
|
|
@@ -692,7 +705,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
692
705
|
prev_dino_center = dino_loss.center.clone()
|
|
693
706
|
prev_ibot_center = ibot_patch_loss.center.clone()
|
|
694
707
|
|
|
695
|
-
|
|
708
|
+
teacher_dino_softmax_centered = dino_loss.softmax_center_teacher(
|
|
696
709
|
teacher_embedding_after_head, teacher_temp=teacher_temp
|
|
697
710
|
).view(n_global_crops, -1, *teacher_embedding_after_head.shape[1:])
|
|
698
711
|
dino_loss.update_center(teacher_embedding_after_head)
|
|
@@ -705,7 +718,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
705
718
|
ibot_patch_loss.update_center(teacher_masked_patch_tokens_after_head[:, :n_masked_patches])
|
|
706
719
|
|
|
707
720
|
else: # sinkhorn_knopp
|
|
708
|
-
|
|
721
|
+
teacher_dino_softmax_centered = dino_loss.sinkhorn_knopp_teacher(
|
|
709
722
|
teacher_embedding_after_head, teacher_temp=teacher_temp
|
|
710
723
|
).view(n_global_crops, -1, *teacher_embedding_after_head.shape[1:])
|
|
711
724
|
|
|
@@ -726,7 +739,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
726
739
|
# Local DINO loss
|
|
727
740
|
loss_dino_local_crops = dino_loss(
|
|
728
741
|
student_local_embedding_after_head.chunk(n_local_crops),
|
|
729
|
-
|
|
742
|
+
teacher_dino_softmax_centered.unbind(0),
|
|
730
743
|
) / (n_global_crops_loss_terms + n_local_crops_loss_terms)
|
|
731
744
|
loss = args.dino_loss_weight * loss_dino_local_crops
|
|
732
745
|
|
|
@@ -736,7 +749,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
736
749
|
dino_loss(
|
|
737
750
|
[student_global_embedding_after_head],
|
|
738
751
|
[
|
|
739
|
-
|
|
752
|
+
teacher_dino_softmax_centered.flatten(0, 1)
|
|
740
753
|
], # These were chunked and stacked in reverse so A is matched to B
|
|
741
754
|
)
|
|
742
755
|
* loss_scales
|
|
@@ -830,7 +843,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
830
843
|
train_patch_agreement.update(training_utils.accuracy(pred_patch_teacher, pred_patch_student))
|
|
831
844
|
|
|
832
845
|
with torch.no_grad():
|
|
833
|
-
p =
|
|
846
|
+
p = teacher_dino_softmax_centered.detach()
|
|
834
847
|
p = p.reshape(-1, p.size(-1)) # (N, D)
|
|
835
848
|
|
|
836
849
|
# Mean distribution over prototypes (marginal)
|
birder/scripts/train_franca.py
CHANGED
|
@@ -205,7 +205,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
205
205
|
#
|
|
206
206
|
# Initialize
|
|
207
207
|
#
|
|
208
|
-
|
|
208
|
+
device, device_id, disable_tqdm = training_utils.init_training(args, logger)
|
|
209
209
|
|
|
210
210
|
if args.size is None:
|
|
211
211
|
args.size = registry.get_default_size(args.network)
|
|
@@ -234,7 +234,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
234
234
|
|
|
235
235
|
network_name = get_mim_network_name("franca", encoder=args.network, tag=args.tag)
|
|
236
236
|
|
|
237
|
-
student_backbone = registry.net_factory(args.network, sample_shape[1],
|
|
237
|
+
student_backbone = registry.net_factory(args.network, 0, sample_shape[1], config=args.model_config, size=args.size)
|
|
238
238
|
if args.model_config is not None:
|
|
239
239
|
teacher_model_config = args.model_config.copy()
|
|
240
240
|
teacher_model_config.update({"drop_path_rate": 0.0})
|
|
@@ -242,7 +242,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
242
242
|
teacher_model_config = {"drop_path_rate": 0.0}
|
|
243
243
|
|
|
244
244
|
teacher_backbone = registry.net_factory(
|
|
245
|
-
args.network, sample_shape[1],
|
|
245
|
+
args.network, 0, sample_shape[1], config=teacher_model_config, size=args.size
|
|
246
246
|
)
|
|
247
247
|
student_backbone.set_dynamic_size()
|
|
248
248
|
if args.ibot_separate_head is False:
|
|
@@ -296,7 +296,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
296
296
|
|
|
297
297
|
if args.resume_epoch is not None:
|
|
298
298
|
begin_epoch = args.resume_epoch + 1
|
|
299
|
-
|
|
299
|
+
net, training_states = fs_ops.load_simple_checkpoint(
|
|
300
300
|
device, net, network_name, epoch=args.resume_epoch, strict=not args.non_strict_weights
|
|
301
301
|
)
|
|
302
302
|
student = net["student"]
|
|
@@ -363,11 +363,11 @@ def train(args: argparse.Namespace) -> None:
|
|
|
363
363
|
elif args.wds is True:
|
|
364
364
|
wds_path: str | list[str]
|
|
365
365
|
if args.wds_info is not None:
|
|
366
|
-
|
|
366
|
+
wds_path, dataset_size = wds_args_from_info(args.wds_info, args.wds_split)
|
|
367
367
|
if args.wds_size is not None:
|
|
368
368
|
dataset_size = args.wds_size
|
|
369
369
|
else:
|
|
370
|
-
|
|
370
|
+
wds_path, dataset_size = prepare_wds_args(args.data_path[0], args.wds_size, device)
|
|
371
371
|
|
|
372
372
|
training_dataset = make_wds_dataset(
|
|
373
373
|
wds_path,
|
|
@@ -393,7 +393,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
393
393
|
|
|
394
394
|
# Data loaders and samplers
|
|
395
395
|
virtual_epoch_mode = args.steps_per_epoch is not None
|
|
396
|
-
|
|
396
|
+
train_sampler, _ = training_utils.get_samplers(
|
|
397
397
|
args, training_dataset, validation_dataset=None, infinite=virtual_epoch_mode
|
|
398
398
|
)
|
|
399
399
|
|
|
@@ -493,7 +493,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
493
493
|
wd_schedule = None
|
|
494
494
|
|
|
495
495
|
# Gradient scaler and AMP related tasks
|
|
496
|
-
|
|
496
|
+
scaler, amp_dtype = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
|
|
497
497
|
|
|
498
498
|
# Load states
|
|
499
499
|
if args.load_states is True:
|
|
@@ -623,6 +623,13 @@ def train(args: argparse.Namespace) -> None:
|
|
|
623
623
|
tic = time.time()
|
|
624
624
|
net.train()
|
|
625
625
|
|
|
626
|
+
# Clear metrics
|
|
627
|
+
running_loss.clear()
|
|
628
|
+
running_loss_dino_local.clear()
|
|
629
|
+
running_loss_dino_global.clear()
|
|
630
|
+
running_loss_koleo.clear()
|
|
631
|
+
running_loss_ibot_patch.clear()
|
|
632
|
+
|
|
626
633
|
if args.sinkhorn_queue_size is not None:
|
|
627
634
|
queue_active = epoch > args.sinkhorn_queue_warmup_epochs
|
|
628
635
|
dino_loss.set_queue_active(queue_active)
|
|
@@ -681,7 +688,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
681
688
|
with torch.amp.autocast("cuda", enabled=args.amp, dtype=amp_dtype):
|
|
682
689
|
with torch.no_grad():
|
|
683
690
|
# Teacher
|
|
684
|
-
|
|
691
|
+
teacher_embedding_after_head, teacher_masked_patch_tokens_after_head = teacher(
|
|
685
692
|
global_crops, n_global_crops, upper_bound, mask_indices_list
|
|
686
693
|
)
|
|
687
694
|
|