birder 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/adversarial/base.py +1 -1
- birder/adversarial/simba.py +4 -4
- birder/common/cli.py +1 -1
- birder/common/fs_ops.py +11 -11
- birder/common/lib.py +2 -2
- birder/common/masking.py +3 -3
- birder/common/training_cli.py +24 -2
- birder/common/training_utils.py +28 -4
- birder/data/collators/detection.py +9 -1
- birder/data/transforms/detection.py +27 -8
- birder/data/transforms/mosaic.py +1 -1
- birder/datahub/classification.py +3 -3
- birder/inference/classification.py +3 -3
- birder/inference/data_parallel.py +1 -1
- birder/inference/detection.py +5 -5
- birder/inference/wbf.py +1 -1
- birder/introspection/attention_rollout.py +5 -5
- birder/introspection/feature_pca.py +4 -4
- birder/introspection/gradcam.py +1 -1
- birder/introspection/guided_backprop.py +2 -2
- birder/introspection/transformer_attribution.py +3 -3
- birder/layers/attention_pool.py +2 -2
- birder/model_registry/model_registry.py +2 -1
- birder/net/__init__.py +2 -0
- birder/net/_rope_vit_configs.py +5 -0
- birder/net/_vit_configs.py +0 -13
- birder/net/alexnet.py +5 -5
- birder/net/base.py +28 -3
- birder/net/biformer.py +17 -17
- birder/net/cait.py +2 -2
- birder/net/cas_vit.py +1 -1
- birder/net/coat.py +15 -15
- birder/net/convnext_v1.py +2 -10
- birder/net/convnext_v1_iso.py +198 -0
- birder/net/convnext_v2.py +2 -10
- birder/net/crossformer.py +9 -9
- birder/net/crossvit.py +1 -1
- birder/net/cspnet.py +1 -1
- birder/net/cswin_transformer.py +10 -10
- birder/net/davit.py +10 -10
- birder/net/deit.py +56 -3
- birder/net/deit3.py +27 -15
- birder/net/detection/__init__.py +4 -0
- birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
- birder/net/detection/base.py +6 -5
- birder/net/detection/deformable_detr.py +26 -28
- birder/net/detection/detr.py +9 -9
- birder/net/detection/efficientdet.py +9 -28
- birder/net/detection/faster_rcnn.py +22 -22
- birder/net/detection/fcos.py +8 -8
- birder/net/detection/plain_detr.py +852 -0
- birder/net/detection/retinanet.py +4 -4
- birder/net/detection/rt_detr_v1.py +81 -25
- birder/net/detection/rt_detr_v2.py +1147 -0
- birder/net/detection/ssd.py +5 -5
- birder/net/detection/yolo_v2.py +12 -12
- birder/net/detection/yolo_v3.py +19 -19
- birder/net/detection/yolo_v4.py +16 -16
- birder/net/detection/yolo_v4_tiny.py +3 -3
- birder/net/edgenext.py +3 -3
- birder/net/edgevit.py +10 -14
- birder/net/efficientformer_v1.py +1 -1
- birder/net/efficientvim.py +9 -9
- birder/net/efficientvit_mit.py +2 -2
- birder/net/efficientvit_msft.py +3 -3
- birder/net/fasternet.py +1 -1
- birder/net/fastvit.py +5 -12
- birder/net/flexivit.py +28 -15
- birder/net/focalnet.py +5 -9
- birder/net/gc_vit.py +11 -11
- birder/net/ghostnet_v1.py +1 -1
- birder/net/ghostnet_v2.py +1 -1
- birder/net/groupmixformer.py +12 -12
- birder/net/hgnet_v1.py +1 -1
- birder/net/hgnet_v2.py +4 -4
- birder/net/hiera.py +6 -6
- birder/net/hieradet.py +9 -9
- birder/net/hornet.py +3 -3
- birder/net/iformer.py +4 -4
- birder/net/inception_next.py +4 -14
- birder/net/levit.py +3 -3
- birder/net/lit_v1.py +13 -15
- birder/net/lit_v1_tiny.py +9 -9
- birder/net/lit_v2.py +14 -15
- birder/net/maxvit.py +10 -22
- birder/net/metaformer.py +2 -2
- birder/net/mim/crossmae.py +5 -5
- birder/net/mim/fcmae.py +3 -5
- birder/net/mim/mae_hiera.py +7 -7
- birder/net/mim/mae_vit.py +3 -5
- birder/net/mim/simmim.py +2 -3
- birder/net/mobilenet_v4_hybrid.py +4 -4
- birder/net/mobileone.py +5 -12
- birder/net/mobilevit_v1.py +2 -2
- birder/net/mobilevit_v2.py +5 -9
- birder/net/mvit_v2.py +24 -24
- birder/net/nextvit.py +2 -2
- birder/net/pit.py +11 -26
- birder/net/pvt_v1.py +4 -4
- birder/net/pvt_v2.py +5 -11
- birder/net/regionvit.py +15 -15
- birder/net/regnet.py +1 -1
- birder/net/repghost.py +4 -5
- birder/net/repvgg.py +3 -5
- birder/net/repvit.py +2 -2
- birder/net/resnest.py +1 -1
- birder/net/rope_deit3.py +29 -15
- birder/net/rope_flexivit.py +28 -15
- birder/net/rope_vit.py +41 -23
- birder/net/sequencer2d.py +3 -4
- birder/net/shufflenet_v1.py +1 -1
- birder/net/shufflenet_v2.py +1 -1
- birder/net/simple_vit.py +47 -5
- birder/net/smt.py +7 -7
- birder/net/ssl/barlow_twins.py +1 -1
- birder/net/ssl/byol.py +2 -2
- birder/net/ssl/capi.py +3 -3
- birder/net/ssl/data2vec2.py +1 -1
- birder/net/ssl/dino_v2.py +11 -1
- birder/net/ssl/franca.py +26 -2
- birder/net/ssl/i_jepa.py +4 -4
- birder/net/ssl/mmcr.py +1 -1
- birder/net/swiftformer.py +1 -1
- birder/net/swin_transformer_v1.py +4 -5
- birder/net/swin_transformer_v2.py +4 -7
- birder/net/tiny_vit.py +3 -3
- birder/net/transnext.py +19 -19
- birder/net/uniformer.py +4 -4
- birder/net/vgg.py +1 -10
- birder/net/vit.py +38 -25
- birder/net/vit_parallel.py +35 -20
- birder/net/vit_sam.py +10 -10
- birder/net/vovnet_v2.py +1 -1
- birder/net/xcit.py +9 -7
- birder/ops/msda.py +4 -4
- birder/ops/swattention.py +10 -10
- birder/results/classification.py +3 -3
- birder/results/gui.py +8 -8
- birder/scripts/benchmark.py +37 -12
- birder/scripts/evaluate.py +1 -1
- birder/scripts/predict.py +3 -3
- birder/scripts/predict_detection.py +2 -2
- birder/scripts/train.py +63 -15
- birder/scripts/train_barlow_twins.py +10 -7
- birder/scripts/train_byol.py +10 -7
- birder/scripts/train_capi.py +15 -10
- birder/scripts/train_data2vec.py +10 -7
- birder/scripts/train_data2vec2.py +10 -7
- birder/scripts/train_detection.py +29 -14
- birder/scripts/train_dino_v1.py +13 -9
- birder/scripts/train_dino_v2.py +27 -14
- birder/scripts/train_dino_v2_dist.py +28 -15
- birder/scripts/train_franca.py +16 -9
- birder/scripts/train_i_jepa.py +12 -9
- birder/scripts/train_ibot.py +15 -11
- birder/scripts/train_kd.py +64 -17
- birder/scripts/train_mim.py +11 -8
- birder/scripts/train_mmcr.py +11 -8
- birder/scripts/train_rotnet.py +11 -7
- birder/scripts/train_simclr.py +10 -7
- birder/scripts/train_vicreg.py +10 -7
- birder/tools/adversarial.py +4 -4
- birder/tools/auto_anchors.py +5 -5
- birder/tools/avg_model.py +1 -1
- birder/tools/convert_model.py +30 -22
- birder/tools/det_results.py +1 -1
- birder/tools/download_model.py +1 -1
- birder/tools/ensemble_model.py +1 -1
- birder/tools/introspection.py +11 -2
- birder/tools/labelme_to_coco.py +2 -2
- birder/tools/model_info.py +12 -14
- birder/tools/pack.py +8 -8
- birder/tools/quantize_model.py +53 -4
- birder/tools/results.py +2 -2
- birder/tools/show_det_iterator.py +19 -6
- birder/tools/show_iterator.py +2 -2
- birder/tools/similarity.py +5 -5
- birder/tools/stats.py +4 -6
- birder/tools/voc_to_coco.py +1 -1
- birder/version.py +1 -1
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
- birder-0.4.1.dist-info/RECORD +300 -0
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
- birder-0.4.0.dist-info/RECORD +0 -297
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
birder/scripts/benchmark.py
CHANGED
|
@@ -13,6 +13,7 @@ from birder.common import cli
|
|
|
13
13
|
from birder.conf import settings
|
|
14
14
|
from birder.model_registry import Task
|
|
15
15
|
from birder.model_registry import registry
|
|
16
|
+
from birder.net.base import DetectorBackbone
|
|
16
17
|
|
|
17
18
|
logger = logging.getLogger(__name__)
|
|
18
19
|
|
|
@@ -27,6 +28,23 @@ def prepare_model(net: torch.nn.Module) -> None:
|
|
|
27
28
|
param.requires_grad_(False)
|
|
28
29
|
|
|
29
30
|
|
|
31
|
+
def init_plain_model(
|
|
32
|
+
model_name: str, sample_shape: tuple[int, ...], device: torch.device, args: argparse.Namespace
|
|
33
|
+
) -> torch.nn.Module:
|
|
34
|
+
size = (sample_shape[2], sample_shape[3])
|
|
35
|
+
input_channels = sample_shape[1]
|
|
36
|
+
if args.backbone is not None:
|
|
37
|
+
backbone = registry.net_factory(args.backbone, args.num_classes, input_channels, size=size)
|
|
38
|
+
net = registry.detection_net_factory(model_name, args.num_classes, backbone, size=size)
|
|
39
|
+
else:
|
|
40
|
+
net = registry.net_factory(model_name, args.num_classes, input_channels, size=size)
|
|
41
|
+
|
|
42
|
+
net.to(device)
|
|
43
|
+
prepare_model(net)
|
|
44
|
+
|
|
45
|
+
return net
|
|
46
|
+
|
|
47
|
+
|
|
30
48
|
def throughput_benchmark(
|
|
31
49
|
net: torch.nn.Module, device: torch.device, sample_shape: tuple[int, ...], model_name: str, args: argparse.Namespace
|
|
32
50
|
) -> tuple[float, int]:
|
|
@@ -110,14 +128,10 @@ def memory_benchmark(
|
|
|
110
128
|
)
|
|
111
129
|
|
|
112
130
|
if args.plain is True:
|
|
113
|
-
|
|
114
|
-
input_channels = sample_shape[1]
|
|
115
|
-
net = registry.net_factory(model_name, input_channels, 0, size=size)
|
|
116
|
-
net.to(device)
|
|
117
|
-
prepare_model(net)
|
|
131
|
+
net = init_plain_model(model_name, sample_shape, device, args)
|
|
118
132
|
|
|
119
133
|
else:
|
|
120
|
-
|
|
134
|
+
net, _ = birder.load_pretrained_model(model_name, inference=True, device=device)
|
|
121
135
|
if args.size is not None:
|
|
122
136
|
size = (sample_shape[2], sample_shape[3])
|
|
123
137
|
net.adjust_size(size)
|
|
@@ -182,7 +196,8 @@ def benchmark(args: argparse.Namespace) -> None:
|
|
|
182
196
|
if args.plain is True:
|
|
183
197
|
model_list = args.models or []
|
|
184
198
|
if len(model_list) == 0:
|
|
185
|
-
|
|
199
|
+
task = Task.OBJECT_DETECTION if args.backbone is not None else Task.IMAGE_CLASSIFICATION
|
|
200
|
+
model_list = registry.list_models(include_filter=args.filter, task=task)
|
|
186
201
|
|
|
187
202
|
else:
|
|
188
203
|
model_list = birder.list_pretrained_models(args.filter)
|
|
@@ -234,11 +249,9 @@ def benchmark(args: argparse.Namespace) -> None:
|
|
|
234
249
|
else:
|
|
235
250
|
# Initialize model
|
|
236
251
|
if args.plain is True:
|
|
237
|
-
net =
|
|
238
|
-
net.to(device)
|
|
239
|
-
prepare_model(net)
|
|
252
|
+
net = init_plain_model(model_name, sample_shape, device, args)
|
|
240
253
|
else:
|
|
241
|
-
|
|
254
|
+
net, _ = birder.load_pretrained_model(model_name, inference=True, device=device)
|
|
242
255
|
if args.size is not None:
|
|
243
256
|
net.adjust_size(size)
|
|
244
257
|
|
|
@@ -247,7 +260,7 @@ def benchmark(args: argparse.Namespace) -> None:
|
|
|
247
260
|
net = torch.compile(net)
|
|
248
261
|
|
|
249
262
|
peak_memory = None
|
|
250
|
-
|
|
263
|
+
t_elapsed, batch_size = throughput_benchmark(net, device, sample_shape, model_name, args)
|
|
251
264
|
if t_elapsed < 0.0:
|
|
252
265
|
continue
|
|
253
266
|
|
|
@@ -305,12 +318,18 @@ def get_args_parser() -> argparse.ArgumentParser:
|
|
|
305
318
|
"--compile --suffix il-common --append\n"
|
|
306
319
|
"python -m birder.scripts.benchmark --plain --models rdnet_t convnext_v1_tiny --bench-iter 50 --repeats 1 "
|
|
307
320
|
"--gpu --size 416 --dry-run\n"
|
|
321
|
+
"python -m birder.scripts.benchmark --plain --models retinanet --backbone resnet_v1_50 --num-classes 91 "
|
|
322
|
+
"--size 640 --gpu --dry-run\n"
|
|
308
323
|
),
|
|
309
324
|
formatter_class=cli.ArgumentHelpFormatter,
|
|
310
325
|
)
|
|
311
326
|
parser.add_argument("--filter", type=str, help="models to benchmark (fnmatch type filter)")
|
|
312
327
|
parser.add_argument("--models", nargs="+", help="plain network names to benchmark")
|
|
313
328
|
parser.add_argument("--plain", default=False, action="store_true", help="benchmark plain networks without weights")
|
|
329
|
+
parser.add_argument("--backbone", type=str, help="backbone name for plain detection benchmarks")
|
|
330
|
+
parser.add_argument(
|
|
331
|
+
"--num-classes", type=int, default=0, metavar="N", help="number of classes for plain benchmarks"
|
|
332
|
+
)
|
|
314
333
|
parser.add_argument("--compile", default=False, action="store_true", help="enable compilation")
|
|
315
334
|
parser.add_argument(
|
|
316
335
|
"--amp", default=False, action="store_true", help="use torch.amp.autocast for mixed precision inference"
|
|
@@ -353,6 +372,12 @@ def validate_args(args: argparse.Namespace) -> None:
|
|
|
353
372
|
raise cli.ValidationError("--memory cannot be used with --compile")
|
|
354
373
|
if args.plain is False and args.models is not None:
|
|
355
374
|
raise cli.ValidationError("--models can only be used with --plain")
|
|
375
|
+
if args.backbone is not None and args.plain is False:
|
|
376
|
+
raise cli.ValidationError("--backbone can only be used with --plain")
|
|
377
|
+
if args.backbone is not None and registry.exists(args.backbone, net_type=DetectorBackbone) is False:
|
|
378
|
+
raise cli.ValidationError(
|
|
379
|
+
f"--backbone {args.backbone} not supported, see list-models tool for available options"
|
|
380
|
+
)
|
|
356
381
|
|
|
357
382
|
|
|
358
383
|
def args_from_dict(**kwargs: Any) -> argparse.Namespace:
|
birder/scripts/evaluate.py
CHANGED
|
@@ -37,7 +37,7 @@ def evaluate(args: argparse.Namespace) -> None:
|
|
|
37
37
|
amp_dtype: torch.dtype = getattr(torch, args.amp_dtype)
|
|
38
38
|
model_list = birder.list_pretrained_models(args.filter)
|
|
39
39
|
for model_name in model_list:
|
|
40
|
-
|
|
40
|
+
net, (class_to_idx, signature, rgb_stats, *_) = birder.load_pretrained_model(
|
|
41
41
|
model_name, inference=True, device=device, dtype=model_dtype
|
|
42
42
|
)
|
|
43
43
|
if args.parallel is True and torch.cuda.device_count() > 1:
|
birder/scripts/predict.py
CHANGED
|
@@ -204,7 +204,7 @@ def predict(args: argparse.Namespace) -> None:
|
|
|
204
204
|
raise RuntimeError("'pip install torchao' to load quantization operators") from exc
|
|
205
205
|
|
|
206
206
|
network_name = lib.get_network_name(args.network, tag=args.tag)
|
|
207
|
-
|
|
207
|
+
net, (class_to_idx, signature, rgb_stats, *_) = fs_ops.load_model(
|
|
208
208
|
device,
|
|
209
209
|
args.network,
|
|
210
210
|
config=args.model_config,
|
|
@@ -261,11 +261,11 @@ def predict(args: argparse.Namespace) -> None:
|
|
|
261
261
|
if args.wds is True:
|
|
262
262
|
wds_path: str | list[str]
|
|
263
263
|
if args.wds_info is not None:
|
|
264
|
-
|
|
264
|
+
wds_path, dataset_size = wds_args_from_info(args.wds_info, args.wds_split)
|
|
265
265
|
if args.wds_size is not None:
|
|
266
266
|
dataset_size = args.wds_size
|
|
267
267
|
else:
|
|
268
|
-
|
|
268
|
+
wds_path, dataset_size = prepare_wds_args(args.data_path[0], args.wds_size, device)
|
|
269
269
|
|
|
270
270
|
num_samples = dataset_size
|
|
271
271
|
dataset = make_wds_dataset(
|
|
@@ -60,7 +60,7 @@ def predict(args: argparse.Namespace) -> None:
|
|
|
60
60
|
network_name = lib.get_detection_network_name(
|
|
61
61
|
args.network, tag=args.tag, backbone=args.backbone, backbone_tag=args.backbone_tag
|
|
62
62
|
)
|
|
63
|
-
|
|
63
|
+
net, (class_to_idx, signature, rgb_stats, *_) = fs_ops.load_detection_model(
|
|
64
64
|
device,
|
|
65
65
|
args.network,
|
|
66
66
|
config=args.model_config,
|
|
@@ -197,7 +197,7 @@ def predict(args: argparse.Namespace) -> None:
|
|
|
197
197
|
# Inference
|
|
198
198
|
tic = time.time()
|
|
199
199
|
with torch.inference_mode():
|
|
200
|
-
|
|
200
|
+
sample_paths, detections, targets = infer_dataloader(
|
|
201
201
|
device,
|
|
202
202
|
net,
|
|
203
203
|
inference_loader,
|
birder/scripts/train.py
CHANGED
|
@@ -7,6 +7,7 @@ import time
|
|
|
7
7
|
from collections.abc import Iterator
|
|
8
8
|
from pathlib import Path
|
|
9
9
|
from typing import Any
|
|
10
|
+
from typing import Optional
|
|
10
11
|
|
|
11
12
|
import matplotlib.pyplot as plt
|
|
12
13
|
import numpy as np
|
|
@@ -52,7 +53,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
52
53
|
#
|
|
53
54
|
# Initialize
|
|
54
55
|
#
|
|
55
|
-
|
|
56
|
+
device, device_id, disable_tqdm = training_utils.init_training(args, logger)
|
|
56
57
|
|
|
57
58
|
if args.size is None:
|
|
58
59
|
args.size = registry.get_default_size(args.network)
|
|
@@ -77,15 +78,15 @@ def train(args: argparse.Namespace) -> None:
|
|
|
77
78
|
training_wds_path: str | list[str]
|
|
78
79
|
val_wds_path: str | list[str]
|
|
79
80
|
if args.wds_info is not None:
|
|
80
|
-
|
|
81
|
-
|
|
81
|
+
training_wds_path, training_size = wds_args_from_info(args.wds_info, args.wds_training_split)
|
|
82
|
+
val_wds_path, val_size = wds_args_from_info(args.wds_info, args.wds_val_split)
|
|
82
83
|
if args.wds_train_size is not None:
|
|
83
84
|
training_size = args.wds_train_size
|
|
84
85
|
if args.wds_val_size is not None:
|
|
85
86
|
val_size = args.wds_val_size
|
|
86
87
|
else:
|
|
87
|
-
|
|
88
|
-
|
|
88
|
+
training_wds_path, training_size = prepare_wds_args(args.data_path, args.wds_train_size, device)
|
|
89
|
+
val_wds_path, val_size = prepare_wds_args(args.val_path, args.wds_val_size, device)
|
|
89
90
|
|
|
90
91
|
training_dataset = make_wds_dataset(
|
|
91
92
|
training_wds_path,
|
|
@@ -149,7 +150,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
149
150
|
|
|
150
151
|
# Data loaders and samplers
|
|
151
152
|
virtual_epoch_mode = args.steps_per_epoch is not None
|
|
152
|
-
|
|
153
|
+
train_sampler, validation_sampler = training_utils.get_samplers(
|
|
153
154
|
args, training_dataset, validation_dataset, infinite=virtual_epoch_mode
|
|
154
155
|
)
|
|
155
156
|
|
|
@@ -231,7 +232,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
231
232
|
|
|
232
233
|
if args.resume_epoch is not None:
|
|
233
234
|
begin_epoch = args.resume_epoch + 1
|
|
234
|
-
|
|
235
|
+
net, class_to_idx_saved, training_states = fs_ops.load_checkpoint(
|
|
235
236
|
device,
|
|
236
237
|
args.network,
|
|
237
238
|
config=args.model_config,
|
|
@@ -247,7 +248,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
247
248
|
|
|
248
249
|
elif args.pretrained is True:
|
|
249
250
|
fs_ops.download_model_by_weights(network_name, progress_bar=training_utils.is_local_primary(args))
|
|
250
|
-
|
|
251
|
+
net, class_to_idx_saved, training_states = fs_ops.load_checkpoint(
|
|
251
252
|
device,
|
|
252
253
|
args.network,
|
|
253
254
|
config=args.model_config,
|
|
@@ -262,7 +263,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
262
263
|
assert class_to_idx == class_to_idx_saved
|
|
263
264
|
|
|
264
265
|
else:
|
|
265
|
-
net = registry.net_factory(args.network, sample_shape[1],
|
|
266
|
+
net = registry.net_factory(args.network, num_outputs, sample_shape[1], config=args.model_config, size=args.size)
|
|
266
267
|
training_states = fs_ops.TrainingStates.empty()
|
|
267
268
|
|
|
268
269
|
net.to(device, dtype=model_dtype)
|
|
@@ -328,7 +329,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
328
329
|
optimizer.step = torch.compile(optimizer.step, fullgraph=False)
|
|
329
330
|
|
|
330
331
|
# Gradient scaler and AMP related tasks
|
|
331
|
-
|
|
332
|
+
scaler, amp_dtype = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
|
|
332
333
|
|
|
333
334
|
# Load states
|
|
334
335
|
if args.load_states is True:
|
|
@@ -474,16 +475,32 @@ def train(args: argparse.Namespace) -> None:
|
|
|
474
475
|
if virtual_epoch_mode is True:
|
|
475
476
|
train_iter = iter(training_loader)
|
|
476
477
|
|
|
478
|
+
top_k = args.top_k
|
|
477
479
|
running_loss = training_utils.SmoothedValue(window_size=64)
|
|
478
480
|
running_val_loss = training_utils.SmoothedValue()
|
|
479
481
|
train_accuracy = training_utils.SmoothedValue(window_size=64)
|
|
480
482
|
val_accuracy = training_utils.SmoothedValue()
|
|
483
|
+
train_topk: Optional[training_utils.SmoothedValue] = None
|
|
484
|
+
val_topk: Optional[training_utils.SmoothedValue] = None
|
|
485
|
+
if top_k is not None:
|
|
486
|
+
train_topk = training_utils.SmoothedValue(window_size=64)
|
|
487
|
+
val_topk = training_utils.SmoothedValue()
|
|
481
488
|
|
|
482
489
|
logger.info(f"Starting training with learning rate of {last_lr}")
|
|
483
490
|
for epoch in range(begin_epoch, args.stop_epoch):
|
|
484
491
|
tic = time.time()
|
|
485
492
|
net.train()
|
|
486
493
|
|
|
494
|
+
# Clear metrics
|
|
495
|
+
running_loss.clear()
|
|
496
|
+
running_val_loss.clear()
|
|
497
|
+
train_accuracy.clear()
|
|
498
|
+
val_accuracy.clear()
|
|
499
|
+
if train_topk is not None:
|
|
500
|
+
train_topk.clear()
|
|
501
|
+
if val_topk is not None:
|
|
502
|
+
val_topk.clear()
|
|
503
|
+
|
|
487
504
|
if args.distributed is True or virtual_epoch_mode is True:
|
|
488
505
|
train_sampler.set_epoch(epoch)
|
|
489
506
|
|
|
@@ -565,6 +582,9 @@ def train(args: argparse.Namespace) -> None:
|
|
|
565
582
|
targets = targets.argmax(dim=1)
|
|
566
583
|
|
|
567
584
|
train_accuracy.update(training_utils.accuracy(targets, outputs.detach()))
|
|
585
|
+
if train_topk is not None:
|
|
586
|
+
topk_val = training_utils.topk_accuracy(targets, outputs.detach(), topk=(top_k,))[0]
|
|
587
|
+
train_topk.update(topk_val)
|
|
568
588
|
|
|
569
589
|
# Write statistics
|
|
570
590
|
if (i % args.log_interval == 0 and i > 0) or i == last_batch_idx:
|
|
@@ -583,6 +603,9 @@ def train(args: argparse.Namespace) -> None:
|
|
|
583
603
|
|
|
584
604
|
running_loss.synchronize_between_processes(device)
|
|
585
605
|
train_accuracy.synchronize_between_processes(device)
|
|
606
|
+
if train_topk is not None:
|
|
607
|
+
train_topk.synchronize_between_processes(device)
|
|
608
|
+
|
|
586
609
|
with training_utils.single_handler_logging(logger, file_handler, enabled=not disable_tqdm) as log:
|
|
587
610
|
log.info(
|
|
588
611
|
f"[Trn] Epoch {epoch}/{epochs-1}, iter {i+1}/{last_batch_idx+1} "
|
|
@@ -597,8 +620,17 @@ def train(args: argparse.Namespace) -> None:
|
|
|
597
620
|
f"[Trn] Epoch {epoch}/{epochs-1}, iter {i+1}/{last_batch_idx+1} "
|
|
598
621
|
f"Accuracy: {train_accuracy.avg:.4f}"
|
|
599
622
|
)
|
|
623
|
+
if train_topk is not None:
|
|
624
|
+
log.info(
|
|
625
|
+
f"[Trn] Epoch {epoch}/{epochs-1}, iter {i+1}/{last_batch_idx+1} "
|
|
626
|
+
f"Accuracy@{top_k}: {train_topk.avg:.4f}"
|
|
627
|
+
)
|
|
600
628
|
|
|
601
629
|
if training_utils.is_local_primary(args) is True:
|
|
630
|
+
performance = {"training_accuracy": train_accuracy.avg}
|
|
631
|
+
if train_topk is not None:
|
|
632
|
+
performance[f"training_accuracy@{top_k}"] = train_topk.avg
|
|
633
|
+
|
|
602
634
|
summary_writer.add_scalars(
|
|
603
635
|
"loss",
|
|
604
636
|
{"training": running_loss.avg},
|
|
@@ -606,7 +638,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
606
638
|
)
|
|
607
639
|
summary_writer.add_scalars(
|
|
608
640
|
"performance",
|
|
609
|
-
|
|
641
|
+
performance,
|
|
610
642
|
((epoch - 1) * epoch_samples) + ((i + 1) * batch_size * args.world_size),
|
|
611
643
|
)
|
|
612
644
|
|
|
@@ -618,6 +650,8 @@ def train(args: argparse.Namespace) -> None:
|
|
|
618
650
|
# Epoch training metrics
|
|
619
651
|
logger.info(f"[Trn] Epoch {epoch}/{epochs-1} training_loss: {running_loss.global_avg:.4f}")
|
|
620
652
|
logger.info(f"[Trn] Epoch {epoch}/{epochs-1} training_accuracy: {train_accuracy.global_avg:.4f}")
|
|
653
|
+
if train_topk is not None:
|
|
654
|
+
logger.info(f"[Trn] Epoch {epoch}/{epochs-1} training_accuracy@{top_k}: {train_topk.global_avg:.4f}")
|
|
621
655
|
|
|
622
656
|
# Validation
|
|
623
657
|
eval_model.eval()
|
|
@@ -649,6 +683,9 @@ def train(args: argparse.Namespace) -> None:
|
|
|
649
683
|
# Statistics
|
|
650
684
|
running_val_loss.update(val_loss.detach())
|
|
651
685
|
val_accuracy.update(training_utils.accuracy(targets, outputs), n=outputs.size(0))
|
|
686
|
+
if val_topk is not None:
|
|
687
|
+
topk_val = training_utils.topk_accuracy(targets, outputs, topk=(top_k,))[0]
|
|
688
|
+
val_topk.update(topk_val, n=outputs.size(0))
|
|
652
689
|
|
|
653
690
|
# Update progress bar
|
|
654
691
|
progress.update(n=batch_size * args.world_size)
|
|
@@ -666,19 +703,30 @@ def train(args: argparse.Namespace) -> None:
|
|
|
666
703
|
|
|
667
704
|
running_val_loss.synchronize_between_processes(device)
|
|
668
705
|
val_accuracy.synchronize_between_processes(device)
|
|
706
|
+
if val_topk is not None:
|
|
707
|
+
val_topk.synchronize_between_processes(device)
|
|
708
|
+
|
|
669
709
|
epoch_val_loss = running_val_loss.global_avg
|
|
670
710
|
epoch_val_accuracy = val_accuracy.global_avg
|
|
711
|
+
if val_topk is not None:
|
|
712
|
+
epoch_val_topk = val_topk.global_avg
|
|
713
|
+
else:
|
|
714
|
+
epoch_val_topk = None
|
|
671
715
|
|
|
672
716
|
# Write statistics
|
|
673
717
|
if training_utils.is_local_primary(args) is True:
|
|
674
718
|
summary_writer.add_scalars("loss", {"validation": epoch_val_loss}, epoch * epoch_samples)
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
719
|
+
performance = {"validation_accuracy": epoch_val_accuracy}
|
|
720
|
+
if epoch_val_topk is not None:
|
|
721
|
+
performance[f"validation_accuracy@{top_k}"] = epoch_val_topk
|
|
722
|
+
|
|
723
|
+
summary_writer.add_scalars("performance", performance, epoch * epoch_samples)
|
|
678
724
|
|
|
679
725
|
# Epoch validation metrics
|
|
680
726
|
logger.info(f"[Val] Epoch {epoch}/{epochs-1} validation_loss: {epoch_val_loss:.4f}")
|
|
681
727
|
logger.info(f"[Val] Epoch {epoch}/{epochs-1} validation_accuracy: {epoch_val_accuracy:.4f}")
|
|
728
|
+
if epoch_val_topk is not None:
|
|
729
|
+
logger.info(f"[Val] Epoch {epoch}/{epochs-1} validation_accuracy@{top_k}: {epoch_val_topk:.4f}")
|
|
682
730
|
|
|
683
731
|
# Learning rate scheduler update
|
|
684
732
|
if step_update is False:
|
|
@@ -849,7 +897,7 @@ def get_args_parser() -> argparse.ArgumentParser:
|
|
|
849
897
|
training_cli.add_compile_args(parser)
|
|
850
898
|
training_cli.add_checkpoint_args(parser, default_save_frequency=5, pretrained=True)
|
|
851
899
|
training_cli.add_distributed_args(parser)
|
|
852
|
-
training_cli.add_logging_and_debug_args(parser)
|
|
900
|
+
training_cli.add_logging_and_debug_args(parser, classification=True)
|
|
853
901
|
training_cli.add_training_data_args(parser)
|
|
854
902
|
|
|
855
903
|
return parser
|
|
@@ -69,7 +69,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
69
69
|
#
|
|
70
70
|
# Initialize
|
|
71
71
|
#
|
|
72
|
-
|
|
72
|
+
device, device_id, disable_tqdm = training_utils.init_training(args, logger)
|
|
73
73
|
|
|
74
74
|
if args.size is None:
|
|
75
75
|
args.size = registry.get_default_size(args.network)
|
|
@@ -92,11 +92,11 @@ def train(args: argparse.Namespace) -> None:
|
|
|
92
92
|
elif args.wds is True:
|
|
93
93
|
wds_path: str | list[str]
|
|
94
94
|
if args.wds_info is not None:
|
|
95
|
-
|
|
95
|
+
wds_path, dataset_size = wds_args_from_info(args.wds_info, args.wds_split)
|
|
96
96
|
if args.wds_size is not None:
|
|
97
97
|
dataset_size = args.wds_size
|
|
98
98
|
else:
|
|
99
|
-
|
|
99
|
+
wds_path, dataset_size = prepare_wds_args(args.data_path[0], args.wds_size, device)
|
|
100
100
|
|
|
101
101
|
training_dataset = make_wds_dataset(
|
|
102
102
|
wds_path,
|
|
@@ -126,7 +126,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
126
126
|
|
|
127
127
|
# Data loaders and samplers
|
|
128
128
|
virtual_epoch_mode = args.steps_per_epoch is not None
|
|
129
|
-
|
|
129
|
+
train_sampler, _ = training_utils.get_samplers(
|
|
130
130
|
args, training_dataset, validation_dataset=None, infinite=virtual_epoch_mode
|
|
131
131
|
)
|
|
132
132
|
|
|
@@ -189,12 +189,12 @@ def train(args: argparse.Namespace) -> None:
|
|
|
189
189
|
|
|
190
190
|
network_name = get_mim_network_name("barlow_twins", encoder=args.network, tag=args.tag)
|
|
191
191
|
|
|
192
|
-
backbone = registry.net_factory(args.network, sample_shape[1],
|
|
192
|
+
backbone = registry.net_factory(args.network, 0, sample_shape[1], config=args.model_config, size=args.size)
|
|
193
193
|
net = BarlowTwins(backbone, config={"projector_sizes": args.projector_dims, "off_lambda": args.off_lambda})
|
|
194
194
|
|
|
195
195
|
if args.resume_epoch is not None:
|
|
196
196
|
begin_epoch = args.resume_epoch + 1
|
|
197
|
-
|
|
197
|
+
net, training_states = fs_ops.load_simple_checkpoint(
|
|
198
198
|
device, net, network_name, epoch=args.resume_epoch, strict=not args.non_strict_weights
|
|
199
199
|
)
|
|
200
200
|
|
|
@@ -253,7 +253,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
253
253
|
optimizer.step = torch.compile(optimizer.step, fullgraph=False)
|
|
254
254
|
|
|
255
255
|
# Gradient scaler and AMP related tasks
|
|
256
|
-
|
|
256
|
+
scaler, amp_dtype = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
|
|
257
257
|
|
|
258
258
|
# Load states
|
|
259
259
|
if args.load_states is True:
|
|
@@ -365,6 +365,9 @@ def train(args: argparse.Namespace) -> None:
|
|
|
365
365
|
tic = time.time()
|
|
366
366
|
net.train()
|
|
367
367
|
|
|
368
|
+
# Clear metrics
|
|
369
|
+
running_loss.clear()
|
|
370
|
+
|
|
368
371
|
if args.distributed is True or virtual_epoch_mode is True:
|
|
369
372
|
train_sampler.set_epoch(epoch)
|
|
370
373
|
|
birder/scripts/train_byol.py
CHANGED
|
@@ -70,7 +70,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
70
70
|
#
|
|
71
71
|
# Initialize
|
|
72
72
|
#
|
|
73
|
-
|
|
73
|
+
device, device_id, disable_tqdm = training_utils.init_training(args, logger)
|
|
74
74
|
|
|
75
75
|
if args.size is None:
|
|
76
76
|
# Prefer mim size over encoder default size
|
|
@@ -94,11 +94,11 @@ def train(args: argparse.Namespace) -> None:
|
|
|
94
94
|
elif args.wds is True:
|
|
95
95
|
wds_path: str | list[str]
|
|
96
96
|
if args.wds_info is not None:
|
|
97
|
-
|
|
97
|
+
wds_path, dataset_size = wds_args_from_info(args.wds_info, args.wds_split)
|
|
98
98
|
if args.wds_size is not None:
|
|
99
99
|
dataset_size = args.wds_size
|
|
100
100
|
else:
|
|
101
|
-
|
|
101
|
+
wds_path, dataset_size = prepare_wds_args(args.data_path[0], args.wds_size, device)
|
|
102
102
|
|
|
103
103
|
training_dataset = make_wds_dataset(
|
|
104
104
|
wds_path,
|
|
@@ -128,7 +128,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
128
128
|
|
|
129
129
|
# Data loaders and samplers
|
|
130
130
|
virtual_epoch_mode = args.steps_per_epoch is not None
|
|
131
|
-
|
|
131
|
+
train_sampler, _ = training_utils.get_samplers(
|
|
132
132
|
args, training_dataset, validation_dataset=None, infinite=virtual_epoch_mode
|
|
133
133
|
)
|
|
134
134
|
|
|
@@ -191,7 +191,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
191
191
|
|
|
192
192
|
network_name = get_mim_network_name("byol", encoder=args.network, tag=args.tag)
|
|
193
193
|
|
|
194
|
-
backbone = registry.net_factory(args.network, sample_shape[1],
|
|
194
|
+
backbone = registry.net_factory(args.network, 0, sample_shape[1], config=args.model_config, size=args.size)
|
|
195
195
|
net = BYOL(
|
|
196
196
|
backbone,
|
|
197
197
|
config={
|
|
@@ -202,7 +202,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
202
202
|
|
|
203
203
|
if args.resume_epoch is not None:
|
|
204
204
|
begin_epoch = args.resume_epoch + 1
|
|
205
|
-
|
|
205
|
+
net, training_states = fs_ops.load_simple_checkpoint(
|
|
206
206
|
device, net, network_name, epoch=args.resume_epoch, strict=not args.non_strict_weights
|
|
207
207
|
)
|
|
208
208
|
|
|
@@ -265,7 +265,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
265
265
|
optimizer.step = torch.compile(optimizer.step, fullgraph=False)
|
|
266
266
|
|
|
267
267
|
# Gradient scaler and AMP related tasks
|
|
268
|
-
|
|
268
|
+
scaler, amp_dtype = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
|
|
269
269
|
|
|
270
270
|
# Load states
|
|
271
271
|
if args.load_states is True:
|
|
@@ -377,6 +377,9 @@ def train(args: argparse.Namespace) -> None:
|
|
|
377
377
|
tic = time.time()
|
|
378
378
|
net.train()
|
|
379
379
|
|
|
380
|
+
# Clear metrics
|
|
381
|
+
running_loss.clear()
|
|
382
|
+
|
|
380
383
|
if args.distributed is True or virtual_epoch_mode is True:
|
|
381
384
|
train_sampler.set_epoch(epoch)
|
|
382
385
|
|
birder/scripts/train_capi.py
CHANGED
|
@@ -79,7 +79,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
79
79
|
#
|
|
80
80
|
# Initialize
|
|
81
81
|
#
|
|
82
|
-
|
|
82
|
+
device, device_id, disable_tqdm = training_utils.init_training(args, logger)
|
|
83
83
|
|
|
84
84
|
if args.size is None:
|
|
85
85
|
args.size = registry.get_default_size(args.network)
|
|
@@ -108,8 +108,8 @@ def train(args: argparse.Namespace) -> None:
|
|
|
108
108
|
|
|
109
109
|
network_name = get_mim_network_name("capi", encoder=args.network, tag=args.tag)
|
|
110
110
|
|
|
111
|
-
student_backbone = registry.net_factory(args.network, sample_shape[1],
|
|
112
|
-
teacher_backbone = registry.net_factory(args.network, sample_shape[1],
|
|
111
|
+
student_backbone = registry.net_factory(args.network, 0, sample_shape[1], config=args.model_config, size=args.size)
|
|
112
|
+
teacher_backbone = registry.net_factory(args.network, 0, sample_shape[1], config=args.model_config, size=args.size)
|
|
113
113
|
|
|
114
114
|
teacher_backbone.load_state_dict(student_backbone.state_dict())
|
|
115
115
|
|
|
@@ -144,7 +144,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
144
144
|
|
|
145
145
|
if args.resume_epoch is not None:
|
|
146
146
|
begin_epoch = args.resume_epoch + 1
|
|
147
|
-
|
|
147
|
+
net, training_states = fs_ops.load_simple_checkpoint(
|
|
148
148
|
device, net, network_name, epoch=args.resume_epoch, strict=not args.non_strict_weights
|
|
149
149
|
)
|
|
150
150
|
student = net["student"]
|
|
@@ -194,11 +194,11 @@ def train(args: argparse.Namespace) -> None:
|
|
|
194
194
|
elif args.wds is True:
|
|
195
195
|
wds_path: str | list[str]
|
|
196
196
|
if args.wds_info is not None:
|
|
197
|
-
|
|
197
|
+
wds_path, dataset_size = wds_args_from_info(args.wds_info, args.wds_split)
|
|
198
198
|
if args.wds_size is not None:
|
|
199
199
|
dataset_size = args.wds_size
|
|
200
200
|
else:
|
|
201
|
-
|
|
201
|
+
wds_path, dataset_size = prepare_wds_args(args.data_path[0], args.wds_size, device)
|
|
202
202
|
|
|
203
203
|
training_dataset = make_wds_dataset(
|
|
204
204
|
wds_path,
|
|
@@ -224,7 +224,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
224
224
|
|
|
225
225
|
# Data loaders and samplers
|
|
226
226
|
virtual_epoch_mode = args.steps_per_epoch is not None
|
|
227
|
-
|
|
227
|
+
train_sampler, _ = training_utils.get_samplers(
|
|
228
228
|
args, training_dataset, validation_dataset=None, infinite=virtual_epoch_mode
|
|
229
229
|
)
|
|
230
230
|
|
|
@@ -326,8 +326,8 @@ def train(args: argparse.Namespace) -> None:
|
|
|
326
326
|
student_temp = 0.12
|
|
327
327
|
|
|
328
328
|
# Gradient scaler and AMP related tasks
|
|
329
|
-
|
|
330
|
-
|
|
329
|
+
scaler, amp_dtype = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
|
|
330
|
+
clustering_scaler, _ = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
|
|
331
331
|
|
|
332
332
|
# Load states
|
|
333
333
|
if args.load_states is True:
|
|
@@ -453,6 +453,11 @@ def train(args: argparse.Namespace) -> None:
|
|
|
453
453
|
tic = time.time()
|
|
454
454
|
net.train()
|
|
455
455
|
|
|
456
|
+
# Clear metrics
|
|
457
|
+
running_loss.clear()
|
|
458
|
+
running_clustering_loss.clear()
|
|
459
|
+
running_target_entropy.clear()
|
|
460
|
+
|
|
456
461
|
if args.sinkhorn_queue_size is not None:
|
|
457
462
|
queue_active = epoch > args.sinkhorn_queue_warmup_epochs
|
|
458
463
|
teacher_without_ddp.head.set_queue_active(queue_active)
|
|
@@ -499,7 +504,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
499
504
|
|
|
500
505
|
# Forward, backward and optimize
|
|
501
506
|
with torch.amp.autocast("cuda", enabled=args.amp, dtype=amp_dtype):
|
|
502
|
-
|
|
507
|
+
selected_assignments, clustering_loss = teacher(images, None, predict_indices)
|
|
503
508
|
|
|
504
509
|
if clustering_scaler is not None:
|
|
505
510
|
clustering_scaler.scale(clustering_loss).backward()
|
birder/scripts/train_data2vec.py
CHANGED
|
@@ -69,7 +69,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
69
69
|
#
|
|
70
70
|
# Initialize
|
|
71
71
|
#
|
|
72
|
-
|
|
72
|
+
device, device_id, disable_tqdm = training_utils.init_training(args, logger)
|
|
73
73
|
|
|
74
74
|
if args.size is None:
|
|
75
75
|
# Prefer mim size over encoder default size
|
|
@@ -99,7 +99,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
99
99
|
|
|
100
100
|
network_name = get_mim_network_name("data2vec", encoder=args.network, tag=args.tag)
|
|
101
101
|
|
|
102
|
-
backbone = registry.net_factory(args.network, sample_shape[1],
|
|
102
|
+
backbone = registry.net_factory(args.network, 0, sample_shape[1], config=args.model_config, size=args.size)
|
|
103
103
|
net = Data2Vec(
|
|
104
104
|
backbone,
|
|
105
105
|
config={
|
|
@@ -112,7 +112,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
112
112
|
|
|
113
113
|
if args.resume_epoch is not None:
|
|
114
114
|
begin_epoch = args.resume_epoch + 1
|
|
115
|
-
|
|
115
|
+
net, training_states = fs_ops.load_simple_checkpoint(
|
|
116
116
|
device, net, network_name, epoch=args.resume_epoch, strict=not args.non_strict_weights
|
|
117
117
|
)
|
|
118
118
|
|
|
@@ -160,11 +160,11 @@ def train(args: argparse.Namespace) -> None:
|
|
|
160
160
|
elif args.wds is True:
|
|
161
161
|
wds_path: str | list[str]
|
|
162
162
|
if args.wds_info is not None:
|
|
163
|
-
|
|
163
|
+
wds_path, dataset_size = wds_args_from_info(args.wds_info, args.wds_split)
|
|
164
164
|
if args.wds_size is not None:
|
|
165
165
|
dataset_size = args.wds_size
|
|
166
166
|
else:
|
|
167
|
-
|
|
167
|
+
wds_path, dataset_size = prepare_wds_args(args.data_path[0], args.wds_size, device)
|
|
168
168
|
|
|
169
169
|
training_dataset = make_wds_dataset(
|
|
170
170
|
wds_path,
|
|
@@ -190,7 +190,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
190
190
|
|
|
191
191
|
# Data loaders and samplers
|
|
192
192
|
virtual_epoch_mode = args.steps_per_epoch is not None
|
|
193
|
-
|
|
193
|
+
train_sampler, _ = training_utils.get_samplers(
|
|
194
194
|
args, training_dataset, validation_dataset=None, infinite=virtual_epoch_mode
|
|
195
195
|
)
|
|
196
196
|
|
|
@@ -279,7 +279,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
279
279
|
)
|
|
280
280
|
|
|
281
281
|
# Gradient scaler and AMP related tasks
|
|
282
|
-
|
|
282
|
+
scaler, amp_dtype = training_utils.get_amp_scaler(args.amp, args.amp_dtype)
|
|
283
283
|
|
|
284
284
|
# Load states
|
|
285
285
|
if args.load_states is True:
|
|
@@ -391,6 +391,9 @@ def train(args: argparse.Namespace) -> None:
|
|
|
391
391
|
tic = time.time()
|
|
392
392
|
net.train()
|
|
393
393
|
|
|
394
|
+
# Clear metrics
|
|
395
|
+
running_loss.clear()
|
|
396
|
+
|
|
394
397
|
if args.distributed is True or virtual_epoch_mode is True:
|
|
395
398
|
train_sampler.set_epoch(epoch)
|
|
396
399
|
|