ultralytics 8.1.29__py3-none-any.whl → 8.3.62__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +22 -0
- tests/conftest.py +83 -0
- tests/test_cli.py +122 -0
- tests/test_cuda.py +155 -0
- tests/test_engine.py +131 -0
- tests/test_exports.py +216 -0
- tests/test_integrations.py +150 -0
- tests/test_python.py +615 -0
- tests/test_solutions.py +94 -0
- ultralytics/__init__.py +11 -8
- ultralytics/cfg/__init__.py +569 -131
- ultralytics/cfg/datasets/Argoverse.yaml +2 -1
- ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
- ultralytics/cfg/datasets/ImageNet.yaml +2 -1
- ultralytics/cfg/datasets/Objects365.yaml +5 -4
- ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
- ultralytics/cfg/datasets/VOC.yaml +3 -2
- ultralytics/cfg/datasets/VisDrone.yaml +6 -5
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
- ultralytics/cfg/datasets/coco-pose.yaml +7 -6
- ultralytics/cfg/datasets/coco.yaml +3 -2
- ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
- ultralytics/cfg/datasets/coco128.yaml +4 -3
- ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
- ultralytics/cfg/datasets/coco8.yaml +3 -2
- ultralytics/cfg/datasets/crack-seg.yaml +3 -2
- ultralytics/cfg/datasets/dog-pose.yaml +24 -0
- ultralytics/cfg/datasets/dota8.yaml +3 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
- ultralytics/cfg/datasets/lvis.yaml +1236 -0
- ultralytics/cfg/datasets/medical-pills.yaml +22 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
- ultralytics/cfg/datasets/package-seg.yaml +5 -4
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
- ultralytics/cfg/datasets/xView.yaml +2 -1
- ultralytics/cfg/default.yaml +14 -11
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
- ultralytics/cfg/models/v3/yolov3.yaml +5 -2
- ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
- ultralytics/cfg/models/v5/yolov5.yaml +5 -2
- ultralytics/cfg/models/v6/yolov6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8.yaml +5 -2
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/solutions/default.yaml +24 -0
- ultralytics/cfg/trackers/botsort.yaml +8 -5
- ultralytics/cfg/trackers/bytetrack.yaml +8 -5
- ultralytics/data/__init__.py +14 -3
- ultralytics/data/annotator.py +37 -15
- ultralytics/data/augment.py +1783 -289
- ultralytics/data/base.py +62 -27
- ultralytics/data/build.py +36 -8
- ultralytics/data/converter.py +196 -36
- ultralytics/data/dataset.py +233 -94
- ultralytics/data/loaders.py +199 -96
- ultralytics/data/split_dota.py +39 -29
- ultralytics/data/utils.py +110 -40
- ultralytics/engine/__init__.py +1 -1
- ultralytics/engine/exporter.py +569 -242
- ultralytics/engine/model.py +604 -252
- ultralytics/engine/predictor.py +22 -11
- ultralytics/engine/results.py +1228 -218
- ultralytics/engine/trainer.py +190 -129
- ultralytics/engine/tuner.py +18 -18
- ultralytics/engine/validator.py +18 -15
- ultralytics/hub/__init__.py +31 -13
- ultralytics/hub/auth.py +11 -7
- ultralytics/hub/google/__init__.py +159 -0
- ultralytics/hub/session.py +128 -94
- ultralytics/hub/utils.py +20 -21
- ultralytics/models/__init__.py +4 -2
- ultralytics/models/fastsam/__init__.py +2 -3
- ultralytics/models/fastsam/model.py +26 -4
- ultralytics/models/fastsam/predict.py +127 -63
- ultralytics/models/fastsam/utils.py +1 -44
- ultralytics/models/fastsam/val.py +1 -1
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +21 -10
- ultralytics/models/nas/predict.py +3 -6
- ultralytics/models/nas/val.py +4 -4
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +1 -1
- ultralytics/models/rtdetr/predict.py +6 -8
- ultralytics/models/rtdetr/train.py +6 -2
- ultralytics/models/rtdetr/val.py +3 -3
- ultralytics/models/sam/__init__.py +3 -3
- ultralytics/models/sam/amg.py +29 -23
- ultralytics/models/sam/build.py +211 -13
- ultralytics/models/sam/model.py +91 -30
- ultralytics/models/sam/modules/__init__.py +1 -1
- ultralytics/models/sam/modules/blocks.py +1129 -0
- ultralytics/models/sam/modules/decoders.py +381 -53
- ultralytics/models/sam/modules/encoders.py +515 -324
- ultralytics/models/sam/modules/memory_attention.py +237 -0
- ultralytics/models/sam/modules/sam.py +969 -21
- ultralytics/models/sam/modules/tiny_encoder.py +425 -154
- ultralytics/models/sam/modules/transformer.py +159 -60
- ultralytics/models/sam/modules/utils.py +293 -0
- ultralytics/models/sam/predict.py +1263 -132
- ultralytics/models/utils/__init__.py +1 -1
- ultralytics/models/utils/loss.py +36 -24
- ultralytics/models/utils/ops.py +3 -7
- ultralytics/models/yolo/__init__.py +3 -3
- ultralytics/models/yolo/classify/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +7 -8
- ultralytics/models/yolo/classify/train.py +17 -22
- ultralytics/models/yolo/classify/val.py +8 -4
- ultralytics/models/yolo/detect/__init__.py +1 -1
- ultralytics/models/yolo/detect/predict.py +3 -5
- ultralytics/models/yolo/detect/train.py +11 -4
- ultralytics/models/yolo/detect/val.py +90 -52
- ultralytics/models/yolo/model.py +14 -9
- ultralytics/models/yolo/obb/__init__.py +1 -1
- ultralytics/models/yolo/obb/predict.py +2 -2
- ultralytics/models/yolo/obb/train.py +5 -3
- ultralytics/models/yolo/obb/val.py +41 -23
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +3 -5
- ultralytics/models/yolo/pose/train.py +2 -2
- ultralytics/models/yolo/pose/val.py +51 -17
- ultralytics/models/yolo/segment/__init__.py +1 -1
- ultralytics/models/yolo/segment/predict.py +3 -5
- ultralytics/models/yolo/segment/train.py +2 -2
- ultralytics/models/yolo/segment/val.py +60 -19
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +92 -0
- ultralytics/models/yolo/world/train_world.py +109 -0
- ultralytics/nn/__init__.py +1 -1
- ultralytics/nn/autobackend.py +228 -93
- ultralytics/nn/modules/__init__.py +39 -14
- ultralytics/nn/modules/activation.py +21 -0
- ultralytics/nn/modules/block.py +526 -66
- ultralytics/nn/modules/conv.py +24 -7
- ultralytics/nn/modules/head.py +177 -34
- ultralytics/nn/modules/transformer.py +6 -5
- ultralytics/nn/modules/utils.py +1 -2
- ultralytics/nn/tasks.py +225 -77
- ultralytics/solutions/__init__.py +30 -1
- ultralytics/solutions/ai_gym.py +96 -143
- ultralytics/solutions/analytics.py +247 -0
- ultralytics/solutions/distance_calculation.py +78 -135
- ultralytics/solutions/heatmap.py +93 -247
- ultralytics/solutions/object_counter.py +184 -259
- ultralytics/solutions/parking_management.py +246 -0
- ultralytics/solutions/queue_management.py +112 -0
- ultralytics/solutions/region_counter.py +116 -0
- ultralytics/solutions/security_alarm.py +144 -0
- ultralytics/solutions/solutions.py +178 -0
- ultralytics/solutions/speed_estimation.py +86 -174
- ultralytics/solutions/streamlit_inference.py +190 -0
- ultralytics/solutions/trackzone.py +68 -0
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +32 -13
- ultralytics/trackers/bot_sort.py +61 -28
- ultralytics/trackers/byte_tracker.py +83 -51
- ultralytics/trackers/track.py +21 -6
- ultralytics/trackers/utils/__init__.py +1 -1
- ultralytics/trackers/utils/gmc.py +62 -48
- ultralytics/trackers/utils/kalman_filter.py +166 -35
- ultralytics/trackers/utils/matching.py +40 -21
- ultralytics/utils/__init__.py +511 -239
- ultralytics/utils/autobatch.py +40 -22
- ultralytics/utils/benchmarks.py +266 -85
- ultralytics/utils/callbacks/__init__.py +1 -1
- ultralytics/utils/callbacks/base.py +1 -3
- ultralytics/utils/callbacks/clearml.py +7 -6
- ultralytics/utils/callbacks/comet.py +39 -17
- ultralytics/utils/callbacks/dvc.py +1 -1
- ultralytics/utils/callbacks/hub.py +16 -16
- ultralytics/utils/callbacks/mlflow.py +28 -24
- ultralytics/utils/callbacks/neptune.py +6 -2
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +18 -18
- ultralytics/utils/callbacks/wb.py +27 -20
- ultralytics/utils/checks.py +160 -100
- ultralytics/utils/dist.py +2 -1
- ultralytics/utils/downloads.py +40 -34
- ultralytics/utils/errors.py +1 -1
- ultralytics/utils/files.py +72 -38
- ultralytics/utils/instance.py +41 -19
- ultralytics/utils/loss.py +83 -55
- ultralytics/utils/metrics.py +61 -56
- ultralytics/utils/ops.py +94 -89
- ultralytics/utils/patches.py +30 -14
- ultralytics/utils/plotting.py +600 -269
- ultralytics/utils/tal.py +67 -26
- ultralytics/utils/torch_utils.py +302 -102
- ultralytics/utils/triton.py +2 -1
- ultralytics/utils/tuner.py +21 -12
- ultralytics-8.3.62.dist-info/METADATA +370 -0
- ultralytics-8.3.62.dist-info/RECORD +241 -0
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/WHEEL +1 -1
- ultralytics/data/explorer/__init__.py +0 -5
- ultralytics/data/explorer/explorer.py +0 -472
- ultralytics/data/explorer/gui/__init__.py +0 -1
- ultralytics/data/explorer/gui/dash.py +0 -268
- ultralytics/data/explorer/utils.py +0 -166
- ultralytics/models/fastsam/prompt.py +0 -357
- ultralytics-8.1.29.dist-info/METADATA +0 -373
- ultralytics-8.1.29.dist-info/RECORD +0 -197
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/LICENSE +0 -0
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/top_level.txt +0 -0
ultralytics/engine/trainer.py
CHANGED
@@ -1,17 +1,18 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
"""
|
3
3
|
Train a model on a dataset.
|
4
4
|
|
5
5
|
Usage:
|
6
|
-
$ yolo mode=train model=yolov8n.pt data=
|
6
|
+
$ yolo mode=train model=yolov8n.pt data=coco8.yaml imgsz=640 epochs=100 batch=16
|
7
7
|
"""
|
8
8
|
|
9
|
+
import gc
|
9
10
|
import math
|
10
11
|
import os
|
11
12
|
import subprocess
|
12
13
|
import time
|
13
14
|
import warnings
|
14
|
-
from copy import deepcopy
|
15
|
+
from copy import copy, deepcopy
|
15
16
|
from datetime import datetime, timedelta
|
16
17
|
from pathlib import Path
|
17
18
|
|
@@ -25,6 +26,7 @@ from ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
|
25
26
|
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
|
26
27
|
from ultralytics.utils import (
|
27
28
|
DEFAULT_CFG,
|
29
|
+
LOCAL_RANK,
|
28
30
|
LOGGER,
|
29
31
|
RANK,
|
30
32
|
TQDM,
|
@@ -40,20 +42,21 @@ from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_m
|
|
40
42
|
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
|
41
43
|
from ultralytics.utils.files import get_latest_run
|
42
44
|
from ultralytics.utils.torch_utils import (
|
45
|
+
TORCH_2_4,
|
43
46
|
EarlyStopping,
|
44
47
|
ModelEMA,
|
45
|
-
|
48
|
+
autocast,
|
49
|
+
convert_optimizer_state_dict_to_fp16,
|
46
50
|
init_seeds,
|
47
51
|
one_cycle,
|
48
52
|
select_device,
|
49
53
|
strip_optimizer,
|
54
|
+
torch_distributed_zero_first,
|
50
55
|
)
|
51
56
|
|
52
57
|
|
53
58
|
class BaseTrainer:
|
54
59
|
"""
|
55
|
-
BaseTrainer.
|
56
|
-
|
57
60
|
A base class for creating trainers.
|
58
61
|
|
59
62
|
Attributes:
|
@@ -107,7 +110,7 @@ class BaseTrainer:
|
|
107
110
|
self.save_dir = get_save_dir(self.args)
|
108
111
|
self.args.name = self.save_dir.name # update name for loggers
|
109
112
|
self.wdir = self.save_dir / "weights" # weights dir
|
110
|
-
if RANK in
|
113
|
+
if RANK in {-1, 0}:
|
111
114
|
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
112
115
|
self.args.save_dir = str(self.save_dir)
|
113
116
|
yaml_save(self.save_dir / "args.yaml", vars(self.args)) # save run args
|
@@ -115,33 +118,19 @@ class BaseTrainer:
|
|
115
118
|
self.save_period = self.args.save_period
|
116
119
|
|
117
120
|
self.batch_size = self.args.batch
|
118
|
-
self.epochs = self.args.epochs
|
121
|
+
self.epochs = self.args.epochs or 100 # in case users accidentally pass epochs=None with timed training
|
119
122
|
self.start_epoch = 0
|
120
123
|
if RANK == -1:
|
121
124
|
print_args(vars(self.args))
|
122
125
|
|
123
126
|
# Device
|
124
|
-
if self.device.type in
|
127
|
+
if self.device.type in {"cpu", "mps"}:
|
125
128
|
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
|
126
129
|
|
127
130
|
# Model and Dataset
|
128
131
|
self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt
|
129
|
-
|
130
|
-
|
131
|
-
self.data = check_cls_dataset(self.args.data)
|
132
|
-
elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in (
|
133
|
-
"detect",
|
134
|
-
"segment",
|
135
|
-
"pose",
|
136
|
-
"obb",
|
137
|
-
):
|
138
|
-
self.data = check_det_dataset(self.args.data)
|
139
|
-
if "yaml_file" in self.data:
|
140
|
-
self.args.data = self.data["yaml_file"] # for validating 'yolo train data=url.zip' usage
|
141
|
-
except Exception as e:
|
142
|
-
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
|
143
|
-
|
144
|
-
self.trainset, self.testset = self.get_dataset(self.data)
|
132
|
+
with torch_distributed_zero_first(LOCAL_RANK): # avoid auto-downloading dataset multiple times
|
133
|
+
self.trainset, self.testset = self.get_dataset()
|
145
134
|
self.ema = None
|
146
135
|
|
147
136
|
# Optimization utils init
|
@@ -157,9 +146,12 @@ class BaseTrainer:
|
|
157
146
|
self.csv = self.save_dir / "results.csv"
|
158
147
|
self.plot_idx = [0, 1, 2]
|
159
148
|
|
149
|
+
# HUB
|
150
|
+
self.hub_session = None
|
151
|
+
|
160
152
|
# Callbacks
|
161
153
|
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
162
|
-
if RANK in
|
154
|
+
if RANK in {-1, 0}:
|
163
155
|
callbacks.add_integration_callbacks(self)
|
164
156
|
|
165
157
|
def add_callback(self, event: str, callback):
|
@@ -181,9 +173,11 @@ class BaseTrainer:
|
|
181
173
|
world_size = len(self.args.device.split(","))
|
182
174
|
elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
|
183
175
|
world_size = len(self.args.device)
|
176
|
+
elif self.args.device in {"cpu", "mps"}: # i.e. device='cpu' or 'mps'
|
177
|
+
world_size = 0
|
184
178
|
elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number
|
185
179
|
world_size = 1 # default to device 0
|
186
|
-
else: # i.e. device=
|
180
|
+
else: # i.e. device=None or device=''
|
187
181
|
world_size = 0
|
188
182
|
|
189
183
|
# Run subprocess if DDP training, else train normally
|
@@ -192,9 +186,9 @@ class BaseTrainer:
|
|
192
186
|
if self.args.rect:
|
193
187
|
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'")
|
194
188
|
self.args.rect = False
|
195
|
-
if self.args.batch
|
189
|
+
if self.args.batch < 1.0:
|
196
190
|
LOGGER.warning(
|
197
|
-
"WARNING ⚠️ 'batch
|
191
|
+
"WARNING ⚠️ 'batch<1' for AutoBatch is incompatible with Multi-GPU training, setting "
|
198
192
|
"default 'batch=16'"
|
199
193
|
)
|
200
194
|
self.args.batch = 16
|
@@ -202,7 +196,7 @@ class BaseTrainer:
|
|
202
196
|
# Command
|
203
197
|
cmd, file = generate_ddp_command(world_size, self)
|
204
198
|
try:
|
205
|
-
LOGGER.info(f
|
199
|
+
LOGGER.info(f"{colorstr('DDP:')} debug command {' '.join(cmd)}")
|
206
200
|
subprocess.run(cmd, check=True)
|
207
201
|
except Exception as e:
|
208
202
|
raise e
|
@@ -225,9 +219,9 @@ class BaseTrainer:
|
|
225
219
|
torch.cuda.set_device(RANK)
|
226
220
|
self.device = torch.device("cuda", RANK)
|
227
221
|
# LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
|
228
|
-
os.environ["
|
222
|
+
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout
|
229
223
|
dist.init_process_group(
|
230
|
-
"nccl" if dist.is_nccl_available() else "gloo",
|
224
|
+
backend="nccl" if dist.is_nccl_available() else "gloo",
|
231
225
|
timeout=timedelta(seconds=10800), # 3 hours
|
232
226
|
rank=RANK,
|
233
227
|
world_size=world_size,
|
@@ -235,7 +229,6 @@ class BaseTrainer:
|
|
235
229
|
|
236
230
|
def _setup_train(self, world_size):
|
237
231
|
"""Builds dataloaders and optimizer on correct rank process."""
|
238
|
-
|
239
232
|
# Model
|
240
233
|
self.run_callbacks("on_pretrain_routine_start")
|
241
234
|
ckpt = self.setup_model()
|
@@ -266,16 +259,18 @@ class BaseTrainer:
|
|
266
259
|
|
267
260
|
# Check AMP
|
268
261
|
self.amp = torch.tensor(self.args.amp).to(self.device) # True or False
|
269
|
-
if self.amp and RANK in
|
262
|
+
if self.amp and RANK in {-1, 0}: # Single-GPU and DDP
|
270
263
|
callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
|
271
264
|
self.amp = torch.tensor(check_amp(self.model), device=self.device)
|
272
265
|
callbacks.default_callbacks = callbacks_backup # restore callbacks
|
273
266
|
if RANK > -1 and world_size > 1: # DDP
|
274
267
|
dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None)
|
275
268
|
self.amp = bool(self.amp) # as boolean
|
276
|
-
self.scaler =
|
269
|
+
self.scaler = (
|
270
|
+
torch.amp.GradScaler("cuda", enabled=self.amp) if TORCH_2_4 else torch.cuda.amp.GradScaler(enabled=self.amp)
|
271
|
+
)
|
277
272
|
if world_size > 1:
|
278
|
-
self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK])
|
273
|
+
self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK], find_unused_parameters=True)
|
279
274
|
|
280
275
|
# Check imgsz
|
281
276
|
gs = max(int(self.model.stride.max() if hasattr(self.model, "stride") else 32), 32) # grid size (max stride)
|
@@ -283,13 +278,13 @@ class BaseTrainer:
|
|
283
278
|
self.stride = gs # for multiscale training
|
284
279
|
|
285
280
|
# Batch size
|
286
|
-
if self.batch_size
|
287
|
-
self.args.batch = self.batch_size =
|
281
|
+
if self.batch_size < 1 and RANK == -1: # single-GPU only, estimate best batch size
|
282
|
+
self.args.batch = self.batch_size = self.auto_batch()
|
288
283
|
|
289
284
|
# Dataloaders
|
290
285
|
batch_size = self.batch_size // max(world_size, 1)
|
291
|
-
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=
|
292
|
-
if RANK in
|
286
|
+
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=LOCAL_RANK, mode="train")
|
287
|
+
if RANK in {-1, 0}:
|
293
288
|
# Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
|
294
289
|
self.test_loader = self.get_dataloader(
|
295
290
|
self.testset, batch_size=batch_size if self.args.task == "obb" else batch_size * 2, rank=-1, mode="val"
|
@@ -334,18 +329,23 @@ class BaseTrainer:
|
|
334
329
|
self.train_time_start = time.time()
|
335
330
|
self.run_callbacks("on_train_start")
|
336
331
|
LOGGER.info(
|
337
|
-
f
|
338
|
-
f
|
332
|
+
f"Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n"
|
333
|
+
f"Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n"
|
339
334
|
f"Logging results to {colorstr('bold', self.save_dir)}\n"
|
340
|
-
f
|
335
|
+
f"Starting training for " + (f"{self.args.time} hours..." if self.args.time else f"{self.epochs} epochs...")
|
341
336
|
)
|
342
337
|
if self.args.close_mosaic:
|
343
338
|
base_idx = (self.epochs - self.args.close_mosaic) * nb
|
344
339
|
self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
|
345
340
|
epoch = self.start_epoch
|
341
|
+
self.optimizer.zero_grad() # zero any resumed gradients to ensure stability on train start
|
346
342
|
while True:
|
347
343
|
self.epoch = epoch
|
348
344
|
self.run_callbacks("on_train_epoch_start")
|
345
|
+
with warnings.catch_warnings():
|
346
|
+
warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()'
|
347
|
+
self.scheduler.step()
|
348
|
+
|
349
349
|
self.model.train()
|
350
350
|
if RANK != -1:
|
351
351
|
self.train_loader.sampler.set_epoch(epoch)
|
@@ -355,11 +355,10 @@ class BaseTrainer:
|
|
355
355
|
self._close_dataloader_mosaic()
|
356
356
|
self.train_loader.reset()
|
357
357
|
|
358
|
-
if RANK in
|
358
|
+
if RANK in {-1, 0}:
|
359
359
|
LOGGER.info(self.progress_string())
|
360
360
|
pbar = TQDM(enumerate(self.train_loader), total=nb)
|
361
361
|
self.tloss = None
|
362
|
-
self.optimizer.zero_grad()
|
363
362
|
for i, batch in pbar:
|
364
363
|
self.run_callbacks("on_train_batch_start")
|
365
364
|
# Warmup
|
@@ -376,7 +375,7 @@ class BaseTrainer:
|
|
376
375
|
x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
|
377
376
|
|
378
377
|
# Forward
|
379
|
-
with
|
378
|
+
with autocast(self.amp):
|
380
379
|
batch = self.preprocess_batch(batch)
|
381
380
|
self.loss, self.loss_items = self.model(batch)
|
382
381
|
if RANK != -1:
|
@@ -404,13 +403,17 @@ class BaseTrainer:
|
|
404
403
|
break
|
405
404
|
|
406
405
|
# Log
|
407
|
-
|
408
|
-
|
409
|
-
losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
|
410
|
-
if RANK in (-1, 0):
|
406
|
+
if RANK in {-1, 0}:
|
407
|
+
loss_length = self.tloss.shape[0] if len(self.tloss.shape) else 1
|
411
408
|
pbar.set_description(
|
412
|
-
("%11s" * 2 + "%11.4g" * (2 +
|
413
|
-
% (
|
409
|
+
("%11s" * 2 + "%11.4g" * (2 + loss_length))
|
410
|
+
% (
|
411
|
+
f"{epoch + 1}/{self.epochs}",
|
412
|
+
f"{self._get_memory():.3g}G", # (GB) GPU memory util
|
413
|
+
*(self.tloss if loss_length > 1 else torch.unsqueeze(self.tloss, 0)), # losses
|
414
|
+
batch["cls"].shape[0], # batch size, i.e. 8
|
415
|
+
batch["img"].shape[-1], # imgsz, i.e 640
|
416
|
+
)
|
414
417
|
)
|
415
418
|
self.run_callbacks("on_batch_end")
|
416
419
|
if self.args.plots and ni in self.plot_idx:
|
@@ -420,8 +423,8 @@ class BaseTrainer:
|
|
420
423
|
|
421
424
|
self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
|
422
425
|
self.run_callbacks("on_train_epoch_end")
|
423
|
-
if RANK in
|
424
|
-
final_epoch = epoch + 1
|
426
|
+
if RANK in {-1, 0}:
|
427
|
+
final_epoch = epoch + 1 >= self.epochs
|
425
428
|
self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
|
426
429
|
|
427
430
|
# Validation
|
@@ -441,17 +444,14 @@ class BaseTrainer:
|
|
441
444
|
t = time.time()
|
442
445
|
self.epoch_time = t - self.epoch_time_start
|
443
446
|
self.epoch_time_start = t
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
self.scheduler.last_epoch = self.epoch # do not move
|
451
|
-
self.stop |= epoch >= self.epochs # stop if exceeded epochs
|
452
|
-
self.scheduler.step()
|
447
|
+
if self.args.time:
|
448
|
+
mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1)
|
449
|
+
self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time)
|
450
|
+
self._setup_scheduler()
|
451
|
+
self.scheduler.last_epoch = self.epoch # do not move
|
452
|
+
self.stop |= epoch >= self.epochs # stop if exceeded epochs
|
453
453
|
self.run_callbacks("on_fit_epoch_end")
|
454
|
-
|
454
|
+
self._clear_memory()
|
455
455
|
|
456
456
|
# Early Stopping
|
457
457
|
if RANK != -1: # if DDP training
|
@@ -462,55 +462,109 @@ class BaseTrainer:
|
|
462
462
|
break # must break all DDP ranks
|
463
463
|
epoch += 1
|
464
464
|
|
465
|
-
if RANK in
|
465
|
+
if RANK in {-1, 0}:
|
466
466
|
# Do final val with best.pt
|
467
|
-
|
468
|
-
|
469
|
-
f"{(time.time() - self.train_time_start) / 3600:.3f} hours."
|
470
|
-
)
|
467
|
+
seconds = time.time() - self.train_time_start
|
468
|
+
LOGGER.info(f"\n{epoch - self.start_epoch + 1} epochs completed in {seconds / 3600:.3f} hours.")
|
471
469
|
self.final_eval()
|
472
470
|
if self.args.plots:
|
473
471
|
self.plot_metrics()
|
474
472
|
self.run_callbacks("on_train_end")
|
475
|
-
|
473
|
+
self._clear_memory()
|
476
474
|
self.run_callbacks("teardown")
|
477
475
|
|
476
|
+
def auto_batch(self, max_num_obj=0):
|
477
|
+
"""Get batch size by calculating memory occupation of model."""
|
478
|
+
return check_train_batch_size(
|
479
|
+
model=self.model,
|
480
|
+
imgsz=self.args.imgsz,
|
481
|
+
amp=self.amp,
|
482
|
+
batch=self.batch_size,
|
483
|
+
max_num_obj=max_num_obj,
|
484
|
+
) # returns batch size
|
485
|
+
|
486
|
+
def _get_memory(self):
|
487
|
+
"""Get accelerator memory utilization in GB."""
|
488
|
+
if self.device.type == "mps":
|
489
|
+
memory = torch.mps.driver_allocated_memory()
|
490
|
+
elif self.device.type == "cpu":
|
491
|
+
memory = 0
|
492
|
+
else:
|
493
|
+
memory = torch.cuda.memory_reserved()
|
494
|
+
return memory / 1e9
|
495
|
+
|
496
|
+
def _clear_memory(self):
|
497
|
+
"""Clear accelerator memory on different platforms."""
|
498
|
+
gc.collect()
|
499
|
+
if self.device.type == "mps":
|
500
|
+
torch.mps.empty_cache()
|
501
|
+
elif self.device.type == "cpu":
|
502
|
+
return
|
503
|
+
else:
|
504
|
+
torch.cuda.empty_cache()
|
505
|
+
|
506
|
+
def read_results_csv(self):
|
507
|
+
"""Read results.csv into a dict using pandas."""
|
508
|
+
import pandas as pd # scope for faster 'import ultralytics'
|
509
|
+
|
510
|
+
return pd.read_csv(self.csv).to_dict(orient="list")
|
511
|
+
|
478
512
|
def save_model(self):
|
479
513
|
"""Save model training checkpoints with additional metadata."""
|
480
|
-
import
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
514
|
+
import io
|
515
|
+
|
516
|
+
# Serialize ckpt to a byte buffer once (faster than repeated torch.save() calls)
|
517
|
+
buffer = io.BytesIO()
|
518
|
+
torch.save(
|
519
|
+
{
|
520
|
+
"epoch": self.epoch,
|
521
|
+
"best_fitness": self.best_fitness,
|
522
|
+
"model": None, # resume and final checkpoints derive from EMA
|
523
|
+
"ema": deepcopy(self.ema.ema).half(),
|
524
|
+
"updates": self.ema.updates,
|
525
|
+
"optimizer": convert_optimizer_state_dict_to_fp16(deepcopy(self.optimizer.state_dict())),
|
526
|
+
"train_args": vars(self.args), # save as dict
|
527
|
+
"train_metrics": {**self.metrics, **{"fitness": self.fitness}},
|
528
|
+
"train_results": self.read_results_csv(),
|
529
|
+
"date": datetime.now().isoformat(),
|
530
|
+
"version": __version__,
|
531
|
+
"license": "AGPL-3.0 (https://ultralytics.com/license)",
|
532
|
+
"docs": "https://docs.ultralytics.com",
|
533
|
+
},
|
534
|
+
buffer,
|
535
|
+
)
|
536
|
+
serialized_ckpt = buffer.getvalue() # get the serialized content to save
|
537
|
+
|
538
|
+
# Save checkpoints
|
539
|
+
self.last.write_bytes(serialized_ckpt) # save last.pt
|
502
540
|
if self.best_fitness == self.fitness:
|
503
|
-
|
504
|
-
if (self.save_period > 0) and (self.epoch
|
505
|
-
|
541
|
+
self.best.write_bytes(serialized_ckpt) # save best.pt
|
542
|
+
if (self.save_period > 0) and (self.epoch % self.save_period == 0):
|
543
|
+
(self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'
|
544
|
+
# if self.args.close_mosaic and self.epoch == (self.epochs - self.args.close_mosaic - 1):
|
545
|
+
# (self.wdir / "last_mosaic.pt").write_bytes(serialized_ckpt) # save mosaic checkpoint
|
506
546
|
|
507
|
-
|
508
|
-
def get_dataset(data):
|
547
|
+
def get_dataset(self):
|
509
548
|
"""
|
510
549
|
Get train, val path from data dict if it exists.
|
511
550
|
|
512
551
|
Returns None if data format is not recognized.
|
513
552
|
"""
|
553
|
+
try:
|
554
|
+
if self.args.task == "classify":
|
555
|
+
data = check_cls_dataset(self.args.data)
|
556
|
+
elif self.args.data.split(".")[-1] in {"yaml", "yml"} or self.args.task in {
|
557
|
+
"detect",
|
558
|
+
"segment",
|
559
|
+
"pose",
|
560
|
+
"obb",
|
561
|
+
}:
|
562
|
+
data = check_det_dataset(self.args.data)
|
563
|
+
if "yaml_file" in data:
|
564
|
+
self.args.data = data["yaml_file"] # for validating 'yolo train data=url.zip' usage
|
565
|
+
except Exception as e:
|
566
|
+
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
|
567
|
+
self.data = data
|
514
568
|
return data["train"], data.get("val") or data.get("test")
|
515
569
|
|
516
570
|
def setup_model(self):
|
@@ -518,13 +572,13 @@ class BaseTrainer:
|
|
518
572
|
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
|
519
573
|
return
|
520
574
|
|
521
|
-
|
575
|
+
cfg, weights = self.model, None
|
522
576
|
ckpt = None
|
523
|
-
if str(model).endswith(".pt"):
|
524
|
-
weights, ckpt = attempt_load_one_weight(model)
|
525
|
-
cfg =
|
526
|
-
|
527
|
-
|
577
|
+
if str(self.model).endswith(".pt"):
|
578
|
+
weights, ckpt = attempt_load_one_weight(self.model)
|
579
|
+
cfg = weights.yaml
|
580
|
+
elif isinstance(self.args.pretrained, (str, Path)):
|
581
|
+
weights, _ = attempt_load_one_weight(self.args.pretrained)
|
528
582
|
self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
|
529
583
|
return ckpt
|
530
584
|
|
@@ -603,26 +657,31 @@ class BaseTrainer:
|
|
603
657
|
def save_metrics(self, metrics):
|
604
658
|
"""Saves training metrics to a CSV file."""
|
605
659
|
keys, vals = list(metrics.keys()), list(metrics.values())
|
606
|
-
n = len(metrics) +
|
607
|
-
s = "" if self.csv.exists() else (("%
|
660
|
+
n = len(metrics) + 2 # number of cols
|
661
|
+
s = "" if self.csv.exists() else (("%s," * n % tuple(["epoch", "time"] + keys)).rstrip(",") + "\n") # header
|
662
|
+
t = time.time() - self.train_time_start
|
608
663
|
with open(self.csv, "a") as f:
|
609
|
-
f.write(s + ("
|
664
|
+
f.write(s + ("%.6g," * n % tuple([self.epoch + 1, t] + vals)).rstrip(",") + "\n")
|
610
665
|
|
611
666
|
def plot_metrics(self):
|
612
667
|
"""Plot and display metrics visually."""
|
613
668
|
pass
|
614
669
|
|
615
670
|
def on_plot(self, name, data=None):
|
616
|
-
"""Registers plots (e.g. to be consumed in callbacks)"""
|
671
|
+
"""Registers plots (e.g. to be consumed in callbacks)."""
|
617
672
|
path = Path(name)
|
618
673
|
self.plots[path] = {"data": data, "timestamp": time.time()}
|
619
674
|
|
620
675
|
def final_eval(self):
|
621
676
|
"""Performs final evaluation and validation for object detection YOLO model."""
|
677
|
+
ckpt = {}
|
622
678
|
for f in self.last, self.best:
|
623
679
|
if f.exists():
|
624
|
-
|
625
|
-
|
680
|
+
if f is self.last:
|
681
|
+
ckpt = strip_optimizer(f)
|
682
|
+
elif f is self.best:
|
683
|
+
k = "train_results" # update best.pt train_metrics from last.pt
|
684
|
+
strip_optimizer(f, updates={k: ckpt[k]} if k in ckpt else None)
|
626
685
|
LOGGER.info(f"\nValidating {f}...")
|
627
686
|
self.validator.args.plots = self.args.plots
|
628
687
|
self.metrics = self.validator(model=f)
|
@@ -644,8 +703,13 @@ class BaseTrainer:
|
|
644
703
|
|
645
704
|
resume = True
|
646
705
|
self.args = get_cfg(ckpt_args)
|
647
|
-
self.args.model = str(last) # reinstate model
|
648
|
-
for k in
|
706
|
+
self.args.model = self.args.resume = str(last) # reinstate model
|
707
|
+
for k in (
|
708
|
+
"imgsz",
|
709
|
+
"batch",
|
710
|
+
"device",
|
711
|
+
"close_mosaic",
|
712
|
+
): # allow arg updates to reduce memory or update device on resume
|
649
713
|
if k in overrides:
|
650
714
|
setattr(self.args, k, overrides[k])
|
651
715
|
|
@@ -658,24 +722,21 @@ class BaseTrainer:
|
|
658
722
|
|
659
723
|
def resume_training(self, ckpt):
|
660
724
|
"""Resume YOLO training from given epoch and best fitness."""
|
661
|
-
if ckpt is None:
|
725
|
+
if ckpt is None or not self.resume:
|
662
726
|
return
|
663
727
|
best_fitness = 0.0
|
664
|
-
start_epoch = ckpt
|
665
|
-
if ckpt
|
728
|
+
start_epoch = ckpt.get("epoch", -1) + 1
|
729
|
+
if ckpt.get("optimizer", None) is not None:
|
666
730
|
self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer
|
667
731
|
best_fitness = ckpt["best_fitness"]
|
668
732
|
if self.ema and ckpt.get("ema"):
|
669
733
|
self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict()) # EMA
|
670
734
|
self.ema.updates = ckpt["updates"]
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
LOGGER.info(
|
677
|
-
f"Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs"
|
678
|
-
)
|
735
|
+
assert start_epoch > 0, (
|
736
|
+
f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n"
|
737
|
+
f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
|
738
|
+
)
|
739
|
+
LOGGER.info(f"Resuming training {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs")
|
679
740
|
if self.epochs < start_epoch:
|
680
741
|
LOGGER.info(
|
681
742
|
f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
|
@@ -692,7 +753,7 @@ class BaseTrainer:
|
|
692
753
|
self.train_loader.dataset.mosaic = False
|
693
754
|
if hasattr(self.train_loader.dataset, "close_mosaic"):
|
694
755
|
LOGGER.info("Closing dataloader mosaic")
|
695
|
-
self.train_loader.dataset.close_mosaic(hyp=self.args)
|
756
|
+
self.train_loader.dataset.close_mosaic(hyp=copy(self.args))
|
696
757
|
|
697
758
|
def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
|
698
759
|
"""
|
@@ -712,7 +773,6 @@ class BaseTrainer:
|
|
712
773
|
Returns:
|
713
774
|
(torch.optim.Optimizer): The constructed optimizer.
|
714
775
|
"""
|
715
|
-
|
716
776
|
g = [], [], [] # optimizer parameter groups
|
717
777
|
bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d()
|
718
778
|
if name == "auto":
|
@@ -736,7 +796,9 @@ class BaseTrainer:
|
|
736
796
|
else: # weight (with decay)
|
737
797
|
g[0].append(param)
|
738
798
|
|
739
|
-
|
799
|
+
optimizers = {"Adam", "Adamax", "AdamW", "NAdam", "RAdam", "RMSProp", "SGD", "auto"}
|
800
|
+
name = {x.lower(): x for x in optimizers}.get(name.lower())
|
801
|
+
if name in {"Adam", "Adamax", "AdamW", "NAdam", "RAdam"}:
|
740
802
|
optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
|
741
803
|
elif name == "RMSProp":
|
742
804
|
optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
|
@@ -744,15 +806,14 @@ class BaseTrainer:
|
|
744
806
|
optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
|
745
807
|
else:
|
746
808
|
raise NotImplementedError(
|
747
|
-
f"Optimizer '{name}' not found in list of available optimizers "
|
748
|
-
|
749
|
-
"To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics."
|
809
|
+
f"Optimizer '{name}' not found in list of available optimizers {optimizers}. "
|
810
|
+
"Request support for addition optimizers at https://github.com/ultralytics/ultralytics."
|
750
811
|
)
|
751
812
|
|
752
813
|
optimizer.add_param_group({"params": g[0], "weight_decay": decay}) # add g0 with weight_decay
|
753
814
|
optimizer.add_param_group({"params": g[1], "weight_decay": 0.0}) # add g1 (BatchNorm2d weights)
|
754
815
|
LOGGER.info(
|
755
816
|
f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
|
756
|
-
f
|
817
|
+
f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)"
|
757
818
|
)
|
758
819
|
return optimizer
|