dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
- dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -6
- tests/conftest.py +15 -39
- tests/test_cli.py +17 -17
- tests/test_cuda.py +17 -8
- tests/test_engine.py +36 -10
- tests/test_exports.py +98 -37
- tests/test_integrations.py +12 -15
- tests/test_python.py +126 -82
- tests/test_solutions.py +319 -135
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +83 -87
- ultralytics/cfg/datasets/Argoverse.yaml +4 -4
- ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
- ultralytics/cfg/datasets/ImageNet.yaml +3 -3
- ultralytics/cfg/datasets/Objects365.yaml +24 -20
- ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
- ultralytics/cfg/datasets/VOC.yaml +10 -13
- ultralytics/cfg/datasets/VisDrone.yaml +43 -33
- ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
- ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
- ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
- ultralytics/cfg/datasets/coco-pose.yaml +26 -4
- ultralytics/cfg/datasets/coco.yaml +4 -4
- ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco128.yaml +2 -2
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco8.yaml +2 -2
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +5 -5
- ultralytics/cfg/datasets/dog-pose.yaml +32 -4
- ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
- ultralytics/cfg/datasets/lvis.yaml +9 -9
- ultralytics/cfg/datasets/medical-pills.yaml +4 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
- ultralytics/cfg/datasets/package-seg.yaml +5 -5
- ultralytics/cfg/datasets/signature.yaml +4 -4
- ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
- ultralytics/cfg/datasets/xView.yaml +5 -5
- ultralytics/cfg/default.yaml +96 -93
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +12 -12
- ultralytics/data/augment.py +531 -564
- ultralytics/data/base.py +76 -81
- ultralytics/data/build.py +206 -42
- ultralytics/data/converter.py +179 -78
- ultralytics/data/dataset.py +121 -121
- ultralytics/data/loaders.py +114 -91
- ultralytics/data/split.py +28 -15
- ultralytics/data/split_dota.py +67 -48
- ultralytics/data/utils.py +110 -89
- ultralytics/engine/exporter.py +422 -460
- ultralytics/engine/model.py +224 -252
- ultralytics/engine/predictor.py +94 -89
- ultralytics/engine/results.py +345 -595
- ultralytics/engine/trainer.py +231 -134
- ultralytics/engine/tuner.py +279 -73
- ultralytics/engine/validator.py +53 -46
- ultralytics/hub/__init__.py +26 -28
- ultralytics/hub/auth.py +30 -16
- ultralytics/hub/google/__init__.py +34 -36
- ultralytics/hub/session.py +53 -77
- ultralytics/hub/utils.py +23 -109
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +36 -18
- ultralytics/models/fastsam/predict.py +33 -44
- ultralytics/models/fastsam/utils.py +4 -5
- ultralytics/models/fastsam/val.py +12 -14
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +16 -20
- ultralytics/models/nas/predict.py +12 -14
- ultralytics/models/nas/val.py +4 -5
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +9 -9
- ultralytics/models/rtdetr/predict.py +22 -17
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +79 -59
- ultralytics/models/sam/__init__.py +8 -2
- ultralytics/models/sam/amg.py +53 -38
- ultralytics/models/sam/build.py +29 -31
- ultralytics/models/sam/model.py +33 -38
- ultralytics/models/sam/modules/blocks.py +159 -182
- ultralytics/models/sam/modules/decoders.py +38 -47
- ultralytics/models/sam/modules/encoders.py +114 -133
- ultralytics/models/sam/modules/memory_attention.py +38 -31
- ultralytics/models/sam/modules/sam.py +114 -93
- ultralytics/models/sam/modules/tiny_encoder.py +268 -291
- ultralytics/models/sam/modules/transformer.py +59 -66
- ultralytics/models/sam/modules/utils.py +55 -72
- ultralytics/models/sam/predict.py +745 -341
- ultralytics/models/utils/loss.py +118 -107
- ultralytics/models/utils/ops.py +118 -71
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +28 -26
- ultralytics/models/yolo/classify/train.py +50 -81
- ultralytics/models/yolo/classify/val.py +68 -61
- ultralytics/models/yolo/detect/predict.py +12 -15
- ultralytics/models/yolo/detect/train.py +56 -46
- ultralytics/models/yolo/detect/val.py +279 -223
- ultralytics/models/yolo/model.py +167 -86
- ultralytics/models/yolo/obb/predict.py +7 -11
- ultralytics/models/yolo/obb/train.py +23 -25
- ultralytics/models/yolo/obb/val.py +107 -99
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +12 -14
- ultralytics/models/yolo/pose/train.py +31 -69
- ultralytics/models/yolo/pose/val.py +119 -254
- ultralytics/models/yolo/segment/predict.py +21 -25
- ultralytics/models/yolo/segment/train.py +12 -66
- ultralytics/models/yolo/segment/val.py +126 -305
- ultralytics/models/yolo/world/train.py +53 -45
- ultralytics/models/yolo/world/train_world.py +51 -32
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +30 -37
- ultralytics/models/yolo/yoloe/train.py +89 -71
- ultralytics/models/yolo/yoloe/train_seg.py +15 -17
- ultralytics/models/yolo/yoloe/val.py +56 -41
- ultralytics/nn/__init__.py +9 -11
- ultralytics/nn/autobackend.py +179 -107
- ultralytics/nn/modules/__init__.py +67 -67
- ultralytics/nn/modules/activation.py +8 -7
- ultralytics/nn/modules/block.py +302 -323
- ultralytics/nn/modules/conv.py +61 -104
- ultralytics/nn/modules/head.py +488 -186
- ultralytics/nn/modules/transformer.py +183 -123
- ultralytics/nn/modules/utils.py +15 -20
- ultralytics/nn/tasks.py +327 -203
- ultralytics/nn/text_model.py +81 -65
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +19 -27
- ultralytics/solutions/analytics.py +36 -26
- ultralytics/solutions/config.py +29 -28
- ultralytics/solutions/distance_calculation.py +23 -24
- ultralytics/solutions/heatmap.py +17 -19
- ultralytics/solutions/instance_segmentation.py +21 -19
- ultralytics/solutions/object_blurrer.py +16 -17
- ultralytics/solutions/object_counter.py +48 -53
- ultralytics/solutions/object_cropper.py +22 -16
- ultralytics/solutions/parking_management.py +61 -58
- ultralytics/solutions/queue_management.py +19 -19
- ultralytics/solutions/region_counter.py +63 -50
- ultralytics/solutions/security_alarm.py +22 -25
- ultralytics/solutions/similarity_search.py +107 -60
- ultralytics/solutions/solutions.py +343 -262
- ultralytics/solutions/speed_estimation.py +35 -31
- ultralytics/solutions/streamlit_inference.py +104 -40
- ultralytics/solutions/templates/similarity-search.html +31 -24
- ultralytics/solutions/trackzone.py +24 -24
- ultralytics/solutions/vision_eye.py +11 -12
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +18 -27
- ultralytics/trackers/bot_sort.py +48 -39
- ultralytics/trackers/byte_tracker.py +94 -94
- ultralytics/trackers/track.py +7 -16
- ultralytics/trackers/utils/gmc.py +37 -69
- ultralytics/trackers/utils/kalman_filter.py +68 -76
- ultralytics/trackers/utils/matching.py +13 -17
- ultralytics/utils/__init__.py +251 -275
- ultralytics/utils/autobatch.py +19 -7
- ultralytics/utils/autodevice.py +68 -38
- ultralytics/utils/benchmarks.py +169 -130
- ultralytics/utils/callbacks/base.py +12 -13
- ultralytics/utils/callbacks/clearml.py +14 -15
- ultralytics/utils/callbacks/comet.py +139 -66
- ultralytics/utils/callbacks/dvc.py +19 -27
- ultralytics/utils/callbacks/hub.py +8 -6
- ultralytics/utils/callbacks/mlflow.py +6 -10
- ultralytics/utils/callbacks/neptune.py +11 -19
- ultralytics/utils/callbacks/platform.py +73 -0
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +9 -12
- ultralytics/utils/callbacks/wb.py +33 -30
- ultralytics/utils/checks.py +163 -114
- ultralytics/utils/cpu.py +89 -0
- ultralytics/utils/dist.py +24 -20
- ultralytics/utils/downloads.py +176 -146
- ultralytics/utils/errors.py +11 -13
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +81 -63
- ultralytics/utils/export/imx.py +294 -0
- ultralytics/utils/export/tensorflow.py +217 -0
- ultralytics/utils/files.py +33 -36
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +105 -120
- ultralytics/utils/logger.py +404 -0
- ultralytics/utils/loss.py +99 -61
- ultralytics/utils/metrics.py +649 -478
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +263 -451
- ultralytics/utils/patches.py +70 -31
- ultralytics/utils/plotting.py +253 -223
- ultralytics/utils/tal.py +48 -61
- ultralytics/utils/torch_utils.py +244 -251
- ultralytics/utils/tqdm.py +438 -0
- ultralytics/utils/triton.py +22 -23
- ultralytics/utils/tuner.py +11 -10
- dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
ultralytics/engine/trainer.py
CHANGED
|
@@ -6,6 +6,8 @@ Usage:
|
|
|
6
6
|
$ yolo mode=train model=yolo11n.pt data=coco8.yaml imgsz=640 epochs=100 batch=16
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
9
11
|
import gc
|
|
10
12
|
import math
|
|
11
13
|
import os
|
|
@@ -24,9 +26,10 @@ from torch import nn, optim
|
|
|
24
26
|
from ultralytics import __version__
|
|
25
27
|
from ultralytics.cfg import get_cfg, get_save_dir
|
|
26
28
|
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
|
27
|
-
from ultralytics.nn.tasks import
|
|
29
|
+
from ultralytics.nn.tasks import load_checkpoint
|
|
28
30
|
from ultralytics.utils import (
|
|
29
31
|
DEFAULT_CFG,
|
|
32
|
+
GIT,
|
|
30
33
|
LOCAL_RANK,
|
|
31
34
|
LOGGER,
|
|
32
35
|
RANK,
|
|
@@ -41,10 +44,12 @@ from ultralytics.utils.autobatch import check_train_batch_size
|
|
|
41
44
|
from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_model_file_from_stem, print_args
|
|
42
45
|
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
|
|
43
46
|
from ultralytics.utils.files import get_latest_run
|
|
47
|
+
from ultralytics.utils.plotting import plot_results
|
|
44
48
|
from ultralytics.utils.torch_utils import (
|
|
45
49
|
TORCH_2_4,
|
|
46
50
|
EarlyStopping,
|
|
47
51
|
ModelEMA,
|
|
52
|
+
attempt_compile,
|
|
48
53
|
autocast,
|
|
49
54
|
convert_optimizer_state_dict_to_fp16,
|
|
50
55
|
init_seeds,
|
|
@@ -53,12 +58,15 @@ from ultralytics.utils.torch_utils import (
|
|
|
53
58
|
strip_optimizer,
|
|
54
59
|
torch_distributed_zero_first,
|
|
55
60
|
unset_deterministic,
|
|
61
|
+
unwrap_model,
|
|
56
62
|
)
|
|
57
63
|
|
|
58
64
|
|
|
59
65
|
class BaseTrainer:
|
|
60
|
-
"""
|
|
61
|
-
|
|
66
|
+
"""A base class for creating trainers.
|
|
67
|
+
|
|
68
|
+
This class provides the foundation for training YOLO models, handling the training loop, validation, checkpointing,
|
|
69
|
+
and various training utilities. It supports both single-GPU and multi-GPU distributed training.
|
|
62
70
|
|
|
63
71
|
Attributes:
|
|
64
72
|
args (SimpleNamespace): Configuration for the trainer.
|
|
@@ -89,21 +97,34 @@ class BaseTrainer:
|
|
|
89
97
|
csv (Path): Path to results CSV file.
|
|
90
98
|
metrics (dict): Dictionary of metrics.
|
|
91
99
|
plots (dict): Dictionary of plots.
|
|
100
|
+
|
|
101
|
+
Methods:
|
|
102
|
+
train: Execute the training process.
|
|
103
|
+
validate: Run validation on the test set.
|
|
104
|
+
save_model: Save model training checkpoints.
|
|
105
|
+
get_dataset: Get train and validation datasets.
|
|
106
|
+
setup_model: Load, create, or download model.
|
|
107
|
+
build_optimizer: Construct an optimizer for the model.
|
|
108
|
+
|
|
109
|
+
Examples:
|
|
110
|
+
Initialize a trainer and start training
|
|
111
|
+
>>> trainer = BaseTrainer(cfg="config.yaml")
|
|
112
|
+
>>> trainer.train()
|
|
92
113
|
"""
|
|
93
114
|
|
|
94
115
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
95
|
-
"""
|
|
96
|
-
Initialize the BaseTrainer class.
|
|
116
|
+
"""Initialize the BaseTrainer class.
|
|
97
117
|
|
|
98
118
|
Args:
|
|
99
|
-
cfg (str, optional): Path to a configuration file.
|
|
100
|
-
overrides (dict, optional): Configuration overrides.
|
|
101
|
-
_callbacks (list, optional): List of callback functions.
|
|
119
|
+
cfg (str, optional): Path to a configuration file.
|
|
120
|
+
overrides (dict, optional): Configuration overrides.
|
|
121
|
+
_callbacks (list, optional): List of callback functions.
|
|
102
122
|
"""
|
|
123
|
+
self.hub_session = overrides.pop("session", None) # HUB
|
|
103
124
|
self.args = get_cfg(cfg, overrides)
|
|
104
125
|
self.check_resume(overrides)
|
|
105
|
-
self.device = select_device(self.args.device
|
|
106
|
-
#
|
|
126
|
+
self.device = select_device(self.args.device)
|
|
127
|
+
# Update "-1" devices so post-training val does not repeat search
|
|
107
128
|
self.args.device = os.getenv("CUDA_VISIBLE_DEVICES") if "cuda" in str(self.device) else str(self.device)
|
|
108
129
|
self.validator = None
|
|
109
130
|
self.metrics = None
|
|
@@ -149,15 +170,32 @@ class BaseTrainer:
|
|
|
149
170
|
self.tloss = None
|
|
150
171
|
self.loss_names = ["Loss"]
|
|
151
172
|
self.csv = self.save_dir / "results.csv"
|
|
173
|
+
if self.csv.exists() and not self.args.resume:
|
|
174
|
+
self.csv.unlink()
|
|
152
175
|
self.plot_idx = [0, 1, 2]
|
|
153
|
-
|
|
154
|
-
# HUB
|
|
155
|
-
self.hub_session = None
|
|
176
|
+
self.nan_recovery_attempts = 0
|
|
156
177
|
|
|
157
178
|
# Callbacks
|
|
158
179
|
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
|
159
|
-
|
|
180
|
+
|
|
181
|
+
if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3'
|
|
182
|
+
world_size = len(self.args.device.split(","))
|
|
183
|
+
elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
|
|
184
|
+
world_size = len(self.args.device)
|
|
185
|
+
elif self.args.device in {"cpu", "mps"}: # i.e. device='cpu' or 'mps'
|
|
186
|
+
world_size = 0
|
|
187
|
+
elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number
|
|
188
|
+
world_size = 1 # default to device 0
|
|
189
|
+
else: # i.e. device=None or device=''
|
|
190
|
+
world_size = 0
|
|
191
|
+
|
|
192
|
+
self.ddp = world_size > 1 and "LOCAL_RANK" not in os.environ
|
|
193
|
+
self.world_size = world_size
|
|
194
|
+
# Run subprocess if DDP training, else train normally
|
|
195
|
+
if RANK in {-1, 0} and not self.ddp:
|
|
160
196
|
callbacks.add_integration_callbacks(self)
|
|
197
|
+
# Start console logging immediately at trainer initialization
|
|
198
|
+
self.run_callbacks("on_pretrain_routine_start")
|
|
161
199
|
|
|
162
200
|
def add_callback(self, event: str, callback):
|
|
163
201
|
"""Append the given callback to the event's callback list."""
|
|
@@ -174,31 +212,20 @@ class BaseTrainer:
|
|
|
174
212
|
|
|
175
213
|
def train(self):
|
|
176
214
|
"""Allow device='', device=None on Multi-GPU systems to default to device=0."""
|
|
177
|
-
if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3'
|
|
178
|
-
world_size = len(self.args.device.split(","))
|
|
179
|
-
elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
|
|
180
|
-
world_size = len(self.args.device)
|
|
181
|
-
elif self.args.device in {"cpu", "mps"}: # i.e. device='cpu' or 'mps'
|
|
182
|
-
world_size = 0
|
|
183
|
-
elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number
|
|
184
|
-
world_size = 1 # default to device 0
|
|
185
|
-
else: # i.e. device=None or device=''
|
|
186
|
-
world_size = 0
|
|
187
|
-
|
|
188
215
|
# Run subprocess if DDP training, else train normally
|
|
189
|
-
if
|
|
216
|
+
if self.ddp:
|
|
190
217
|
# Argument checks
|
|
191
218
|
if self.args.rect:
|
|
192
219
|
LOGGER.warning("'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'")
|
|
193
220
|
self.args.rect = False
|
|
194
221
|
if self.args.batch < 1.0:
|
|
195
|
-
|
|
196
|
-
"
|
|
222
|
+
raise ValueError(
|
|
223
|
+
"AutoBatch with batch<1 not supported for Multi-GPU training, "
|
|
224
|
+
f"please specify a valid batch size multiple of GPU count {self.world_size}, i.e. batch={self.world_size * 8}."
|
|
197
225
|
)
|
|
198
|
-
self.args.batch = 16
|
|
199
226
|
|
|
200
227
|
# Command
|
|
201
|
-
cmd, file = generate_ddp_command(
|
|
228
|
+
cmd, file = generate_ddp_command(self)
|
|
202
229
|
try:
|
|
203
230
|
LOGGER.info(f"{colorstr('DDP:')} debug command {' '.join(cmd)}")
|
|
204
231
|
subprocess.run(cmd, check=True)
|
|
@@ -208,7 +235,7 @@ class BaseTrainer:
|
|
|
208
235
|
ddp_cleanup(self, str(file))
|
|
209
236
|
|
|
210
237
|
else:
|
|
211
|
-
self._do_train(
|
|
238
|
+
self._do_train()
|
|
212
239
|
|
|
213
240
|
def _setup_scheduler(self):
|
|
214
241
|
"""Initialize training learning rate scheduler."""
|
|
@@ -218,27 +245,27 @@ class BaseTrainer:
|
|
|
218
245
|
self.lf = lambda x: max(1 - x / self.epochs, 0) * (1.0 - self.args.lrf) + self.args.lrf # linear
|
|
219
246
|
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
|
|
220
247
|
|
|
221
|
-
def _setup_ddp(self
|
|
248
|
+
def _setup_ddp(self):
|
|
222
249
|
"""Initialize and set the DistributedDataParallel parameters for training."""
|
|
223
250
|
torch.cuda.set_device(RANK)
|
|
224
251
|
self.device = torch.device("cuda", RANK)
|
|
225
|
-
# LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
|
|
226
252
|
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout
|
|
227
253
|
dist.init_process_group(
|
|
228
254
|
backend="nccl" if dist.is_nccl_available() else "gloo",
|
|
229
255
|
timeout=timedelta(seconds=10800), # 3 hours
|
|
230
256
|
rank=RANK,
|
|
231
|
-
world_size=world_size,
|
|
257
|
+
world_size=self.world_size,
|
|
232
258
|
)
|
|
233
259
|
|
|
234
|
-
def _setup_train(self
|
|
260
|
+
def _setup_train(self):
|
|
235
261
|
"""Build dataloaders and optimizer on correct rank process."""
|
|
236
|
-
# Model
|
|
237
|
-
self.run_callbacks("on_pretrain_routine_start")
|
|
238
262
|
ckpt = self.setup_model()
|
|
239
263
|
self.model = self.model.to(self.device)
|
|
240
264
|
self.set_model_attributes()
|
|
241
265
|
|
|
266
|
+
# Compile model
|
|
267
|
+
self.model = attempt_compile(self.model, device=self.device, mode=self.args.compile)
|
|
268
|
+
|
|
242
269
|
# Freeze layers
|
|
243
270
|
freeze_list = (
|
|
244
271
|
self.args.freeze
|
|
@@ -268,13 +295,13 @@ class BaseTrainer:
|
|
|
268
295
|
callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
|
|
269
296
|
self.amp = torch.tensor(check_amp(self.model), device=self.device)
|
|
270
297
|
callbacks.default_callbacks = callbacks_backup # restore callbacks
|
|
271
|
-
if RANK > -1 and world_size > 1: # DDP
|
|
298
|
+
if RANK > -1 and self.world_size > 1: # DDP
|
|
272
299
|
dist.broadcast(self.amp.int(), src=0) # broadcast from rank 0 to all other ranks; gloo errors with boolean
|
|
273
300
|
self.amp = bool(self.amp) # as boolean
|
|
274
301
|
self.scaler = (
|
|
275
302
|
torch.amp.GradScaler("cuda", enabled=self.amp) if TORCH_2_4 else torch.cuda.amp.GradScaler(enabled=self.amp)
|
|
276
303
|
)
|
|
277
|
-
if world_size > 1:
|
|
304
|
+
if self.world_size > 1:
|
|
278
305
|
self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK], find_unused_parameters=True)
|
|
279
306
|
|
|
280
307
|
# Check imgsz
|
|
@@ -287,22 +314,22 @@ class BaseTrainer:
|
|
|
287
314
|
self.args.batch = self.batch_size = self.auto_batch()
|
|
288
315
|
|
|
289
316
|
# Dataloaders
|
|
290
|
-
batch_size = self.batch_size // max(world_size, 1)
|
|
317
|
+
batch_size = self.batch_size // max(self.world_size, 1)
|
|
291
318
|
self.train_loader = self.get_dataloader(
|
|
292
319
|
self.data["train"], batch_size=batch_size, rank=LOCAL_RANK, mode="train"
|
|
293
320
|
)
|
|
321
|
+
# Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
|
|
322
|
+
self.test_loader = self.get_dataloader(
|
|
323
|
+
self.data.get("val") or self.data.get("test"),
|
|
324
|
+
batch_size=batch_size if self.args.task == "obb" else batch_size * 2,
|
|
325
|
+
rank=LOCAL_RANK,
|
|
326
|
+
mode="val",
|
|
327
|
+
)
|
|
328
|
+
self.validator = self.get_validator()
|
|
329
|
+
self.ema = ModelEMA(self.model)
|
|
294
330
|
if RANK in {-1, 0}:
|
|
295
|
-
# Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
|
|
296
|
-
self.test_loader = self.get_dataloader(
|
|
297
|
-
self.data.get("val") or self.data.get("test"),
|
|
298
|
-
batch_size=batch_size if self.args.task == "obb" else batch_size * 2,
|
|
299
|
-
rank=-1,
|
|
300
|
-
mode="val",
|
|
301
|
-
)
|
|
302
|
-
self.validator = self.get_validator()
|
|
303
331
|
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val")
|
|
304
332
|
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
|
|
305
|
-
self.ema = ModelEMA(self.model)
|
|
306
333
|
if self.args.plots:
|
|
307
334
|
self.plot_training_labels()
|
|
308
335
|
|
|
@@ -325,11 +352,11 @@ class BaseTrainer:
|
|
|
325
352
|
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
|
|
326
353
|
self.run_callbacks("on_pretrain_routine_end")
|
|
327
354
|
|
|
328
|
-
def _do_train(self
|
|
355
|
+
def _do_train(self):
|
|
329
356
|
"""Train the model with the specified world size."""
|
|
330
|
-
if world_size > 1:
|
|
331
|
-
self._setup_ddp(
|
|
332
|
-
self._setup_train(
|
|
357
|
+
if self.world_size > 1:
|
|
358
|
+
self._setup_ddp()
|
|
359
|
+
self._setup_train()
|
|
333
360
|
|
|
334
361
|
nb = len(self.train_loader) # number of batches
|
|
335
362
|
nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1 # warmup iterations
|
|
@@ -340,7 +367,7 @@ class BaseTrainer:
|
|
|
340
367
|
self.run_callbacks("on_train_start")
|
|
341
368
|
LOGGER.info(
|
|
342
369
|
f"Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n"
|
|
343
|
-
f"Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n"
|
|
370
|
+
f"Using {self.train_loader.num_workers * (self.world_size or 1)} dataloader workers\n"
|
|
344
371
|
f"Logging results to {colorstr('bold', self.save_dir)}\n"
|
|
345
372
|
f"Starting training for " + (f"{self.args.time} hours..." if self.args.time else f"{self.epochs} epochs...")
|
|
346
373
|
)
|
|
@@ -387,18 +414,19 @@ class BaseTrainer:
|
|
|
387
414
|
# Forward
|
|
388
415
|
with autocast(self.amp):
|
|
389
416
|
batch = self.preprocess_batch(batch)
|
|
390
|
-
|
|
417
|
+
if self.args.compile:
|
|
418
|
+
# Decouple inference and loss calculations for improved compile performance
|
|
419
|
+
preds = self.model(batch["img"])
|
|
420
|
+
loss, self.loss_items = unwrap_model(self.model).loss(batch, preds)
|
|
421
|
+
else:
|
|
422
|
+
loss, self.loss_items = self.model(batch)
|
|
391
423
|
self.loss = loss.sum()
|
|
392
424
|
if RANK != -1:
|
|
393
|
-
self.loss *= world_size
|
|
394
|
-
self.tloss = (
|
|
395
|
-
(self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items
|
|
396
|
-
)
|
|
425
|
+
self.loss *= self.world_size
|
|
426
|
+
self.tloss = self.loss_items if self.tloss is None else (self.tloss * i + self.loss_items) / (i + 1)
|
|
397
427
|
|
|
398
428
|
# Backward
|
|
399
429
|
self.scaler.scale(self.loss).backward()
|
|
400
|
-
|
|
401
|
-
# Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
|
|
402
430
|
if ni - last_opt_step >= self.accumulate:
|
|
403
431
|
self.optimizer_step()
|
|
404
432
|
last_opt_step = ni
|
|
@@ -433,14 +461,23 @@ class BaseTrainer:
|
|
|
433
461
|
self.run_callbacks("on_train_batch_end")
|
|
434
462
|
|
|
435
463
|
self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
|
|
464
|
+
|
|
436
465
|
self.run_callbacks("on_train_epoch_end")
|
|
437
466
|
if RANK in {-1, 0}:
|
|
438
467
|
final_epoch = epoch + 1 >= self.epochs
|
|
439
468
|
self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
|
|
440
469
|
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
470
|
+
# Validation
|
|
471
|
+
if self.args.val or final_epoch or self.stopper.possible_stop or self.stop:
|
|
472
|
+
self._clear_memory(threshold=0.5) # prevent VRAM spike
|
|
473
|
+
self.metrics, self.fitness = self.validate()
|
|
474
|
+
|
|
475
|
+
# NaN recovery
|
|
476
|
+
if self._handle_nan_recovery(epoch):
|
|
477
|
+
continue
|
|
478
|
+
|
|
479
|
+
self.nan_recovery_attempts = 0
|
|
480
|
+
if RANK in {-1, 0}:
|
|
444
481
|
self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr})
|
|
445
482
|
self.stop |= self.stopper(epoch + 1, self.fitness) or final_epoch
|
|
446
483
|
if self.args.time:
|
|
@@ -462,8 +499,7 @@ class BaseTrainer:
|
|
|
462
499
|
self.scheduler.last_epoch = self.epoch # do not move
|
|
463
500
|
self.stop |= epoch >= self.epochs # stop if exceeded epochs
|
|
464
501
|
self.run_callbacks("on_fit_epoch_end")
|
|
465
|
-
|
|
466
|
-
self._clear_memory() # clear if memory utilization > 50%
|
|
502
|
+
self._clear_memory(0.5) # clear if memory utilization > 50%
|
|
467
503
|
|
|
468
504
|
# Early Stopping
|
|
469
505
|
if RANK != -1: # if DDP training
|
|
@@ -474,11 +510,11 @@ class BaseTrainer:
|
|
|
474
510
|
break # must break all DDP ranks
|
|
475
511
|
epoch += 1
|
|
476
512
|
|
|
513
|
+
seconds = time.time() - self.train_time_start
|
|
514
|
+
LOGGER.info(f"\n{epoch - self.start_epoch + 1} epochs completed in {seconds / 3600:.3f} hours.")
|
|
515
|
+
# Do final val with best.pt
|
|
516
|
+
self.final_eval()
|
|
477
517
|
if RANK in {-1, 0}:
|
|
478
|
-
# Do final val with best.pt
|
|
479
|
-
seconds = time.time() - self.train_time_start
|
|
480
|
-
LOGGER.info(f"\n{epoch - self.start_epoch + 1} epochs completed in {seconds / 3600:.3f} hours.")
|
|
481
|
-
self.final_eval()
|
|
482
518
|
if self.args.plots:
|
|
483
519
|
self.plot_metrics()
|
|
484
520
|
self.run_callbacks("on_train_end")
|
|
@@ -509,8 +545,12 @@ class BaseTrainer:
|
|
|
509
545
|
total = torch.cuda.get_device_properties(self.device).total_memory
|
|
510
546
|
return ((memory / total) if total > 0 else 0) if fraction else (memory / 2**30)
|
|
511
547
|
|
|
512
|
-
def _clear_memory(self):
|
|
548
|
+
def _clear_memory(self, threshold: float | None = None):
|
|
513
549
|
"""Clear accelerator memory by calling garbage collector and emptying cache."""
|
|
550
|
+
if threshold:
|
|
551
|
+
assert 0 <= threshold <= 1, "Threshold must be between 0 and 1."
|
|
552
|
+
if self._get_memory(fraction=True) <= threshold:
|
|
553
|
+
return
|
|
514
554
|
gc.collect()
|
|
515
555
|
if self.device.type == "mps":
|
|
516
556
|
torch.mps.empty_cache()
|
|
@@ -520,10 +560,13 @@ class BaseTrainer:
|
|
|
520
560
|
torch.cuda.empty_cache()
|
|
521
561
|
|
|
522
562
|
def read_results_csv(self):
|
|
523
|
-
"""Read results.csv into a dictionary using
|
|
524
|
-
import
|
|
563
|
+
"""Read results.csv into a dictionary using polars."""
|
|
564
|
+
import polars as pl # scope for faster 'import ultralytics'
|
|
525
565
|
|
|
526
|
-
|
|
566
|
+
try:
|
|
567
|
+
return pl.read_csv(self.csv, infer_schema_length=None).to_dict(as_series=False)
|
|
568
|
+
except Exception:
|
|
569
|
+
return {}
|
|
527
570
|
|
|
528
571
|
def _model_train(self):
|
|
529
572
|
"""Set model in training mode."""
|
|
@@ -544,14 +587,21 @@ class BaseTrainer:
|
|
|
544
587
|
"epoch": self.epoch,
|
|
545
588
|
"best_fitness": self.best_fitness,
|
|
546
589
|
"model": None, # resume and final checkpoints derive from EMA
|
|
547
|
-
"ema": deepcopy(self.ema.ema).half(),
|
|
590
|
+
"ema": deepcopy(unwrap_model(self.ema.ema)).half(),
|
|
548
591
|
"updates": self.ema.updates,
|
|
549
592
|
"optimizer": convert_optimizer_state_dict_to_fp16(deepcopy(self.optimizer.state_dict())),
|
|
593
|
+
"scaler": self.scaler.state_dict(),
|
|
550
594
|
"train_args": vars(self.args), # save as dict
|
|
551
595
|
"train_metrics": {**self.metrics, **{"fitness": self.fitness}},
|
|
552
596
|
"train_results": self.read_results_csv(),
|
|
553
597
|
"date": datetime.now().isoformat(),
|
|
554
598
|
"version": __version__,
|
|
599
|
+
"git": {
|
|
600
|
+
"root": str(GIT.root),
|
|
601
|
+
"branch": GIT.branch,
|
|
602
|
+
"commit": GIT.commit,
|
|
603
|
+
"origin": GIT.origin,
|
|
604
|
+
},
|
|
555
605
|
"license": "AGPL-3.0 (https://ultralytics.com/license)",
|
|
556
606
|
"docs": "https://docs.ultralytics.com",
|
|
557
607
|
},
|
|
@@ -560,17 +610,15 @@ class BaseTrainer:
|
|
|
560
610
|
serialized_ckpt = buffer.getvalue() # get the serialized content to save
|
|
561
611
|
|
|
562
612
|
# Save checkpoints
|
|
613
|
+
self.wdir.mkdir(parents=True, exist_ok=True) # ensure weights directory exists
|
|
563
614
|
self.last.write_bytes(serialized_ckpt) # save last.pt
|
|
564
615
|
if self.best_fitness == self.fitness:
|
|
565
616
|
self.best.write_bytes(serialized_ckpt) # save best.pt
|
|
566
617
|
if (self.save_period > 0) and (self.epoch % self.save_period == 0):
|
|
567
618
|
(self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'
|
|
568
|
-
# if self.args.close_mosaic and self.epoch == (self.epochs - self.args.close_mosaic - 1):
|
|
569
|
-
# (self.wdir / "last_mosaic.pt").write_bytes(serialized_ckpt) # save mosaic checkpoint
|
|
570
619
|
|
|
571
620
|
def get_dataset(self):
|
|
572
|
-
"""
|
|
573
|
-
Get train and validation datasets from data dictionary.
|
|
621
|
+
"""Get train and validation datasets from data dictionary.
|
|
574
622
|
|
|
575
623
|
Returns:
|
|
576
624
|
(dict): A dictionary containing the training/validation/test dataset and category names.
|
|
@@ -578,7 +626,16 @@ class BaseTrainer:
|
|
|
578
626
|
try:
|
|
579
627
|
if self.args.task == "classify":
|
|
580
628
|
data = check_cls_dataset(self.args.data)
|
|
581
|
-
elif self.args.data.
|
|
629
|
+
elif self.args.data.rsplit(".", 1)[-1] == "ndjson":
|
|
630
|
+
# Convert NDJSON to YOLO format
|
|
631
|
+
import asyncio
|
|
632
|
+
|
|
633
|
+
from ultralytics.data.converter import convert_ndjson_to_yolo
|
|
634
|
+
|
|
635
|
+
yaml_path = asyncio.run(convert_ndjson_to_yolo(self.args.data))
|
|
636
|
+
self.args.data = str(yaml_path)
|
|
637
|
+
data = check_det_dataset(self.args.data)
|
|
638
|
+
elif self.args.data.rsplit(".", 1)[-1] in {"yaml", "yml"} or self.args.task in {
|
|
582
639
|
"detect",
|
|
583
640
|
"segment",
|
|
584
641
|
"pose",
|
|
@@ -596,8 +653,7 @@ class BaseTrainer:
|
|
|
596
653
|
return data
|
|
597
654
|
|
|
598
655
|
def setup_model(self):
|
|
599
|
-
"""
|
|
600
|
-
Load, create, or download model for any task.
|
|
656
|
+
"""Load, create, or download model for any task.
|
|
601
657
|
|
|
602
658
|
Returns:
|
|
603
659
|
(dict): Optional checkpoint to resume training from.
|
|
@@ -608,17 +664,17 @@ class BaseTrainer:
|
|
|
608
664
|
cfg, weights = self.model, None
|
|
609
665
|
ckpt = None
|
|
610
666
|
if str(self.model).endswith(".pt"):
|
|
611
|
-
weights, ckpt =
|
|
667
|
+
weights, ckpt = load_checkpoint(self.model)
|
|
612
668
|
cfg = weights.yaml
|
|
613
669
|
elif isinstance(self.args.pretrained, (str, Path)):
|
|
614
|
-
weights, _ =
|
|
670
|
+
weights, _ = load_checkpoint(self.args.pretrained)
|
|
615
671
|
self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
|
|
616
672
|
return ckpt
|
|
617
673
|
|
|
618
674
|
def optimizer_step(self):
|
|
619
675
|
"""Perform a single step of the training optimizer with gradient clipping and EMA update."""
|
|
620
676
|
self.scaler.unscale_(self.optimizer) # unscale gradients
|
|
621
|
-
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)
|
|
677
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)
|
|
622
678
|
self.scaler.step(self.optimizer)
|
|
623
679
|
self.scaler.update()
|
|
624
680
|
self.optimizer.zero_grad()
|
|
@@ -626,17 +682,23 @@ class BaseTrainer:
|
|
|
626
682
|
self.ema.update(self.model)
|
|
627
683
|
|
|
628
684
|
def preprocess_batch(self, batch):
|
|
629
|
-
"""
|
|
685
|
+
"""Allow custom preprocessing model inputs and ground truths depending on task type."""
|
|
630
686
|
return batch
|
|
631
687
|
|
|
632
688
|
def validate(self):
|
|
633
|
-
"""
|
|
634
|
-
Run validation on test set using self.validator.
|
|
689
|
+
"""Run validation on val set using self.validator.
|
|
635
690
|
|
|
636
691
|
Returns:
|
|
637
|
-
(
|
|
692
|
+
metrics (dict): Dictionary of validation metrics.
|
|
693
|
+
fitness (float): Fitness score for the validation.
|
|
638
694
|
"""
|
|
695
|
+
if self.ema and self.world_size > 1:
|
|
696
|
+
# Sync EMA buffers from rank 0 to all ranks
|
|
697
|
+
for buffer in self.ema.ema.buffers():
|
|
698
|
+
dist.broadcast(buffer, src=0)
|
|
639
699
|
metrics = self.validator(self)
|
|
700
|
+
if metrics is None:
|
|
701
|
+
return None, None
|
|
640
702
|
fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
|
|
641
703
|
if not self.best_fitness or self.best_fitness < fitness:
|
|
642
704
|
self.best_fitness = fitness
|
|
@@ -647,11 +709,11 @@ class BaseTrainer:
|
|
|
647
709
|
raise NotImplementedError("This task trainer doesn't support loading cfg files")
|
|
648
710
|
|
|
649
711
|
def get_validator(self):
|
|
650
|
-
"""
|
|
712
|
+
"""Return a NotImplementedError when the get_validator function is called."""
|
|
651
713
|
raise NotImplementedError("get_validator function not implemented in trainer")
|
|
652
714
|
|
|
653
715
|
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
|
|
654
|
-
"""
|
|
716
|
+
"""Return dataloader derived from torch.data.Dataloader."""
|
|
655
717
|
raise NotImplementedError("get_dataloader function not implemented in trainer")
|
|
656
718
|
|
|
657
719
|
def build_dataset(self, img_path, mode="train", batch=None):
|
|
@@ -659,10 +721,9 @@ class BaseTrainer:
|
|
|
659
721
|
raise NotImplementedError("build_dataset function not implemented in trainer")
|
|
660
722
|
|
|
661
723
|
def label_loss_items(self, loss_items=None, prefix="train"):
|
|
662
|
-
"""
|
|
663
|
-
Returns a loss dict with labelled training loss items tensor.
|
|
724
|
+
"""Return a loss dict with labeled training loss items tensor.
|
|
664
725
|
|
|
665
|
-
|
|
726
|
+
Notes:
|
|
666
727
|
This is not needed for classification but necessary for segmentation & detection
|
|
667
728
|
"""
|
|
668
729
|
return {"loss": loss_items} if loss_items is not None else ["loss"]
|
|
@@ -672,55 +733,57 @@ class BaseTrainer:
|
|
|
672
733
|
self.model.names = self.data["names"]
|
|
673
734
|
|
|
674
735
|
def build_targets(self, preds, targets):
|
|
675
|
-
"""
|
|
736
|
+
"""Build target tensors for training YOLO model."""
|
|
676
737
|
pass
|
|
677
738
|
|
|
678
739
|
def progress_string(self):
|
|
679
|
-
"""
|
|
740
|
+
"""Return a string describing training progress."""
|
|
680
741
|
return ""
|
|
681
742
|
|
|
682
743
|
# TODO: may need to put these following functions into callback
|
|
683
744
|
def plot_training_samples(self, batch, ni):
|
|
684
|
-
"""
|
|
745
|
+
"""Plot training samples during YOLO training."""
|
|
685
746
|
pass
|
|
686
747
|
|
|
687
748
|
def plot_training_labels(self):
|
|
688
|
-
"""
|
|
749
|
+
"""Plot training labels for YOLO model."""
|
|
689
750
|
pass
|
|
690
751
|
|
|
691
752
|
def save_metrics(self, metrics):
|
|
692
753
|
"""Save training metrics to a CSV file."""
|
|
693
754
|
keys, vals = list(metrics.keys()), list(metrics.values())
|
|
694
755
|
n = len(metrics) + 2 # number of cols
|
|
695
|
-
s = "" if self.csv.exists() else (("%s," * n % tuple(["epoch", "time"] + keys)).rstrip(",") + "\n") # header
|
|
696
756
|
t = time.time() - self.train_time_start
|
|
757
|
+
self.csv.parent.mkdir(parents=True, exist_ok=True) # ensure parent directory exists
|
|
758
|
+
s = "" if self.csv.exists() else (("%s," * n % tuple(["epoch", "time", *keys])).rstrip(",") + "\n") # header
|
|
697
759
|
with open(self.csv, "a", encoding="utf-8") as f:
|
|
698
|
-
f.write(s + ("%.6g," * n % tuple([self.epoch + 1, t
|
|
760
|
+
f.write(s + ("%.6g," * n % tuple([self.epoch + 1, t, *vals])).rstrip(",") + "\n")
|
|
699
761
|
|
|
700
762
|
def plot_metrics(self):
|
|
701
|
-
"""Plot
|
|
702
|
-
|
|
763
|
+
"""Plot metrics from a CSV file."""
|
|
764
|
+
plot_results(file=self.csv, on_plot=self.on_plot) # save results.png
|
|
703
765
|
|
|
704
766
|
def on_plot(self, name, data=None):
|
|
705
|
-
"""
|
|
767
|
+
"""Register plots (e.g. to be consumed in callbacks)."""
|
|
706
768
|
path = Path(name)
|
|
707
769
|
self.plots[path] = {"data": data, "timestamp": time.time()}
|
|
708
770
|
|
|
709
771
|
def final_eval(self):
|
|
710
772
|
"""Perform final evaluation and validation for object detection YOLO model."""
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
if
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
773
|
+
model = self.best if self.best.exists() else None
|
|
774
|
+
with torch_distributed_zero_first(LOCAL_RANK): # strip only on GPU 0; other GPUs should wait
|
|
775
|
+
if RANK in {-1, 0}:
|
|
776
|
+
ckpt = strip_optimizer(self.last) if self.last.exists() else {}
|
|
777
|
+
if model:
|
|
778
|
+
# update best.pt train_metrics from last.pt
|
|
779
|
+
strip_optimizer(self.best, updates={"train_results": ckpt.get("train_results")})
|
|
780
|
+
if model:
|
|
781
|
+
LOGGER.info(f"\nValidating {model}...")
|
|
782
|
+
self.validator.args.plots = self.args.plots
|
|
783
|
+
self.validator.args.compile = False # disable final val compile as too slow
|
|
784
|
+
self.metrics = self.validator(model=model)
|
|
785
|
+
self.metrics.pop("fitness", None)
|
|
786
|
+
self.run_callbacks("on_fit_epoch_end")
|
|
724
787
|
|
|
725
788
|
def check_resume(self, overrides):
|
|
726
789
|
"""Check if resume checkpoint exists and update arguments accordingly."""
|
|
@@ -731,7 +794,7 @@ class BaseTrainer:
|
|
|
731
794
|
last = Path(check_file(resume) if exists else get_latest_run())
|
|
732
795
|
|
|
733
796
|
# Check that resume data YAML exists, otherwise strip to force re-download of dataset
|
|
734
|
-
ckpt_args =
|
|
797
|
+
ckpt_args = load_checkpoint(last)[0].args
|
|
735
798
|
if not isinstance(ckpt_args["data"], dict) and not Path(ckpt_args["data"]).exists():
|
|
736
799
|
ckpt_args["data"] = self.args.data
|
|
737
800
|
|
|
@@ -754,18 +817,54 @@ class BaseTrainer:
|
|
|
754
817
|
) from e
|
|
755
818
|
self.resume = resume
|
|
756
819
|
|
|
820
|
+
def _load_checkpoint_state(self, ckpt):
|
|
821
|
+
"""Load optimizer, scaler, EMA, and best_fitness from checkpoint."""
|
|
822
|
+
if ckpt.get("optimizer") is not None:
|
|
823
|
+
self.optimizer.load_state_dict(ckpt["optimizer"])
|
|
824
|
+
if ckpt.get("scaler") is not None:
|
|
825
|
+
self.scaler.load_state_dict(ckpt["scaler"])
|
|
826
|
+
if self.ema and ckpt.get("ema"):
|
|
827
|
+
self.ema = ModelEMA(self.model) # validation with EMA creates inference tensors that can't be updated
|
|
828
|
+
self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict())
|
|
829
|
+
self.ema.updates = ckpt["updates"]
|
|
830
|
+
self.best_fitness = ckpt.get("best_fitness", 0.0)
|
|
831
|
+
|
|
832
|
+
def _handle_nan_recovery(self, epoch):
|
|
833
|
+
"""Detect and recover from NaN/Inf loss and fitness collapse by loading last checkpoint."""
|
|
834
|
+
loss_nan = self.loss is not None and not self.loss.isfinite()
|
|
835
|
+
fitness_nan = self.fitness is not None and not np.isfinite(self.fitness)
|
|
836
|
+
fitness_collapse = self.best_fitness and self.best_fitness > 0 and self.fitness == 0
|
|
837
|
+
corrupted = RANK in {-1, 0} and loss_nan and (fitness_nan or fitness_collapse)
|
|
838
|
+
reason = "Loss NaN/Inf" if loss_nan else "Fitness NaN/Inf" if fitness_nan else "Fitness collapse"
|
|
839
|
+
if RANK != -1: # DDP: broadcast to all ranks
|
|
840
|
+
broadcast_list = [corrupted if RANK == 0 else None]
|
|
841
|
+
dist.broadcast_object_list(broadcast_list, 0)
|
|
842
|
+
corrupted = broadcast_list[0]
|
|
843
|
+
if not corrupted:
|
|
844
|
+
return False
|
|
845
|
+
if epoch == self.start_epoch or not self.last.exists():
|
|
846
|
+
LOGGER.warning(f"{reason} detected but can not recover from last.pt...")
|
|
847
|
+
return False # Cannot recover on first epoch, let training continue
|
|
848
|
+
self.nan_recovery_attempts += 1
|
|
849
|
+
if self.nan_recovery_attempts > 3:
|
|
850
|
+
raise RuntimeError(f"Training failed: NaN persisted for {self.nan_recovery_attempts} epochs")
|
|
851
|
+
LOGGER.warning(f"{reason} detected (attempt {self.nan_recovery_attempts}/3), recovering from last.pt...")
|
|
852
|
+
self._model_train() # set model to train mode before loading checkpoint to avoid inference tensor errors
|
|
853
|
+
_, ckpt = load_checkpoint(self.last)
|
|
854
|
+
ema_state = ckpt["ema"].float().state_dict()
|
|
855
|
+
if not all(torch.isfinite(v).all() for v in ema_state.values() if isinstance(v, torch.Tensor)):
|
|
856
|
+
raise RuntimeError(f"Checkpoint {self.last} is corrupted with NaN/Inf weights")
|
|
857
|
+
unwrap_model(self.model).load_state_dict(ema_state) # Load EMA weights into model
|
|
858
|
+
self._load_checkpoint_state(ckpt) # Load optimizer/scaler/EMA/best_fitness
|
|
859
|
+
del ckpt, ema_state
|
|
860
|
+
self.scheduler.last_epoch = epoch - 1
|
|
861
|
+
return True
|
|
862
|
+
|
|
757
863
|
def resume_training(self, ckpt):
|
|
758
864
|
"""Resume YOLO training from given epoch and best fitness."""
|
|
759
865
|
if ckpt is None or not self.resume:
|
|
760
866
|
return
|
|
761
|
-
best_fitness = 0.0
|
|
762
867
|
start_epoch = ckpt.get("epoch", -1) + 1
|
|
763
|
-
if ckpt.get("optimizer", None) is not None:
|
|
764
|
-
self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer
|
|
765
|
-
best_fitness = ckpt["best_fitness"]
|
|
766
|
-
if self.ema and ckpt.get("ema"):
|
|
767
|
-
self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict()) # EMA
|
|
768
|
-
self.ema.updates = ckpt["updates"]
|
|
769
868
|
assert start_epoch > 0, (
|
|
770
869
|
f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n"
|
|
771
870
|
f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
|
|
@@ -776,7 +875,7 @@ class BaseTrainer:
|
|
|
776
875
|
f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
|
|
777
876
|
)
|
|
778
877
|
self.epochs += ckpt["epoch"] # finetune additional epochs
|
|
779
|
-
self.
|
|
878
|
+
self._load_checkpoint_state(ckpt)
|
|
780
879
|
self.start_epoch = start_epoch
|
|
781
880
|
if start_epoch > (self.epochs - self.args.close_mosaic):
|
|
782
881
|
self._close_dataloader_mosaic()
|
|
@@ -790,18 +889,16 @@ class BaseTrainer:
|
|
|
790
889
|
self.train_loader.dataset.close_mosaic(hyp=copy(self.args))
|
|
791
890
|
|
|
792
891
|
def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
|
|
793
|
-
"""
|
|
794
|
-
Construct an optimizer for the given model.
|
|
892
|
+
"""Construct an optimizer for the given model.
|
|
795
893
|
|
|
796
894
|
Args:
|
|
797
895
|
model (torch.nn.Module): The model for which to build an optimizer.
|
|
798
|
-
name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected
|
|
799
|
-
|
|
800
|
-
lr (float, optional): The learning rate for the optimizer.
|
|
801
|
-
momentum (float, optional): The momentum factor for the optimizer.
|
|
802
|
-
decay (float, optional): The weight decay for the optimizer.
|
|
803
|
-
iterations (float, optional): The number of iterations, which determines the optimizer if
|
|
804
|
-
name is 'auto'. Default: 1e5.
|
|
896
|
+
name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected based on the
|
|
897
|
+
number of iterations.
|
|
898
|
+
lr (float, optional): The learning rate for the optimizer.
|
|
899
|
+
momentum (float, optional): The momentum factor for the optimizer.
|
|
900
|
+
decay (float, optional): The weight decay for the optimizer.
|
|
901
|
+
iterations (float, optional): The number of iterations, which determines the optimizer if name is 'auto'.
|
|
805
902
|
|
|
806
903
|
Returns:
|
|
807
904
|
(torch.optim.Optimizer): The constructed optimizer.
|