birder 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/common/lib.py +2 -9
- birder/common/training_cli.py +18 -0
- birder/common/training_utils.py +123 -10
- birder/data/collators/detection.py +10 -3
- birder/data/datasets/coco.py +8 -10
- birder/data/transforms/detection.py +30 -13
- birder/inference/detection.py +108 -4
- birder/inference/wbf.py +226 -0
- birder/net/__init__.py +8 -0
- birder/net/detection/efficientdet.py +65 -86
- birder/net/detection/rt_detr_v1.py +1 -0
- birder/net/detection/yolo_anchors.py +205 -0
- birder/net/detection/yolo_v2.py +25 -24
- birder/net/detection/yolo_v3.py +39 -40
- birder/net/detection/yolo_v4.py +28 -26
- birder/net/detection/yolo_v4_tiny.py +24 -20
- birder/net/fasternet.py +1 -1
- birder/net/gc_vit.py +671 -0
- birder/net/lit_v1.py +472 -0
- birder/net/lit_v1_tiny.py +342 -0
- birder/net/lit_v2.py +436 -0
- birder/net/mobilenet_v4_hybrid.py +1 -1
- birder/net/resnet_v1.py +1 -1
- birder/net/resnext.py +67 -25
- birder/net/se_resnet_v1.py +46 -0
- birder/net/se_resnext.py +3 -0
- birder/net/simple_vit.py +2 -2
- birder/net/vit.py +0 -15
- birder/net/vovnet_v2.py +31 -1
- birder/scripts/benchmark.py +90 -21
- birder/scripts/predict.py +1 -0
- birder/scripts/predict_detection.py +18 -11
- birder/scripts/train.py +10 -34
- birder/scripts/train_barlow_twins.py +10 -34
- birder/scripts/train_byol.py +10 -34
- birder/scripts/train_capi.py +10 -35
- birder/scripts/train_data2vec.py +9 -34
- birder/scripts/train_data2vec2.py +9 -34
- birder/scripts/train_detection.py +48 -40
- birder/scripts/train_dino_v1.py +10 -34
- birder/scripts/train_dino_v2.py +9 -34
- birder/scripts/train_dino_v2_dist.py +9 -34
- birder/scripts/train_franca.py +9 -34
- birder/scripts/train_i_jepa.py +9 -34
- birder/scripts/train_ibot.py +9 -34
- birder/scripts/train_kd.py +156 -64
- birder/scripts/train_mim.py +10 -34
- birder/scripts/train_mmcr.py +10 -34
- birder/scripts/train_rotnet.py +10 -34
- birder/scripts/train_simclr.py +10 -34
- birder/scripts/train_vicreg.py +10 -34
- birder/tools/auto_anchors.py +20 -1
- birder/tools/pack.py +172 -103
- birder/tools/show_det_iterator.py +10 -1
- birder/version.py +1 -1
- {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/METADATA +3 -3
- {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/RECORD +61 -55
- {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/WHEEL +0 -0
- {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/entry_points.txt +0 -0
- {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/top_level.txt +0 -0
birder/scripts/train_kd.py
CHANGED
|
@@ -4,6 +4,7 @@ Supports:
|
|
|
4
4
|
* Logits matching (Soft distillation), https://arxiv.org/abs/1503.02531
|
|
5
5
|
* Hard-label distillation, https://arxiv.org/pdf/2012.12877
|
|
6
6
|
* Distillation token, https://arxiv.org/pdf/2012.12877
|
|
7
|
+
* Embedding matching (L2-normalized MSE)
|
|
7
8
|
"""
|
|
8
9
|
|
|
9
10
|
import argparse
|
|
@@ -16,6 +17,7 @@ import typing
|
|
|
16
17
|
from pathlib import Path
|
|
17
18
|
from typing import Any
|
|
18
19
|
from typing import Literal
|
|
20
|
+
from typing import Optional
|
|
19
21
|
|
|
20
22
|
import matplotlib.pyplot as plt
|
|
21
23
|
import numpy as np
|
|
@@ -39,7 +41,6 @@ from birder.common import training_cli
|
|
|
39
41
|
from birder.common import training_utils
|
|
40
42
|
from birder.common.lib import format_duration
|
|
41
43
|
from birder.common.lib import get_network_name
|
|
42
|
-
from birder.common.lib import set_random_seeds
|
|
43
44
|
from birder.conf import settings
|
|
44
45
|
from birder.data.dataloader.webdataset import make_wds_loader
|
|
45
46
|
from birder.data.datasets.directory import HierarchicalImageFolder
|
|
@@ -55,7 +56,18 @@ from birder.net.base import get_signature
|
|
|
55
56
|
|
|
56
57
|
logger = logging.getLogger(__name__)
|
|
57
58
|
|
|
58
|
-
DistType = Literal["soft", "hard", "deit"]
|
|
59
|
+
DistType = Literal["soft", "hard", "deit", "embedding"]
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class EmbeddingDistillWrapper(torch.nn.Module):
|
|
63
|
+
def __init__(self, model: torch.nn.Module) -> None:
|
|
64
|
+
super().__init__()
|
|
65
|
+
self.model = model
|
|
66
|
+
|
|
67
|
+
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
68
|
+
embedding = self.model.embedding(x)
|
|
69
|
+
outputs = self.model.classify(embedding)
|
|
70
|
+
return (outputs, embedding)
|
|
59
71
|
|
|
60
72
|
|
|
61
73
|
# pylint: disable=too-many-locals,too-many-branches,too-many-statements
|
|
@@ -63,41 +75,11 @@ def train(args: argparse.Namespace) -> None:
|
|
|
63
75
|
#
|
|
64
76
|
# Initialize
|
|
65
77
|
#
|
|
66
|
-
training_utils.
|
|
67
|
-
logger.info(f"Starting training, birder version: {birder.__version__}, pytorch version: {torch.__version__}")
|
|
68
|
-
training_utils.log_git_info()
|
|
78
|
+
(device, device_id, disable_tqdm) = training_utils.init_training(args, logger)
|
|
69
79
|
|
|
70
80
|
if args.type != "soft":
|
|
71
81
|
args.temperature = 1.0
|
|
72
82
|
|
|
73
|
-
logger.info(f"Using size={args.size}")
|
|
74
|
-
|
|
75
|
-
if args.cpu is True:
|
|
76
|
-
device = torch.device("cpu")
|
|
77
|
-
device_id = 0
|
|
78
|
-
else:
|
|
79
|
-
device = torch.device("cuda")
|
|
80
|
-
device_id = torch.cuda.current_device()
|
|
81
|
-
|
|
82
|
-
if args.use_deterministic_algorithms is True:
|
|
83
|
-
torch.backends.cudnn.benchmark = False
|
|
84
|
-
torch.use_deterministic_algorithms(True)
|
|
85
|
-
else:
|
|
86
|
-
torch.backends.cudnn.benchmark = True
|
|
87
|
-
|
|
88
|
-
if args.seed is not None:
|
|
89
|
-
set_random_seeds(args.seed)
|
|
90
|
-
|
|
91
|
-
if args.non_interactive is True or training_utils.is_local_primary(args) is False:
|
|
92
|
-
disable_tqdm = True
|
|
93
|
-
elif sys.stderr.isatty() is False:
|
|
94
|
-
disable_tqdm = True
|
|
95
|
-
else:
|
|
96
|
-
disable_tqdm = False
|
|
97
|
-
|
|
98
|
-
# Enable or disable the autograd anomaly detection
|
|
99
|
-
torch.autograd.set_detect_anomaly(args.grad_anomaly_detection)
|
|
100
|
-
|
|
101
83
|
# Using the teacher rgb values for the student
|
|
102
84
|
(teacher, (class_to_idx, signature, rgb_stats, *_)) = fs_ops.load_model(
|
|
103
85
|
device,
|
|
@@ -112,7 +94,8 @@ def train(args: argparse.Namespace) -> None:
|
|
|
112
94
|
)
|
|
113
95
|
if args.size is None:
|
|
114
96
|
args.size = lib.get_size_from_signature(signature)
|
|
115
|
-
|
|
97
|
+
|
|
98
|
+
logger.info(f"Using size={args.size}")
|
|
116
99
|
|
|
117
100
|
#
|
|
118
101
|
# Data
|
|
@@ -188,7 +171,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
188
171
|
batch_size: int = args.batch_size
|
|
189
172
|
grad_accum_steps: int = args.grad_accum_steps
|
|
190
173
|
model_ema_steps: int = args.model_ema_steps
|
|
191
|
-
logger.debug(f"Effective batch size = {
|
|
174
|
+
logger.debug(f"Effective batch size = {batch_size * grad_accum_steps * args.world_size}")
|
|
192
175
|
|
|
193
176
|
# Set data iterators
|
|
194
177
|
if args.mixup_alpha is not None or args.cutmix is True:
|
|
@@ -258,6 +241,8 @@ def train(args: argparse.Namespace) -> None:
|
|
|
258
241
|
else:
|
|
259
242
|
args.stop_epoch += 1
|
|
260
243
|
|
|
244
|
+
logging.debug(f"Epoch has {last_batch_idx+1} iterations ({optimizer_steps_per_epoch} steps)")
|
|
245
|
+
|
|
261
246
|
#
|
|
262
247
|
# Initialize networks
|
|
263
248
|
#
|
|
@@ -298,33 +283,61 @@ def train(args: argparse.Namespace) -> None:
|
|
|
298
283
|
if args.fast_matmul is True or args.amp is True:
|
|
299
284
|
torch.set_float32_matmul_precision("high")
|
|
300
285
|
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
student
|
|
305
|
-
|
|
306
|
-
|
|
286
|
+
distillation_type: DistType = args.type
|
|
287
|
+
embedding_projection: Optional[torch.nn.Module] = None
|
|
288
|
+
if distillation_type == "embedding":
|
|
289
|
+
if student.embedding_size == teacher.embedding_size:
|
|
290
|
+
embedding_projection = torch.nn.Identity()
|
|
291
|
+
else:
|
|
292
|
+
logger.info(
|
|
293
|
+
f"Creating embedding projection layer from {student.embedding_size} to {teacher.embedding_size}"
|
|
294
|
+
)
|
|
295
|
+
embedding_projection = torch.nn.Linear(student.embedding_size, teacher.embedding_size)
|
|
296
|
+
|
|
297
|
+
embedding_projection.to(device, dtype=model_dtype)
|
|
298
|
+
if training_states.extra_states is not None:
|
|
299
|
+
projection_state = training_states.extra_states.get("embedding_projection")
|
|
300
|
+
if projection_state is not None:
|
|
301
|
+
embedding_projection.load_state_dict(projection_state)
|
|
307
302
|
|
|
308
303
|
#
|
|
309
304
|
# Loss criteria, optimizer, learning rate scheduler and training parameter groups
|
|
310
305
|
#
|
|
311
306
|
|
|
307
|
+
# Learning rate scaling
|
|
308
|
+
lr = training_utils.scale_lr(args)
|
|
309
|
+
|
|
312
310
|
# Training parameter groups and loss criteria
|
|
313
311
|
custom_keys_weight_decay = training_utils.get_wd_custom_keys(args)
|
|
314
312
|
parameters = training_utils.optimizer_parameter_groups(
|
|
315
313
|
student,
|
|
316
314
|
args.wd,
|
|
315
|
+
base_lr=lr,
|
|
317
316
|
norm_weight_decay=args.norm_wd,
|
|
318
317
|
custom_keys_weight_decay=custom_keys_weight_decay,
|
|
318
|
+
custom_layer_weight_decay=args.custom_layer_wd,
|
|
319
319
|
layer_decay=args.layer_decay,
|
|
320
320
|
layer_decay_min_scale=args.layer_decay_min_scale,
|
|
321
321
|
layer_decay_no_opt_scale=args.layer_decay_no_opt_scale,
|
|
322
322
|
bias_lr=args.bias_lr,
|
|
323
|
+
custom_layer_lr_scale=args.custom_layer_lr_scale,
|
|
323
324
|
)
|
|
325
|
+
if embedding_projection is not None:
|
|
326
|
+
projection_parameters = training_utils.optimizer_parameter_groups(
|
|
327
|
+
embedding_projection,
|
|
328
|
+
args.wd,
|
|
329
|
+
base_lr=lr,
|
|
330
|
+
norm_weight_decay=args.norm_wd,
|
|
331
|
+
custom_keys_weight_decay=custom_keys_weight_decay,
|
|
332
|
+
custom_layer_weight_decay=args.custom_layer_wd,
|
|
333
|
+
bias_lr=args.bias_lr,
|
|
334
|
+
custom_layer_lr_scale=args.custom_layer_lr_scale,
|
|
335
|
+
)
|
|
336
|
+
parameters.extend(projection_parameters)
|
|
337
|
+
|
|
324
338
|
criterion = torch.nn.CrossEntropyLoss(label_smoothing=args.smoothing_alpha)
|
|
325
339
|
|
|
326
340
|
# Distillation
|
|
327
|
-
distillation_type: DistType = args.type
|
|
328
341
|
if distillation_type == "soft":
|
|
329
342
|
distillation_criterion = torch.nn.KLDivLoss(reduction="batchmean", log_target=False)
|
|
330
343
|
elif distillation_type == "hard":
|
|
@@ -332,11 +345,11 @@ def train(args: argparse.Namespace) -> None:
|
|
|
332
345
|
elif distillation_type == "deit":
|
|
333
346
|
distillation_criterion = torch.nn.CrossEntropyLoss()
|
|
334
347
|
student.set_distillation_output()
|
|
348
|
+
elif distillation_type == "embedding":
|
|
349
|
+
distillation_criterion = torch.nn.MSELoss()
|
|
335
350
|
else:
|
|
336
351
|
raise ValueError(f"Unknown KD type: {args.type}")
|
|
337
352
|
|
|
338
|
-
# Learning rate scaling
|
|
339
|
-
lr = training_utils.scale_lr(args)
|
|
340
353
|
if args.lr_scheduler_update == "epoch":
|
|
341
354
|
step_update = False
|
|
342
355
|
scheduler_steps_per_epoch = 1
|
|
@@ -398,12 +411,50 @@ def train(args: argparse.Namespace) -> None:
|
|
|
398
411
|
ema_warmup_steps = 0
|
|
399
412
|
|
|
400
413
|
logger.debug(f"EMA warmup steps = {ema_warmup_steps}")
|
|
414
|
+
train_student = student
|
|
415
|
+
if distillation_type == "embedding":
|
|
416
|
+
train_student = EmbeddingDistillWrapper(student)
|
|
417
|
+
|
|
418
|
+
# Compile networks
|
|
419
|
+
if args.compile is True:
|
|
420
|
+
train_student = torch.compile(train_student)
|
|
421
|
+
if distillation_type == "embedding":
|
|
422
|
+
teacher.embedding = torch.compile(teacher.embedding)
|
|
423
|
+
embedding_projection = torch.compile(embedding_projection)
|
|
424
|
+
student = torch.compile(student) # For validation
|
|
425
|
+
else:
|
|
426
|
+
teacher = torch.compile(teacher)
|
|
427
|
+
student = train_student
|
|
428
|
+
|
|
429
|
+
elif args.compile_teacher is True:
|
|
430
|
+
if distillation_type == "embedding":
|
|
431
|
+
teacher.embedding = torch.compile(teacher.embedding)
|
|
432
|
+
else:
|
|
433
|
+
teacher = torch.compile(teacher)
|
|
434
|
+
|
|
401
435
|
net_without_ddp = student
|
|
402
436
|
if args.distributed is True:
|
|
403
|
-
|
|
404
|
-
|
|
437
|
+
train_student = torch.nn.parallel.DistributedDataParallel(
|
|
438
|
+
train_student, device_ids=[args.local_rank], find_unused_parameters=args.find_unused_parameters
|
|
405
439
|
)
|
|
406
|
-
|
|
440
|
+
if distillation_type != "embedding":
|
|
441
|
+
net_without_ddp = train_student.module
|
|
442
|
+
|
|
443
|
+
embedding_projection_to_save = None
|
|
444
|
+
if embedding_projection is not None:
|
|
445
|
+
if args.distributed is True and any(p.requires_grad for p in embedding_projection.parameters()):
|
|
446
|
+
embedding_projection = torch.nn.parallel.DistributedDataParallel(
|
|
447
|
+
embedding_projection,
|
|
448
|
+
device_ids=[args.local_rank],
|
|
449
|
+
find_unused_parameters=args.find_unused_parameters,
|
|
450
|
+
)
|
|
451
|
+
embedding_projection_to_save = embedding_projection.module
|
|
452
|
+
else:
|
|
453
|
+
embedding_projection_to_save = embedding_projection
|
|
454
|
+
|
|
455
|
+
# Unwrap compiled module for saving
|
|
456
|
+
if hasattr(embedding_projection_to_save, "_orig_mod"):
|
|
457
|
+
embedding_projection_to_save = embedding_projection_to_save._orig_mod # pylint: disable=protected-access
|
|
407
458
|
|
|
408
459
|
if args.model_ema is True:
|
|
409
460
|
model_base = net_without_ddp # Original model without DDP wrapper, will be saved as training state
|
|
@@ -499,7 +550,10 @@ def train(args: argparse.Namespace) -> None:
|
|
|
499
550
|
logger.info(f"Starting training with learning rate of {last_lr}")
|
|
500
551
|
for epoch in range(begin_epoch, args.stop_epoch):
|
|
501
552
|
tic = time.time()
|
|
502
|
-
|
|
553
|
+
train_student.train()
|
|
554
|
+
if embedding_projection is not None:
|
|
555
|
+
embedding_projection.train()
|
|
556
|
+
|
|
503
557
|
running_loss = training_utils.SmoothedValue(window_size=64)
|
|
504
558
|
running_val_loss = training_utils.SmoothedValue()
|
|
505
559
|
train_accuracy = training_utils.SmoothedValue(window_size=64)
|
|
@@ -531,22 +585,37 @@ def train(args: argparse.Namespace) -> None:
|
|
|
531
585
|
|
|
532
586
|
# Forward, backward and optimize
|
|
533
587
|
with torch.amp.autocast("cuda", enabled=args.amp, dtype=amp_dtype):
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
outputs =
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
elif distillation_type == "deit":
|
|
545
|
-
(outputs, dist_output) = torch.unbind(student(inputs), dim=1)
|
|
588
|
+
if distillation_type == "embedding":
|
|
589
|
+
with torch.no_grad():
|
|
590
|
+
teacher_embedding = teacher.embedding(inputs)
|
|
591
|
+
teacher_embedding = F.normalize(teacher_embedding, dim=-1)
|
|
592
|
+
|
|
593
|
+
(outputs, student_embedding) = train_student(inputs)
|
|
594
|
+
student_embedding = embedding_projection(student_embedding) # type: ignore[misc]
|
|
595
|
+
student_embedding = F.normalize(student_embedding, dim=-1)
|
|
596
|
+
dist_loss = distillation_criterion(student_embedding, teacher_embedding)
|
|
597
|
+
|
|
546
598
|
else:
|
|
547
|
-
|
|
599
|
+
with torch.no_grad():
|
|
600
|
+
teacher_outputs = teacher(inputs)
|
|
601
|
+
if distillation_type == "soft":
|
|
602
|
+
teacher_targets = F.softmax(teacher_outputs / args.temperature, dim=-1)
|
|
603
|
+
else:
|
|
604
|
+
teacher_targets = teacher_outputs.argmax(dim=-1)
|
|
605
|
+
|
|
606
|
+
if distillation_type == "soft":
|
|
607
|
+
outputs = train_student(inputs)
|
|
608
|
+
dist_output = F.log_softmax(outputs / args.temperature, dim=-1)
|
|
609
|
+
dist_loss = distillation_criterion(dist_output, teacher_targets) * (args.temperature**2)
|
|
610
|
+
elif distillation_type == "hard":
|
|
611
|
+
outputs = train_student(inputs)
|
|
612
|
+
dist_loss = distillation_criterion(outputs, teacher_targets)
|
|
613
|
+
elif distillation_type == "deit":
|
|
614
|
+
(outputs, dist_output) = torch.unbind(train_student(inputs), dim=1)
|
|
615
|
+
dist_loss = distillation_criterion(dist_output, teacher_targets)
|
|
616
|
+
else:
|
|
617
|
+
raise RuntimeError
|
|
548
618
|
|
|
549
|
-
dist_loss = distillation_criterion(dist_output, softmax_teacher) * (args.temperature**2)
|
|
550
619
|
target_loss = criterion(outputs, targets)
|
|
551
620
|
loss = (1 - args.lambda_param) * target_loss + (args.lambda_param * dist_loss)
|
|
552
621
|
|
|
@@ -555,7 +624,11 @@ def train(args: argparse.Namespace) -> None:
|
|
|
555
624
|
if optimizer_update is True:
|
|
556
625
|
if args.clip_grad_norm is not None:
|
|
557
626
|
scaler.unscale_(optimizer)
|
|
558
|
-
|
|
627
|
+
params = list(train_student.parameters())
|
|
628
|
+
if embedding_projection is not None:
|
|
629
|
+
params += list(embedding_projection.parameters())
|
|
630
|
+
|
|
631
|
+
torch.nn.utils.clip_grad_norm_(params, args.clip_grad_norm)
|
|
559
632
|
|
|
560
633
|
scaler.step(optimizer)
|
|
561
634
|
scaler.update()
|
|
@@ -567,7 +640,11 @@ def train(args: argparse.Namespace) -> None:
|
|
|
567
640
|
loss.backward()
|
|
568
641
|
if optimizer_update is True:
|
|
569
642
|
if args.clip_grad_norm is not None:
|
|
570
|
-
|
|
643
|
+
params = list(train_student.parameters())
|
|
644
|
+
if embedding_projection is not None:
|
|
645
|
+
params += list(embedding_projection.parameters())
|
|
646
|
+
|
|
647
|
+
torch.nn.utils.clip_grad_norm_(params, args.clip_grad_norm)
|
|
571
648
|
|
|
572
649
|
optimizer.step()
|
|
573
650
|
optimizer.zero_grad()
|
|
@@ -710,6 +787,10 @@ def train(args: argparse.Namespace) -> None:
|
|
|
710
787
|
if training_utils.is_local_primary(args) is True:
|
|
711
788
|
# Checkpoint model
|
|
712
789
|
if epoch % args.save_frequency == 0:
|
|
790
|
+
extra_states = {}
|
|
791
|
+
if embedding_projection_to_save is not None:
|
|
792
|
+
extra_states["embedding_projection"] = embedding_projection_to_save.state_dict()
|
|
793
|
+
|
|
713
794
|
fs_ops.checkpoint_model(
|
|
714
795
|
student_name,
|
|
715
796
|
epoch,
|
|
@@ -721,6 +802,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
721
802
|
scheduler,
|
|
722
803
|
scaler,
|
|
723
804
|
model_base,
|
|
805
|
+
**extra_states,
|
|
724
806
|
)
|
|
725
807
|
if args.keep_last is not None:
|
|
726
808
|
fs_ops.clean_checkpoints(student_name, args.keep_last)
|
|
@@ -766,6 +848,10 @@ def train(args: argparse.Namespace) -> None:
|
|
|
766
848
|
|
|
767
849
|
# Checkpoint model
|
|
768
850
|
if training_utils.is_local_primary(args) is True:
|
|
851
|
+
extra_states = {}
|
|
852
|
+
if embedding_projection_to_save is not None:
|
|
853
|
+
extra_states["embedding_projection"] = embedding_projection_to_save.state_dict()
|
|
854
|
+
|
|
769
855
|
fs_ops.checkpoint_model(
|
|
770
856
|
student_name,
|
|
771
857
|
epoch,
|
|
@@ -777,6 +863,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
777
863
|
scheduler,
|
|
778
864
|
scaler,
|
|
779
865
|
model_base,
|
|
866
|
+
**extra_states,
|
|
780
867
|
)
|
|
781
868
|
|
|
782
869
|
training_utils.shutdown_distributed_mode(args)
|
|
@@ -896,6 +983,8 @@ def validate_args(args: argparse.Namespace) -> None:
|
|
|
896
983
|
training_cli.common_args_validation(args)
|
|
897
984
|
|
|
898
985
|
# Script specific checks
|
|
986
|
+
if args.type is None:
|
|
987
|
+
raise cli.ValidationError("--type is required")
|
|
899
988
|
if args.teacher is None:
|
|
900
989
|
raise cli.ValidationError("--teacher is required")
|
|
901
990
|
if args.student is None:
|
|
@@ -905,6 +994,9 @@ def validate_args(args: argparse.Namespace) -> None:
|
|
|
905
994
|
if registry.exists(args.student, task=Task.IMAGE_CLASSIFICATION) is False:
|
|
906
995
|
raise cli.ValidationError(f"--student {args.student} not supported, see list-models tool for available options")
|
|
907
996
|
|
|
997
|
+
if args.type == "embedding" and (args.pts is True or args.pt2 is True):
|
|
998
|
+
raise cli.ValidationError("--type embedding does not support --pts or --pt2 teachers")
|
|
999
|
+
|
|
908
1000
|
if args.smoothing_alpha < 0 or args.smoothing_alpha >= 0.5:
|
|
909
1001
|
raise cli.ValidationError(f"--smoothing-alpha must be in range of [0, 0.5), got {args.smoothing_alpha}")
|
|
910
1002
|
|
birder/scripts/train_mim.py
CHANGED
|
@@ -25,7 +25,6 @@ from birder.common import training_utils
|
|
|
25
25
|
from birder.common.lib import format_duration
|
|
26
26
|
from birder.common.lib import get_mim_network_name
|
|
27
27
|
from birder.common.lib import get_network_name
|
|
28
|
-
from birder.common.lib import set_random_seeds
|
|
29
28
|
from birder.conf import settings
|
|
30
29
|
from birder.data.dataloader.webdataset import make_wds_loader
|
|
31
30
|
from birder.data.datasets.directory import make_image_dataset
|
|
@@ -49,9 +48,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
49
48
|
#
|
|
50
49
|
# Initialize
|
|
51
50
|
#
|
|
52
|
-
training_utils.
|
|
53
|
-
logger.info(f"Starting training, birder version: {birder.__version__}, pytorch version: {torch.__version__}")
|
|
54
|
-
training_utils.log_git_info()
|
|
51
|
+
(device, device_id, disable_tqdm) = training_utils.init_training(args, logger)
|
|
55
52
|
|
|
56
53
|
if args.size is None:
|
|
57
54
|
# Prefer mim size over encoder default size
|
|
@@ -59,32 +56,6 @@ def train(args: argparse.Namespace) -> None:
|
|
|
59
56
|
|
|
60
57
|
logger.info(f"Using size={args.size}")
|
|
61
58
|
|
|
62
|
-
if args.cpu is True:
|
|
63
|
-
device = torch.device("cpu")
|
|
64
|
-
device_id = 0
|
|
65
|
-
else:
|
|
66
|
-
device = torch.device("cuda")
|
|
67
|
-
device_id = torch.cuda.current_device()
|
|
68
|
-
|
|
69
|
-
if args.use_deterministic_algorithms is True:
|
|
70
|
-
torch.backends.cudnn.benchmark = False
|
|
71
|
-
torch.use_deterministic_algorithms(True)
|
|
72
|
-
else:
|
|
73
|
-
torch.backends.cudnn.benchmark = True
|
|
74
|
-
|
|
75
|
-
if args.seed is not None:
|
|
76
|
-
set_random_seeds(args.seed)
|
|
77
|
-
|
|
78
|
-
if args.non_interactive is True or training_utils.is_local_primary(args) is False:
|
|
79
|
-
disable_tqdm = True
|
|
80
|
-
elif sys.stderr.isatty() is False:
|
|
81
|
-
disable_tqdm = True
|
|
82
|
-
else:
|
|
83
|
-
disable_tqdm = False
|
|
84
|
-
|
|
85
|
-
# Enable or disable the autograd anomaly detection
|
|
86
|
-
torch.autograd.set_detect_anomaly(args.grad_anomaly_detection)
|
|
87
|
-
|
|
88
59
|
#
|
|
89
60
|
# Data
|
|
90
61
|
#
|
|
@@ -131,7 +102,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
131
102
|
|
|
132
103
|
batch_size: int = args.batch_size
|
|
133
104
|
grad_accum_steps: int = args.grad_accum_steps
|
|
134
|
-
logger.debug(f"Effective batch size = {
|
|
105
|
+
logger.debug(f"Effective batch size = {batch_size * grad_accum_steps * args.world_size}")
|
|
135
106
|
|
|
136
107
|
# Data loaders and samplers
|
|
137
108
|
if args.distributed is True:
|
|
@@ -172,6 +143,8 @@ def train(args: argparse.Namespace) -> None:
|
|
|
172
143
|
else:
|
|
173
144
|
args.stop_epoch += 1
|
|
174
145
|
|
|
146
|
+
logging.debug(f"Epoch has {last_batch_idx+1} iterations ({optimizer_steps_per_epoch} steps)")
|
|
147
|
+
|
|
175
148
|
#
|
|
176
149
|
# Initialize network
|
|
177
150
|
#
|
|
@@ -241,22 +214,25 @@ def train(args: argparse.Namespace) -> None:
|
|
|
241
214
|
# Loss criteria, optimizer, learning rate scheduler and training parameter groups
|
|
242
215
|
#
|
|
243
216
|
|
|
217
|
+
# Learning rate scaling
|
|
218
|
+
lr = training_utils.scale_lr(args)
|
|
219
|
+
|
|
244
220
|
# Training parameter groups
|
|
245
221
|
custom_keys_weight_decay = training_utils.get_wd_custom_keys(args)
|
|
246
222
|
parameters = training_utils.optimizer_parameter_groups(
|
|
247
223
|
net,
|
|
248
224
|
args.wd,
|
|
225
|
+
base_lr=lr,
|
|
249
226
|
norm_weight_decay=args.norm_wd,
|
|
250
227
|
custom_keys_weight_decay=custom_keys_weight_decay,
|
|
228
|
+
custom_layer_weight_decay=args.custom_layer_wd,
|
|
251
229
|
layer_decay=args.layer_decay,
|
|
252
230
|
layer_decay_min_scale=args.layer_decay_min_scale,
|
|
253
231
|
layer_decay_no_opt_scale=args.layer_decay_no_opt_scale,
|
|
254
232
|
bias_lr=args.bias_lr,
|
|
233
|
+
custom_layer_lr_scale=args.custom_layer_lr_scale,
|
|
255
234
|
)
|
|
256
235
|
|
|
257
|
-
# Learning rate scaling
|
|
258
|
-
lr = training_utils.scale_lr(args)
|
|
259
|
-
|
|
260
236
|
if args.lr_scheduler_update == "epoch":
|
|
261
237
|
step_update = False
|
|
262
238
|
scheduler_steps_per_epoch = 1
|
birder/scripts/train_mmcr.py
CHANGED
|
@@ -36,7 +36,6 @@ from birder.common import training_utils
|
|
|
36
36
|
from birder.common.lib import format_duration
|
|
37
37
|
from birder.common.lib import get_mim_network_name
|
|
38
38
|
from birder.common.lib import get_network_name
|
|
39
|
-
from birder.common.lib import set_random_seeds
|
|
40
39
|
from birder.conf import settings
|
|
41
40
|
from birder.data.dataloader.webdataset import make_wds_loader
|
|
42
41
|
from birder.data.datasets.directory import make_image_dataset
|
|
@@ -74,41 +73,13 @@ def train(args: argparse.Namespace) -> None:
|
|
|
74
73
|
#
|
|
75
74
|
# Initialize
|
|
76
75
|
#
|
|
77
|
-
training_utils.
|
|
78
|
-
logger.info(f"Starting training, birder version: {birder.__version__}, pytorch version: {torch.__version__}")
|
|
79
|
-
training_utils.log_git_info()
|
|
76
|
+
(device, device_id, disable_tqdm) = training_utils.init_training(args, logger)
|
|
80
77
|
|
|
81
78
|
if args.size is None:
|
|
82
79
|
args.size = registry.get_default_size(args.network)
|
|
83
80
|
|
|
84
81
|
logger.info(f"Using size={args.size}")
|
|
85
82
|
|
|
86
|
-
if args.cpu is True:
|
|
87
|
-
device = torch.device("cpu")
|
|
88
|
-
device_id = 0
|
|
89
|
-
else:
|
|
90
|
-
device = torch.device("cuda")
|
|
91
|
-
device_id = torch.cuda.current_device()
|
|
92
|
-
|
|
93
|
-
if args.use_deterministic_algorithms is True:
|
|
94
|
-
torch.backends.cudnn.benchmark = False
|
|
95
|
-
torch.use_deterministic_algorithms(True)
|
|
96
|
-
else:
|
|
97
|
-
torch.backends.cudnn.benchmark = True
|
|
98
|
-
|
|
99
|
-
if args.seed is not None:
|
|
100
|
-
set_random_seeds(args.seed)
|
|
101
|
-
|
|
102
|
-
if args.non_interactive is True or training_utils.is_local_primary(args) is False:
|
|
103
|
-
disable_tqdm = True
|
|
104
|
-
elif sys.stderr.isatty() is False:
|
|
105
|
-
disable_tqdm = True
|
|
106
|
-
else:
|
|
107
|
-
disable_tqdm = False
|
|
108
|
-
|
|
109
|
-
# Enable or disable the autograd anomaly detection
|
|
110
|
-
torch.autograd.set_detect_anomaly(args.grad_anomaly_detection)
|
|
111
|
-
|
|
112
83
|
#
|
|
113
84
|
# Data
|
|
114
85
|
#
|
|
@@ -155,7 +126,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
155
126
|
|
|
156
127
|
batch_size: int = args.batch_size
|
|
157
128
|
grad_accum_steps: int = args.grad_accum_steps
|
|
158
|
-
logger.debug(f"Effective batch size = {
|
|
129
|
+
logger.debug(f"Effective batch size = {batch_size * grad_accum_steps * args.world_size}")
|
|
159
130
|
|
|
160
131
|
# Data loaders and samplers
|
|
161
132
|
if args.distributed is True:
|
|
@@ -196,6 +167,8 @@ def train(args: argparse.Namespace) -> None:
|
|
|
196
167
|
else:
|
|
197
168
|
args.stop_epoch += 1
|
|
198
169
|
|
|
170
|
+
logging.debug(f"Epoch has {last_batch_idx+1} iterations ({optimizer_steps_per_epoch} steps)")
|
|
171
|
+
|
|
199
172
|
#
|
|
200
173
|
# Initialize network
|
|
201
174
|
#
|
|
@@ -243,22 +216,25 @@ def train(args: argparse.Namespace) -> None:
|
|
|
243
216
|
# Loss
|
|
244
217
|
mmcr_loss = MMCRMomentumLoss(args.lambda_coeff, n_aug=args.n_aug)
|
|
245
218
|
|
|
219
|
+
# Learning rate scaling
|
|
220
|
+
lr = training_utils.scale_lr(args)
|
|
221
|
+
|
|
246
222
|
# Training parameter groups
|
|
247
223
|
custom_keys_weight_decay = training_utils.get_wd_custom_keys(args)
|
|
248
224
|
parameters = training_utils.optimizer_parameter_groups(
|
|
249
225
|
net,
|
|
250
226
|
args.wd,
|
|
227
|
+
base_lr=lr,
|
|
251
228
|
norm_weight_decay=args.norm_wd,
|
|
252
229
|
custom_keys_weight_decay=custom_keys_weight_decay,
|
|
230
|
+
custom_layer_weight_decay=args.custom_layer_wd,
|
|
253
231
|
layer_decay=args.layer_decay,
|
|
254
232
|
layer_decay_min_scale=args.layer_decay_min_scale,
|
|
255
233
|
layer_decay_no_opt_scale=args.layer_decay_no_opt_scale,
|
|
256
234
|
bias_lr=args.bias_lr,
|
|
235
|
+
custom_layer_lr_scale=args.custom_layer_lr_scale,
|
|
257
236
|
)
|
|
258
237
|
|
|
259
|
-
# Learning rate scaling
|
|
260
|
-
lr = training_utils.scale_lr(args)
|
|
261
|
-
|
|
262
238
|
if args.lr_scheduler_update == "epoch":
|
|
263
239
|
step_update = False
|
|
264
240
|
scheduler_steps_per_epoch = 1
|
birder/scripts/train_rotnet.py
CHANGED
|
@@ -31,7 +31,6 @@ from birder.common import training_cli
|
|
|
31
31
|
from birder.common import training_utils
|
|
32
32
|
from birder.common.lib import format_duration
|
|
33
33
|
from birder.common.lib import get_network_name
|
|
34
|
-
from birder.common.lib import set_random_seeds
|
|
35
34
|
from birder.conf import settings
|
|
36
35
|
from birder.data.dataloader.webdataset import make_wds_loader
|
|
37
36
|
from birder.data.datasets.directory import make_image_dataset
|
|
@@ -83,41 +82,13 @@ def train(args: argparse.Namespace) -> None:
|
|
|
83
82
|
#
|
|
84
83
|
# Initialize
|
|
85
84
|
#
|
|
86
|
-
training_utils.
|
|
87
|
-
logger.info(f"Starting training, birder version: {birder.__version__}, pytorch version: {torch.__version__}")
|
|
88
|
-
training_utils.log_git_info()
|
|
85
|
+
(device, device_id, disable_tqdm) = training_utils.init_training(args, logger)
|
|
89
86
|
|
|
90
87
|
if args.size is None:
|
|
91
88
|
args.size = registry.get_default_size(args.network)
|
|
92
89
|
|
|
93
90
|
logger.info(f"Using size={args.size}")
|
|
94
91
|
|
|
95
|
-
if args.cpu is True:
|
|
96
|
-
device = torch.device("cpu")
|
|
97
|
-
device_id = 0
|
|
98
|
-
else:
|
|
99
|
-
device = torch.device("cuda")
|
|
100
|
-
device_id = torch.cuda.current_device()
|
|
101
|
-
|
|
102
|
-
if args.use_deterministic_algorithms is True:
|
|
103
|
-
torch.backends.cudnn.benchmark = False
|
|
104
|
-
torch.use_deterministic_algorithms(True)
|
|
105
|
-
else:
|
|
106
|
-
torch.backends.cudnn.benchmark = True
|
|
107
|
-
|
|
108
|
-
if args.seed is not None:
|
|
109
|
-
set_random_seeds(args.seed)
|
|
110
|
-
|
|
111
|
-
if args.non_interactive is True or training_utils.is_local_primary(args) is False:
|
|
112
|
-
disable_tqdm = True
|
|
113
|
-
elif sys.stderr.isatty() is False:
|
|
114
|
-
disable_tqdm = True
|
|
115
|
-
else:
|
|
116
|
-
disable_tqdm = False
|
|
117
|
-
|
|
118
|
-
# Enable or disable the autograd anomaly detection
|
|
119
|
-
torch.autograd.set_detect_anomaly(args.grad_anomaly_detection)
|
|
120
|
-
|
|
121
92
|
#
|
|
122
93
|
# Data
|
|
123
94
|
#
|
|
@@ -169,7 +140,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
169
140
|
|
|
170
141
|
batch_size: int = args.batch_size
|
|
171
142
|
grad_accum_steps: int = args.grad_accum_steps
|
|
172
|
-
logger.debug(f"Effective batch size = {
|
|
143
|
+
logger.debug(f"Effective batch size = {batch_size * grad_accum_steps * args.world_size}")
|
|
173
144
|
|
|
174
145
|
# Data loaders and samplers
|
|
175
146
|
if args.distributed is True:
|
|
@@ -210,6 +181,8 @@ def train(args: argparse.Namespace) -> None:
|
|
|
210
181
|
else:
|
|
211
182
|
args.stop_epoch += 1
|
|
212
183
|
|
|
184
|
+
logging.debug(f"Epoch has {last_batch_idx+1} iterations ({optimizer_steps_per_epoch} steps)")
|
|
185
|
+
|
|
213
186
|
#
|
|
214
187
|
# Initialize network
|
|
215
188
|
#
|
|
@@ -252,25 +225,28 @@ def train(args: argparse.Namespace) -> None:
|
|
|
252
225
|
# Loss criteria, optimizer, learning rate scheduler and training parameter groups
|
|
253
226
|
#
|
|
254
227
|
|
|
228
|
+
# Learning rate scaling
|
|
229
|
+
lr = training_utils.scale_lr(args)
|
|
230
|
+
|
|
255
231
|
# Training parameter groups
|
|
256
232
|
custom_keys_weight_decay = training_utils.get_wd_custom_keys(args)
|
|
257
233
|
parameters = training_utils.optimizer_parameter_groups(
|
|
258
234
|
net,
|
|
259
235
|
args.wd,
|
|
236
|
+
base_lr=lr,
|
|
260
237
|
norm_weight_decay=args.norm_wd,
|
|
261
238
|
custom_keys_weight_decay=custom_keys_weight_decay,
|
|
239
|
+
custom_layer_weight_decay=args.custom_layer_wd,
|
|
262
240
|
layer_decay=args.layer_decay,
|
|
263
241
|
layer_decay_min_scale=args.layer_decay_min_scale,
|
|
264
242
|
layer_decay_no_opt_scale=args.layer_decay_no_opt_scale,
|
|
265
243
|
bias_lr=args.bias_lr,
|
|
244
|
+
custom_layer_lr_scale=args.custom_layer_lr_scale,
|
|
266
245
|
)
|
|
267
246
|
|
|
268
247
|
# Loss criteria
|
|
269
248
|
criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
|
|
270
249
|
|
|
271
|
-
# Learning rate scaling
|
|
272
|
-
lr = training_utils.scale_lr(args)
|
|
273
|
-
|
|
274
250
|
if args.lr_scheduler_update == "epoch":
|
|
275
251
|
step_update = False
|
|
276
252
|
scheduler_steps_per_epoch = 1
|