ultralytics 8.1.29__py3-none-any.whl → 8.3.63__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +22 -0
- tests/conftest.py +83 -0
- tests/test_cli.py +122 -0
- tests/test_cuda.py +155 -0
- tests/test_engine.py +131 -0
- tests/test_exports.py +216 -0
- tests/test_integrations.py +150 -0
- tests/test_python.py +615 -0
- tests/test_solutions.py +94 -0
- ultralytics/__init__.py +11 -8
- ultralytics/cfg/__init__.py +569 -131
- ultralytics/cfg/datasets/Argoverse.yaml +2 -1
- ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
- ultralytics/cfg/datasets/ImageNet.yaml +2 -1
- ultralytics/cfg/datasets/Objects365.yaml +5 -4
- ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
- ultralytics/cfg/datasets/VOC.yaml +3 -2
- ultralytics/cfg/datasets/VisDrone.yaml +6 -5
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
- ultralytics/cfg/datasets/coco-pose.yaml +7 -6
- ultralytics/cfg/datasets/coco.yaml +3 -2
- ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
- ultralytics/cfg/datasets/coco128.yaml +4 -3
- ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
- ultralytics/cfg/datasets/coco8.yaml +3 -2
- ultralytics/cfg/datasets/crack-seg.yaml +3 -2
- ultralytics/cfg/datasets/dog-pose.yaml +24 -0
- ultralytics/cfg/datasets/dota8.yaml +3 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
- ultralytics/cfg/datasets/lvis.yaml +1236 -0
- ultralytics/cfg/datasets/medical-pills.yaml +22 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
- ultralytics/cfg/datasets/package-seg.yaml +5 -4
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
- ultralytics/cfg/datasets/xView.yaml +2 -1
- ultralytics/cfg/default.yaml +14 -11
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
- ultralytics/cfg/models/v3/yolov3.yaml +5 -2
- ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
- ultralytics/cfg/models/v5/yolov5.yaml +5 -2
- ultralytics/cfg/models/v6/yolov6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8.yaml +5 -2
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/solutions/default.yaml +24 -0
- ultralytics/cfg/trackers/botsort.yaml +8 -5
- ultralytics/cfg/trackers/bytetrack.yaml +8 -5
- ultralytics/data/__init__.py +14 -3
- ultralytics/data/annotator.py +37 -15
- ultralytics/data/augment.py +1783 -289
- ultralytics/data/base.py +62 -27
- ultralytics/data/build.py +37 -8
- ultralytics/data/converter.py +196 -36
- ultralytics/data/dataset.py +233 -94
- ultralytics/data/loaders.py +199 -96
- ultralytics/data/split_dota.py +39 -29
- ultralytics/data/utils.py +111 -41
- ultralytics/engine/__init__.py +1 -1
- ultralytics/engine/exporter.py +579 -244
- ultralytics/engine/model.py +604 -252
- ultralytics/engine/predictor.py +22 -11
- ultralytics/engine/results.py +1228 -218
- ultralytics/engine/trainer.py +191 -129
- ultralytics/engine/tuner.py +18 -18
- ultralytics/engine/validator.py +18 -15
- ultralytics/hub/__init__.py +31 -13
- ultralytics/hub/auth.py +11 -7
- ultralytics/hub/google/__init__.py +159 -0
- ultralytics/hub/session.py +128 -94
- ultralytics/hub/utils.py +20 -21
- ultralytics/models/__init__.py +4 -2
- ultralytics/models/fastsam/__init__.py +2 -3
- ultralytics/models/fastsam/model.py +26 -4
- ultralytics/models/fastsam/predict.py +127 -63
- ultralytics/models/fastsam/utils.py +1 -44
- ultralytics/models/fastsam/val.py +1 -1
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +21 -10
- ultralytics/models/nas/predict.py +3 -6
- ultralytics/models/nas/val.py +4 -4
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +1 -1
- ultralytics/models/rtdetr/predict.py +6 -8
- ultralytics/models/rtdetr/train.py +6 -2
- ultralytics/models/rtdetr/val.py +3 -3
- ultralytics/models/sam/__init__.py +3 -3
- ultralytics/models/sam/amg.py +29 -23
- ultralytics/models/sam/build.py +211 -13
- ultralytics/models/sam/model.py +91 -30
- ultralytics/models/sam/modules/__init__.py +1 -1
- ultralytics/models/sam/modules/blocks.py +1129 -0
- ultralytics/models/sam/modules/decoders.py +381 -53
- ultralytics/models/sam/modules/encoders.py +515 -324
- ultralytics/models/sam/modules/memory_attention.py +237 -0
- ultralytics/models/sam/modules/sam.py +969 -21
- ultralytics/models/sam/modules/tiny_encoder.py +425 -154
- ultralytics/models/sam/modules/transformer.py +159 -60
- ultralytics/models/sam/modules/utils.py +293 -0
- ultralytics/models/sam/predict.py +1263 -132
- ultralytics/models/utils/__init__.py +1 -1
- ultralytics/models/utils/loss.py +36 -24
- ultralytics/models/utils/ops.py +3 -7
- ultralytics/models/yolo/__init__.py +3 -3
- ultralytics/models/yolo/classify/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +7 -8
- ultralytics/models/yolo/classify/train.py +17 -22
- ultralytics/models/yolo/classify/val.py +8 -4
- ultralytics/models/yolo/detect/__init__.py +1 -1
- ultralytics/models/yolo/detect/predict.py +3 -5
- ultralytics/models/yolo/detect/train.py +11 -4
- ultralytics/models/yolo/detect/val.py +90 -52
- ultralytics/models/yolo/model.py +14 -9
- ultralytics/models/yolo/obb/__init__.py +1 -1
- ultralytics/models/yolo/obb/predict.py +2 -2
- ultralytics/models/yolo/obb/train.py +5 -3
- ultralytics/models/yolo/obb/val.py +41 -23
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +3 -5
- ultralytics/models/yolo/pose/train.py +2 -2
- ultralytics/models/yolo/pose/val.py +51 -17
- ultralytics/models/yolo/segment/__init__.py +1 -1
- ultralytics/models/yolo/segment/predict.py +3 -5
- ultralytics/models/yolo/segment/train.py +2 -2
- ultralytics/models/yolo/segment/val.py +60 -19
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +92 -0
- ultralytics/models/yolo/world/train_world.py +109 -0
- ultralytics/nn/__init__.py +1 -1
- ultralytics/nn/autobackend.py +228 -93
- ultralytics/nn/modules/__init__.py +39 -14
- ultralytics/nn/modules/activation.py +21 -0
- ultralytics/nn/modules/block.py +526 -66
- ultralytics/nn/modules/conv.py +24 -7
- ultralytics/nn/modules/head.py +177 -34
- ultralytics/nn/modules/transformer.py +6 -5
- ultralytics/nn/modules/utils.py +1 -2
- ultralytics/nn/tasks.py +226 -82
- ultralytics/solutions/__init__.py +30 -1
- ultralytics/solutions/ai_gym.py +96 -143
- ultralytics/solutions/analytics.py +247 -0
- ultralytics/solutions/distance_calculation.py +78 -135
- ultralytics/solutions/heatmap.py +93 -247
- ultralytics/solutions/object_counter.py +184 -259
- ultralytics/solutions/parking_management.py +246 -0
- ultralytics/solutions/queue_management.py +112 -0
- ultralytics/solutions/region_counter.py +116 -0
- ultralytics/solutions/security_alarm.py +144 -0
- ultralytics/solutions/solutions.py +178 -0
- ultralytics/solutions/speed_estimation.py +86 -174
- ultralytics/solutions/streamlit_inference.py +190 -0
- ultralytics/solutions/trackzone.py +68 -0
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +32 -13
- ultralytics/trackers/bot_sort.py +61 -28
- ultralytics/trackers/byte_tracker.py +83 -51
- ultralytics/trackers/track.py +21 -6
- ultralytics/trackers/utils/__init__.py +1 -1
- ultralytics/trackers/utils/gmc.py +62 -48
- ultralytics/trackers/utils/kalman_filter.py +166 -35
- ultralytics/trackers/utils/matching.py +40 -21
- ultralytics/utils/__init__.py +511 -239
- ultralytics/utils/autobatch.py +40 -22
- ultralytics/utils/benchmarks.py +266 -85
- ultralytics/utils/callbacks/__init__.py +1 -1
- ultralytics/utils/callbacks/base.py +1 -3
- ultralytics/utils/callbacks/clearml.py +7 -6
- ultralytics/utils/callbacks/comet.py +39 -17
- ultralytics/utils/callbacks/dvc.py +1 -1
- ultralytics/utils/callbacks/hub.py +16 -16
- ultralytics/utils/callbacks/mlflow.py +28 -24
- ultralytics/utils/callbacks/neptune.py +6 -2
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +18 -18
- ultralytics/utils/callbacks/wb.py +27 -20
- ultralytics/utils/checks.py +172 -100
- ultralytics/utils/dist.py +2 -1
- ultralytics/utils/downloads.py +40 -34
- ultralytics/utils/errors.py +1 -1
- ultralytics/utils/files.py +72 -38
- ultralytics/utils/instance.py +41 -19
- ultralytics/utils/loss.py +83 -55
- ultralytics/utils/metrics.py +61 -56
- ultralytics/utils/ops.py +94 -89
- ultralytics/utils/patches.py +30 -14
- ultralytics/utils/plotting.py +600 -269
- ultralytics/utils/tal.py +67 -26
- ultralytics/utils/torch_utils.py +305 -112
- ultralytics/utils/triton.py +2 -1
- ultralytics/utils/tuner.py +21 -12
- ultralytics-8.3.63.dist-info/METADATA +370 -0
- ultralytics-8.3.63.dist-info/RECORD +241 -0
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/WHEEL +1 -1
- ultralytics/data/explorer/__init__.py +0 -5
- ultralytics/data/explorer/explorer.py +0 -472
- ultralytics/data/explorer/gui/__init__.py +0 -1
- ultralytics/data/explorer/gui/dash.py +0 -268
- ultralytics/data/explorer/utils.py +0 -166
- ultralytics/models/fastsam/prompt.py +0 -357
- ultralytics-8.1.29.dist-info/METADATA +0 -373
- ultralytics-8.1.29.dist-info/RECORD +0 -197
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/LICENSE +0 -0
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/top_level.txt +0 -0
ultralytics/engine/trainer.py
CHANGED
@@ -1,17 +1,18 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
"""
|
3
3
|
Train a model on a dataset.
|
4
4
|
|
5
5
|
Usage:
|
6
|
-
$ yolo mode=train model=yolov8n.pt data=
|
6
|
+
$ yolo mode=train model=yolov8n.pt data=coco8.yaml imgsz=640 epochs=100 batch=16
|
7
7
|
"""
|
8
8
|
|
9
|
+
import gc
|
9
10
|
import math
|
10
11
|
import os
|
11
12
|
import subprocess
|
12
13
|
import time
|
13
14
|
import warnings
|
14
|
-
from copy import deepcopy
|
15
|
+
from copy import copy, deepcopy
|
15
16
|
from datetime import datetime, timedelta
|
16
17
|
from pathlib import Path
|
17
18
|
|
@@ -25,6 +26,7 @@ from ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
|
25
26
|
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
|
26
27
|
from ultralytics.utils import (
|
27
28
|
DEFAULT_CFG,
|
29
|
+
LOCAL_RANK,
|
28
30
|
LOGGER,
|
29
31
|
RANK,
|
30
32
|
TQDM,
|
@@ -40,20 +42,21 @@ from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_m
|
|
40
42
|
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
|
41
43
|
from ultralytics.utils.files import get_latest_run
|
42
44
|
from ultralytics.utils.torch_utils import (
|
45
|
+
TORCH_2_4,
|
43
46
|
EarlyStopping,
|
44
47
|
ModelEMA,
|
45
|
-
|
48
|
+
autocast,
|
49
|
+
convert_optimizer_state_dict_to_fp16,
|
46
50
|
init_seeds,
|
47
51
|
one_cycle,
|
48
52
|
select_device,
|
49
53
|
strip_optimizer,
|
54
|
+
torch_distributed_zero_first,
|
50
55
|
)
|
51
56
|
|
52
57
|
|
53
58
|
class BaseTrainer:
|
54
59
|
"""
|
55
|
-
BaseTrainer.
|
56
|
-
|
57
60
|
A base class for creating trainers.
|
58
61
|
|
59
62
|
Attributes:
|
@@ -107,7 +110,7 @@ class BaseTrainer:
|
|
107
110
|
self.save_dir = get_save_dir(self.args)
|
108
111
|
self.args.name = self.save_dir.name # update name for loggers
|
109
112
|
self.wdir = self.save_dir / "weights" # weights dir
|
110
|
-
if RANK in
|
113
|
+
if RANK in {-1, 0}:
|
111
114
|
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
112
115
|
self.args.save_dir = str(self.save_dir)
|
113
116
|
yaml_save(self.save_dir / "args.yaml", vars(self.args)) # save run args
|
@@ -115,33 +118,19 @@ class BaseTrainer:
|
|
115
118
|
self.save_period = self.args.save_period
|
116
119
|
|
117
120
|
self.batch_size = self.args.batch
|
118
|
-
self.epochs = self.args.epochs
|
121
|
+
self.epochs = self.args.epochs or 100 # in case users accidentally pass epochs=None with timed training
|
119
122
|
self.start_epoch = 0
|
120
123
|
if RANK == -1:
|
121
124
|
print_args(vars(self.args))
|
122
125
|
|
123
126
|
# Device
|
124
|
-
if self.device.type in
|
127
|
+
if self.device.type in {"cpu", "mps"}:
|
125
128
|
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
|
126
129
|
|
127
130
|
# Model and Dataset
|
128
131
|
self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt
|
129
|
-
|
130
|
-
|
131
|
-
self.data = check_cls_dataset(self.args.data)
|
132
|
-
elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in (
|
133
|
-
"detect",
|
134
|
-
"segment",
|
135
|
-
"pose",
|
136
|
-
"obb",
|
137
|
-
):
|
138
|
-
self.data = check_det_dataset(self.args.data)
|
139
|
-
if "yaml_file" in self.data:
|
140
|
-
self.args.data = self.data["yaml_file"] # for validating 'yolo train data=url.zip' usage
|
141
|
-
except Exception as e:
|
142
|
-
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
|
143
|
-
|
144
|
-
self.trainset, self.testset = self.get_dataset(self.data)
|
132
|
+
with torch_distributed_zero_first(LOCAL_RANK): # avoid auto-downloading dataset multiple times
|
133
|
+
self.trainset, self.testset = self.get_dataset()
|
145
134
|
self.ema = None
|
146
135
|
|
147
136
|
# Optimization utils init
|
@@ -157,9 +146,12 @@ class BaseTrainer:
|
|
157
146
|
self.csv = self.save_dir / "results.csv"
|
158
147
|
self.plot_idx = [0, 1, 2]
|
159
148
|
|
149
|
+
# HUB
|
150
|
+
self.hub_session = None
|
151
|
+
|
160
152
|
# Callbacks
|
161
153
|
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
162
|
-
if RANK in
|
154
|
+
if RANK in {-1, 0}:
|
163
155
|
callbacks.add_integration_callbacks(self)
|
164
156
|
|
165
157
|
def add_callback(self, event: str, callback):
|
@@ -181,9 +173,11 @@ class BaseTrainer:
|
|
181
173
|
world_size = len(self.args.device.split(","))
|
182
174
|
elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
|
183
175
|
world_size = len(self.args.device)
|
176
|
+
elif self.args.device in {"cpu", "mps"}: # i.e. device='cpu' or 'mps'
|
177
|
+
world_size = 0
|
184
178
|
elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number
|
185
179
|
world_size = 1 # default to device 0
|
186
|
-
else: # i.e. device=
|
180
|
+
else: # i.e. device=None or device=''
|
187
181
|
world_size = 0
|
188
182
|
|
189
183
|
# Run subprocess if DDP training, else train normally
|
@@ -192,9 +186,9 @@ class BaseTrainer:
|
|
192
186
|
if self.args.rect:
|
193
187
|
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'")
|
194
188
|
self.args.rect = False
|
195
|
-
if self.args.batch
|
189
|
+
if self.args.batch < 1.0:
|
196
190
|
LOGGER.warning(
|
197
|
-
"WARNING ⚠️ 'batch
|
191
|
+
"WARNING ⚠️ 'batch<1' for AutoBatch is incompatible with Multi-GPU training, setting "
|
198
192
|
"default 'batch=16'"
|
199
193
|
)
|
200
194
|
self.args.batch = 16
|
@@ -202,7 +196,7 @@ class BaseTrainer:
|
|
202
196
|
# Command
|
203
197
|
cmd, file = generate_ddp_command(world_size, self)
|
204
198
|
try:
|
205
|
-
LOGGER.info(f
|
199
|
+
LOGGER.info(f"{colorstr('DDP:')} debug command {' '.join(cmd)}")
|
206
200
|
subprocess.run(cmd, check=True)
|
207
201
|
except Exception as e:
|
208
202
|
raise e
|
@@ -225,9 +219,9 @@ class BaseTrainer:
|
|
225
219
|
torch.cuda.set_device(RANK)
|
226
220
|
self.device = torch.device("cuda", RANK)
|
227
221
|
# LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
|
228
|
-
os.environ["
|
222
|
+
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout
|
229
223
|
dist.init_process_group(
|
230
|
-
"nccl" if dist.is_nccl_available() else "gloo",
|
224
|
+
backend="nccl" if dist.is_nccl_available() else "gloo",
|
231
225
|
timeout=timedelta(seconds=10800), # 3 hours
|
232
226
|
rank=RANK,
|
233
227
|
world_size=world_size,
|
@@ -235,7 +229,6 @@ class BaseTrainer:
|
|
235
229
|
|
236
230
|
def _setup_train(self, world_size):
|
237
231
|
"""Builds dataloaders and optimizer on correct rank process."""
|
238
|
-
|
239
232
|
# Model
|
240
233
|
self.run_callbacks("on_pretrain_routine_start")
|
241
234
|
ckpt = self.setup_model()
|
@@ -266,16 +259,19 @@ class BaseTrainer:
|
|
266
259
|
|
267
260
|
# Check AMP
|
268
261
|
self.amp = torch.tensor(self.args.amp).to(self.device) # True or False
|
269
|
-
if self.amp and RANK in
|
262
|
+
if self.amp and RANK in {-1, 0}: # Single-GPU and DDP
|
270
263
|
callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
|
271
264
|
self.amp = torch.tensor(check_amp(self.model), device=self.device)
|
272
265
|
callbacks.default_callbacks = callbacks_backup # restore callbacks
|
273
266
|
if RANK > -1 and world_size > 1: # DDP
|
274
267
|
dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None)
|
275
268
|
self.amp = bool(self.amp) # as boolean
|
276
|
-
self.scaler =
|
269
|
+
self.scaler = (
|
270
|
+
torch.amp.GradScaler("cuda", enabled=self.amp) if TORCH_2_4 else torch.cuda.amp.GradScaler(enabled=self.amp)
|
271
|
+
)
|
277
272
|
if world_size > 1:
|
278
|
-
self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK])
|
273
|
+
self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK], find_unused_parameters=True)
|
274
|
+
self.set_model_attributes() # set again after DDP wrapper
|
279
275
|
|
280
276
|
# Check imgsz
|
281
277
|
gs = max(int(self.model.stride.max() if hasattr(self.model, "stride") else 32), 32) # grid size (max stride)
|
@@ -283,13 +279,13 @@ class BaseTrainer:
|
|
283
279
|
self.stride = gs # for multiscale training
|
284
280
|
|
285
281
|
# Batch size
|
286
|
-
if self.batch_size
|
287
|
-
self.args.batch = self.batch_size =
|
282
|
+
if self.batch_size < 1 and RANK == -1: # single-GPU only, estimate best batch size
|
283
|
+
self.args.batch = self.batch_size = self.auto_batch()
|
288
284
|
|
289
285
|
# Dataloaders
|
290
286
|
batch_size = self.batch_size // max(world_size, 1)
|
291
|
-
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=
|
292
|
-
if RANK in
|
287
|
+
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=LOCAL_RANK, mode="train")
|
288
|
+
if RANK in {-1, 0}:
|
293
289
|
# Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
|
294
290
|
self.test_loader = self.get_dataloader(
|
295
291
|
self.testset, batch_size=batch_size if self.args.task == "obb" else batch_size * 2, rank=-1, mode="val"
|
@@ -334,18 +330,23 @@ class BaseTrainer:
|
|
334
330
|
self.train_time_start = time.time()
|
335
331
|
self.run_callbacks("on_train_start")
|
336
332
|
LOGGER.info(
|
337
|
-
f
|
338
|
-
f
|
333
|
+
f"Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n"
|
334
|
+
f"Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n"
|
339
335
|
f"Logging results to {colorstr('bold', self.save_dir)}\n"
|
340
|
-
f
|
336
|
+
f"Starting training for " + (f"{self.args.time} hours..." if self.args.time else f"{self.epochs} epochs...")
|
341
337
|
)
|
342
338
|
if self.args.close_mosaic:
|
343
339
|
base_idx = (self.epochs - self.args.close_mosaic) * nb
|
344
340
|
self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
|
345
341
|
epoch = self.start_epoch
|
342
|
+
self.optimizer.zero_grad() # zero any resumed gradients to ensure stability on train start
|
346
343
|
while True:
|
347
344
|
self.epoch = epoch
|
348
345
|
self.run_callbacks("on_train_epoch_start")
|
346
|
+
with warnings.catch_warnings():
|
347
|
+
warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()'
|
348
|
+
self.scheduler.step()
|
349
|
+
|
349
350
|
self.model.train()
|
350
351
|
if RANK != -1:
|
351
352
|
self.train_loader.sampler.set_epoch(epoch)
|
@@ -355,11 +356,10 @@ class BaseTrainer:
|
|
355
356
|
self._close_dataloader_mosaic()
|
356
357
|
self.train_loader.reset()
|
357
358
|
|
358
|
-
if RANK in
|
359
|
+
if RANK in {-1, 0}:
|
359
360
|
LOGGER.info(self.progress_string())
|
360
361
|
pbar = TQDM(enumerate(self.train_loader), total=nb)
|
361
362
|
self.tloss = None
|
362
|
-
self.optimizer.zero_grad()
|
363
363
|
for i, batch in pbar:
|
364
364
|
self.run_callbacks("on_train_batch_start")
|
365
365
|
# Warmup
|
@@ -376,7 +376,7 @@ class BaseTrainer:
|
|
376
376
|
x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
|
377
377
|
|
378
378
|
# Forward
|
379
|
-
with
|
379
|
+
with autocast(self.amp):
|
380
380
|
batch = self.preprocess_batch(batch)
|
381
381
|
self.loss, self.loss_items = self.model(batch)
|
382
382
|
if RANK != -1:
|
@@ -404,13 +404,17 @@ class BaseTrainer:
|
|
404
404
|
break
|
405
405
|
|
406
406
|
# Log
|
407
|
-
|
408
|
-
|
409
|
-
losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
|
410
|
-
if RANK in (-1, 0):
|
407
|
+
if RANK in {-1, 0}:
|
408
|
+
loss_length = self.tloss.shape[0] if len(self.tloss.shape) else 1
|
411
409
|
pbar.set_description(
|
412
|
-
("%11s" * 2 + "%11.4g" * (2 +
|
413
|
-
% (
|
410
|
+
("%11s" * 2 + "%11.4g" * (2 + loss_length))
|
411
|
+
% (
|
412
|
+
f"{epoch + 1}/{self.epochs}",
|
413
|
+
f"{self._get_memory():.3g}G", # (GB) GPU memory util
|
414
|
+
*(self.tloss if loss_length > 1 else torch.unsqueeze(self.tloss, 0)), # losses
|
415
|
+
batch["cls"].shape[0], # batch size, i.e. 8
|
416
|
+
batch["img"].shape[-1], # imgsz, i.e 640
|
417
|
+
)
|
414
418
|
)
|
415
419
|
self.run_callbacks("on_batch_end")
|
416
420
|
if self.args.plots and ni in self.plot_idx:
|
@@ -420,8 +424,8 @@ class BaseTrainer:
|
|
420
424
|
|
421
425
|
self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
|
422
426
|
self.run_callbacks("on_train_epoch_end")
|
423
|
-
if RANK in
|
424
|
-
final_epoch = epoch + 1
|
427
|
+
if RANK in {-1, 0}:
|
428
|
+
final_epoch = epoch + 1 >= self.epochs
|
425
429
|
self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
|
426
430
|
|
427
431
|
# Validation
|
@@ -441,17 +445,14 @@ class BaseTrainer:
|
|
441
445
|
t = time.time()
|
442
446
|
self.epoch_time = t - self.epoch_time_start
|
443
447
|
self.epoch_time_start = t
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
self.scheduler.last_epoch = self.epoch # do not move
|
451
|
-
self.stop |= epoch >= self.epochs # stop if exceeded epochs
|
452
|
-
self.scheduler.step()
|
448
|
+
if self.args.time:
|
449
|
+
mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1)
|
450
|
+
self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time)
|
451
|
+
self._setup_scheduler()
|
452
|
+
self.scheduler.last_epoch = self.epoch # do not move
|
453
|
+
self.stop |= epoch >= self.epochs # stop if exceeded epochs
|
453
454
|
self.run_callbacks("on_fit_epoch_end")
|
454
|
-
|
455
|
+
self._clear_memory()
|
455
456
|
|
456
457
|
# Early Stopping
|
457
458
|
if RANK != -1: # if DDP training
|
@@ -462,55 +463,109 @@ class BaseTrainer:
|
|
462
463
|
break # must break all DDP ranks
|
463
464
|
epoch += 1
|
464
465
|
|
465
|
-
if RANK in
|
466
|
+
if RANK in {-1, 0}:
|
466
467
|
# Do final val with best.pt
|
467
|
-
|
468
|
-
|
469
|
-
f"{(time.time() - self.train_time_start) / 3600:.3f} hours."
|
470
|
-
)
|
468
|
+
seconds = time.time() - self.train_time_start
|
469
|
+
LOGGER.info(f"\n{epoch - self.start_epoch + 1} epochs completed in {seconds / 3600:.3f} hours.")
|
471
470
|
self.final_eval()
|
472
471
|
if self.args.plots:
|
473
472
|
self.plot_metrics()
|
474
473
|
self.run_callbacks("on_train_end")
|
475
|
-
|
474
|
+
self._clear_memory()
|
476
475
|
self.run_callbacks("teardown")
|
477
476
|
|
477
|
+
def auto_batch(self, max_num_obj=0):
|
478
|
+
"""Get batch size by calculating memory occupation of model."""
|
479
|
+
return check_train_batch_size(
|
480
|
+
model=self.model,
|
481
|
+
imgsz=self.args.imgsz,
|
482
|
+
amp=self.amp,
|
483
|
+
batch=self.batch_size,
|
484
|
+
max_num_obj=max_num_obj,
|
485
|
+
) # returns batch size
|
486
|
+
|
487
|
+
def _get_memory(self):
|
488
|
+
"""Get accelerator memory utilization in GB."""
|
489
|
+
if self.device.type == "mps":
|
490
|
+
memory = torch.mps.driver_allocated_memory()
|
491
|
+
elif self.device.type == "cpu":
|
492
|
+
memory = 0
|
493
|
+
else:
|
494
|
+
memory = torch.cuda.memory_reserved()
|
495
|
+
return memory / 1e9
|
496
|
+
|
497
|
+
def _clear_memory(self):
|
498
|
+
"""Clear accelerator memory on different platforms."""
|
499
|
+
gc.collect()
|
500
|
+
if self.device.type == "mps":
|
501
|
+
torch.mps.empty_cache()
|
502
|
+
elif self.device.type == "cpu":
|
503
|
+
return
|
504
|
+
else:
|
505
|
+
torch.cuda.empty_cache()
|
506
|
+
|
507
|
+
def read_results_csv(self):
|
508
|
+
"""Read results.csv into a dict using pandas."""
|
509
|
+
import pandas as pd # scope for faster 'import ultralytics'
|
510
|
+
|
511
|
+
return pd.read_csv(self.csv).to_dict(orient="list")
|
512
|
+
|
478
513
|
def save_model(self):
|
479
514
|
"""Save model training checkpoints with additional metadata."""
|
480
|
-
import
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
515
|
+
import io
|
516
|
+
|
517
|
+
# Serialize ckpt to a byte buffer once (faster than repeated torch.save() calls)
|
518
|
+
buffer = io.BytesIO()
|
519
|
+
torch.save(
|
520
|
+
{
|
521
|
+
"epoch": self.epoch,
|
522
|
+
"best_fitness": self.best_fitness,
|
523
|
+
"model": None, # resume and final checkpoints derive from EMA
|
524
|
+
"ema": deepcopy(self.ema.ema).half(),
|
525
|
+
"updates": self.ema.updates,
|
526
|
+
"optimizer": convert_optimizer_state_dict_to_fp16(deepcopy(self.optimizer.state_dict())),
|
527
|
+
"train_args": vars(self.args), # save as dict
|
528
|
+
"train_metrics": {**self.metrics, **{"fitness": self.fitness}},
|
529
|
+
"train_results": self.read_results_csv(),
|
530
|
+
"date": datetime.now().isoformat(),
|
531
|
+
"version": __version__,
|
532
|
+
"license": "AGPL-3.0 (https://ultralytics.com/license)",
|
533
|
+
"docs": "https://docs.ultralytics.com",
|
534
|
+
},
|
535
|
+
buffer,
|
536
|
+
)
|
537
|
+
serialized_ckpt = buffer.getvalue() # get the serialized content to save
|
538
|
+
|
539
|
+
# Save checkpoints
|
540
|
+
self.last.write_bytes(serialized_ckpt) # save last.pt
|
502
541
|
if self.best_fitness == self.fitness:
|
503
|
-
|
504
|
-
if (self.save_period > 0) and (self.epoch
|
505
|
-
|
542
|
+
self.best.write_bytes(serialized_ckpt) # save best.pt
|
543
|
+
if (self.save_period > 0) and (self.epoch % self.save_period == 0):
|
544
|
+
(self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'
|
545
|
+
# if self.args.close_mosaic and self.epoch == (self.epochs - self.args.close_mosaic - 1):
|
546
|
+
# (self.wdir / "last_mosaic.pt").write_bytes(serialized_ckpt) # save mosaic checkpoint
|
506
547
|
|
507
|
-
|
508
|
-
def get_dataset(data):
|
548
|
+
def get_dataset(self):
|
509
549
|
"""
|
510
550
|
Get train, val path from data dict if it exists.
|
511
551
|
|
512
552
|
Returns None if data format is not recognized.
|
513
553
|
"""
|
554
|
+
try:
|
555
|
+
if self.args.task == "classify":
|
556
|
+
data = check_cls_dataset(self.args.data)
|
557
|
+
elif self.args.data.split(".")[-1] in {"yaml", "yml"} or self.args.task in {
|
558
|
+
"detect",
|
559
|
+
"segment",
|
560
|
+
"pose",
|
561
|
+
"obb",
|
562
|
+
}:
|
563
|
+
data = check_det_dataset(self.args.data)
|
564
|
+
if "yaml_file" in data:
|
565
|
+
self.args.data = data["yaml_file"] # for validating 'yolo train data=url.zip' usage
|
566
|
+
except Exception as e:
|
567
|
+
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
|
568
|
+
self.data = data
|
514
569
|
return data["train"], data.get("val") or data.get("test")
|
515
570
|
|
516
571
|
def setup_model(self):
|
@@ -518,13 +573,13 @@ class BaseTrainer:
|
|
518
573
|
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
|
519
574
|
return
|
520
575
|
|
521
|
-
|
576
|
+
cfg, weights = self.model, None
|
522
577
|
ckpt = None
|
523
|
-
if str(model).endswith(".pt"):
|
524
|
-
weights, ckpt = attempt_load_one_weight(model)
|
525
|
-
cfg =
|
526
|
-
|
527
|
-
|
578
|
+
if str(self.model).endswith(".pt"):
|
579
|
+
weights, ckpt = attempt_load_one_weight(self.model)
|
580
|
+
cfg = weights.yaml
|
581
|
+
elif isinstance(self.args.pretrained, (str, Path)):
|
582
|
+
weights, _ = attempt_load_one_weight(self.args.pretrained)
|
528
583
|
self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
|
529
584
|
return ckpt
|
530
585
|
|
@@ -603,26 +658,31 @@ class BaseTrainer:
|
|
603
658
|
def save_metrics(self, metrics):
|
604
659
|
"""Saves training metrics to a CSV file."""
|
605
660
|
keys, vals = list(metrics.keys()), list(metrics.values())
|
606
|
-
n = len(metrics) +
|
607
|
-
s = "" if self.csv.exists() else (("%
|
661
|
+
n = len(metrics) + 2 # number of cols
|
662
|
+
s = "" if self.csv.exists() else (("%s," * n % tuple(["epoch", "time"] + keys)).rstrip(",") + "\n") # header
|
663
|
+
t = time.time() - self.train_time_start
|
608
664
|
with open(self.csv, "a") as f:
|
609
|
-
f.write(s + ("
|
665
|
+
f.write(s + ("%.6g," * n % tuple([self.epoch + 1, t] + vals)).rstrip(",") + "\n")
|
610
666
|
|
611
667
|
def plot_metrics(self):
|
612
668
|
"""Plot and display metrics visually."""
|
613
669
|
pass
|
614
670
|
|
615
671
|
def on_plot(self, name, data=None):
|
616
|
-
"""Registers plots (e.g. to be consumed in callbacks)"""
|
672
|
+
"""Registers plots (e.g. to be consumed in callbacks)."""
|
617
673
|
path = Path(name)
|
618
674
|
self.plots[path] = {"data": data, "timestamp": time.time()}
|
619
675
|
|
620
676
|
def final_eval(self):
|
621
677
|
"""Performs final evaluation and validation for object detection YOLO model."""
|
678
|
+
ckpt = {}
|
622
679
|
for f in self.last, self.best:
|
623
680
|
if f.exists():
|
624
|
-
|
625
|
-
|
681
|
+
if f is self.last:
|
682
|
+
ckpt = strip_optimizer(f)
|
683
|
+
elif f is self.best:
|
684
|
+
k = "train_results" # update best.pt train_metrics from last.pt
|
685
|
+
strip_optimizer(f, updates={k: ckpt[k]} if k in ckpt else None)
|
626
686
|
LOGGER.info(f"\nValidating {f}...")
|
627
687
|
self.validator.args.plots = self.args.plots
|
628
688
|
self.metrics = self.validator(model=f)
|
@@ -644,8 +704,13 @@ class BaseTrainer:
|
|
644
704
|
|
645
705
|
resume = True
|
646
706
|
self.args = get_cfg(ckpt_args)
|
647
|
-
self.args.model = str(last) # reinstate model
|
648
|
-
for k in
|
707
|
+
self.args.model = self.args.resume = str(last) # reinstate model
|
708
|
+
for k in (
|
709
|
+
"imgsz",
|
710
|
+
"batch",
|
711
|
+
"device",
|
712
|
+
"close_mosaic",
|
713
|
+
): # allow arg updates to reduce memory or update device on resume
|
649
714
|
if k in overrides:
|
650
715
|
setattr(self.args, k, overrides[k])
|
651
716
|
|
@@ -658,24 +723,21 @@ class BaseTrainer:
|
|
658
723
|
|
659
724
|
def resume_training(self, ckpt):
|
660
725
|
"""Resume YOLO training from given epoch and best fitness."""
|
661
|
-
if ckpt is None:
|
726
|
+
if ckpt is None or not self.resume:
|
662
727
|
return
|
663
728
|
best_fitness = 0.0
|
664
|
-
start_epoch = ckpt
|
665
|
-
if ckpt
|
729
|
+
start_epoch = ckpt.get("epoch", -1) + 1
|
730
|
+
if ckpt.get("optimizer", None) is not None:
|
666
731
|
self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer
|
667
732
|
best_fitness = ckpt["best_fitness"]
|
668
733
|
if self.ema and ckpt.get("ema"):
|
669
734
|
self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict()) # EMA
|
670
735
|
self.ema.updates = ckpt["updates"]
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
LOGGER.info(
|
677
|
-
f"Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs"
|
678
|
-
)
|
736
|
+
assert start_epoch > 0, (
|
737
|
+
f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n"
|
738
|
+
f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
|
739
|
+
)
|
740
|
+
LOGGER.info(f"Resuming training {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs")
|
679
741
|
if self.epochs < start_epoch:
|
680
742
|
LOGGER.info(
|
681
743
|
f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
|
@@ -692,7 +754,7 @@ class BaseTrainer:
|
|
692
754
|
self.train_loader.dataset.mosaic = False
|
693
755
|
if hasattr(self.train_loader.dataset, "close_mosaic"):
|
694
756
|
LOGGER.info("Closing dataloader mosaic")
|
695
|
-
self.train_loader.dataset.close_mosaic(hyp=self.args)
|
757
|
+
self.train_loader.dataset.close_mosaic(hyp=copy(self.args))
|
696
758
|
|
697
759
|
def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
|
698
760
|
"""
|
@@ -712,7 +774,6 @@ class BaseTrainer:
|
|
712
774
|
Returns:
|
713
775
|
(torch.optim.Optimizer): The constructed optimizer.
|
714
776
|
"""
|
715
|
-
|
716
777
|
g = [], [], [] # optimizer parameter groups
|
717
778
|
bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d()
|
718
779
|
if name == "auto":
|
@@ -736,7 +797,9 @@ class BaseTrainer:
|
|
736
797
|
else: # weight (with decay)
|
737
798
|
g[0].append(param)
|
738
799
|
|
739
|
-
|
800
|
+
optimizers = {"Adam", "Adamax", "AdamW", "NAdam", "RAdam", "RMSProp", "SGD", "auto"}
|
801
|
+
name = {x.lower(): x for x in optimizers}.get(name.lower())
|
802
|
+
if name in {"Adam", "Adamax", "AdamW", "NAdam", "RAdam"}:
|
740
803
|
optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
|
741
804
|
elif name == "RMSProp":
|
742
805
|
optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
|
@@ -744,15 +807,14 @@ class BaseTrainer:
|
|
744
807
|
optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
|
745
808
|
else:
|
746
809
|
raise NotImplementedError(
|
747
|
-
f"Optimizer '{name}' not found in list of available optimizers "
|
748
|
-
|
749
|
-
"To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics."
|
810
|
+
f"Optimizer '{name}' not found in list of available optimizers {optimizers}. "
|
811
|
+
"Request support for addition optimizers at https://github.com/ultralytics/ultralytics."
|
750
812
|
)
|
751
813
|
|
752
814
|
optimizer.add_param_group({"params": g[0], "weight_decay": decay}) # add g0 with weight_decay
|
753
815
|
optimizer.add_param_group({"params": g[1], "weight_decay": 0.0}) # add g1 (BatchNorm2d weights)
|
754
816
|
LOGGER.info(
|
755
817
|
f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
|
756
|
-
f
|
818
|
+
f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)"
|
757
819
|
)
|
758
820
|
return optimizer
|