dgenerate-ultralytics-headless 8.3.196__py3-none-any.whl → 8.3.248__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +33 -34
- dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
- tests/__init__.py +5 -7
- tests/conftest.py +8 -15
- tests/test_cli.py +8 -10
- tests/test_cuda.py +9 -10
- tests/test_engine.py +29 -2
- tests/test_exports.py +69 -21
- tests/test_integrations.py +8 -11
- tests/test_python.py +109 -71
- tests/test_solutions.py +170 -159
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +57 -64
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/Objects365.yaml +19 -15
- ultralytics/cfg/datasets/SKU-110K.yaml +1 -1
- ultralytics/cfg/datasets/VOC.yaml +19 -21
- ultralytics/cfg/datasets/VisDrone.yaml +5 -5
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +24 -2
- ultralytics/cfg/datasets/coco.yaml +2 -2
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +7 -7
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +96 -94
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +3 -4
- ultralytics/data/augment.py +286 -476
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +151 -26
- ultralytics/data/converter.py +38 -50
- ultralytics/data/dataset.py +47 -75
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +41 -45
- ultralytics/engine/exporter.py +462 -462
- ultralytics/engine/model.py +150 -191
- ultralytics/engine/predictor.py +30 -40
- ultralytics/engine/results.py +177 -311
- ultralytics/engine/trainer.py +193 -120
- ultralytics/engine/tuner.py +77 -63
- ultralytics/engine/validator.py +39 -22
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +19 -30
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +7 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +22 -40
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +206 -79
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2268 -366
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +9 -12
- ultralytics/models/yolo/classify/train.py +15 -41
- ultralytics/models/yolo/classify/val.py +34 -32
- ultralytics/models/yolo/detect/predict.py +8 -11
- ultralytics/models/yolo/detect/train.py +13 -32
- ultralytics/models/yolo/detect/val.py +75 -63
- ultralytics/models/yolo/model.py +37 -53
- ultralytics/models/yolo/obb/predict.py +5 -14
- ultralytics/models/yolo/obb/train.py +11 -14
- ultralytics/models/yolo/obb/val.py +42 -39
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +7 -22
- ultralytics/models/yolo/pose/train.py +10 -22
- ultralytics/models/yolo/pose/val.py +40 -59
- ultralytics/models/yolo/segment/predict.py +16 -20
- ultralytics/models/yolo/segment/train.py +3 -12
- ultralytics/models/yolo/segment/val.py +106 -56
- ultralytics/models/yolo/world/train.py +12 -16
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +31 -56
- ultralytics/models/yolo/yoloe/train_seg.py +5 -10
- ultralytics/models/yolo/yoloe/val.py +16 -21
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +152 -80
- ultralytics/nn/modules/__init__.py +60 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +133 -217
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +64 -116
- ultralytics/nn/modules/transformer.py +79 -89
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +111 -156
- ultralytics/nn/text_model.py +40 -67
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +11 -17
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +5 -6
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +7 -13
- ultralytics/solutions/instance_segmentation.py +5 -8
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +33 -31
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +13 -17
- ultralytics/solutions/solutions.py +75 -74
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +4 -7
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +14 -30
- ultralytics/trackers/track.py +3 -6
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +116 -116
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +70 -70
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +314 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +23 -31
- ultralytics/utils/callbacks/wb.py +10 -13
- ultralytics/utils/checks.py +151 -87
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +19 -15
- ultralytics/utils/downloads.py +29 -41
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +16 -16
- ultralytics/utils/export/imx.py +325 -0
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +24 -28
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +15 -24
- ultralytics/utils/metrics.py +131 -160
- ultralytics/utils/nms.py +21 -30
- ultralytics/utils/ops.py +107 -165
- ultralytics/utils/patches.py +33 -21
- ultralytics/utils/plotting.py +122 -119
- ultralytics/utils/tal.py +28 -44
- ultralytics/utils/torch_utils.py +70 -187
- ultralytics/utils/tqdm.py +20 -20
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +17 -5
- dgenerate_ultralytics_headless-8.3.196.dist-info/RECORD +0 -281
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
ultralytics/engine/trainer.py
CHANGED
|
@@ -6,6 +6,8 @@ Usage:
|
|
|
6
6
|
$ yolo mode=train model=yolo11n.pt data=coco8.yaml imgsz=640 epochs=100 batch=16
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
9
11
|
import gc
|
|
10
12
|
import math
|
|
11
13
|
import os
|
|
@@ -42,6 +44,7 @@ from ultralytics.utils.autobatch import check_train_batch_size
|
|
|
42
44
|
from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_model_file_from_stem, print_args
|
|
43
45
|
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
|
|
44
46
|
from ultralytics.utils.files import get_latest_run
|
|
47
|
+
from ultralytics.utils.plotting import plot_results
|
|
45
48
|
from ultralytics.utils.torch_utils import (
|
|
46
49
|
TORCH_2_4,
|
|
47
50
|
EarlyStopping,
|
|
@@ -60,8 +63,7 @@ from ultralytics.utils.torch_utils import (
|
|
|
60
63
|
|
|
61
64
|
|
|
62
65
|
class BaseTrainer:
|
|
63
|
-
"""
|
|
64
|
-
A base class for creating trainers.
|
|
66
|
+
"""A base class for creating trainers.
|
|
65
67
|
|
|
66
68
|
This class provides the foundation for training YOLO models, handling the training loop, validation, checkpointing,
|
|
67
69
|
and various training utilities. It supports both single-GPU and multi-GPU distributed training.
|
|
@@ -111,17 +113,17 @@ class BaseTrainer:
|
|
|
111
113
|
"""
|
|
112
114
|
|
|
113
115
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
114
|
-
"""
|
|
115
|
-
Initialize the BaseTrainer class.
|
|
116
|
+
"""Initialize the BaseTrainer class.
|
|
116
117
|
|
|
117
118
|
Args:
|
|
118
119
|
cfg (str, optional): Path to a configuration file.
|
|
119
120
|
overrides (dict, optional): Configuration overrides.
|
|
120
121
|
_callbacks (list, optional): List of callback functions.
|
|
121
122
|
"""
|
|
123
|
+
self.hub_session = overrides.pop("session", None) # HUB
|
|
122
124
|
self.args = get_cfg(cfg, overrides)
|
|
123
125
|
self.check_resume(overrides)
|
|
124
|
-
self.device = select_device(self.args.device
|
|
126
|
+
self.device = select_device(self.args.device)
|
|
125
127
|
# Update "-1" devices so post-training val does not repeat search
|
|
126
128
|
self.args.device = os.getenv("CUDA_VISIBLE_DEVICES") if "cuda" in str(self.device) else str(self.device)
|
|
127
129
|
self.validator = None
|
|
@@ -136,7 +138,12 @@ class BaseTrainer:
|
|
|
136
138
|
if RANK in {-1, 0}:
|
|
137
139
|
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
|
138
140
|
self.args.save_dir = str(self.save_dir)
|
|
139
|
-
|
|
141
|
+
# Save run args, serializing augmentations as reprs for resume compatibility
|
|
142
|
+
args_dict = vars(self.args).copy()
|
|
143
|
+
if args_dict.get("augmentations") is not None:
|
|
144
|
+
# Serialize Albumentations transforms as their repr strings for checkpoint compatibility
|
|
145
|
+
args_dict["augmentations"] = [repr(t) for t in args_dict["augmentations"]]
|
|
146
|
+
YAML.save(self.save_dir / "args.yaml", args_dict) # save run args
|
|
140
147
|
self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt" # checkpoint paths
|
|
141
148
|
self.save_period = self.args.save_period
|
|
142
149
|
|
|
@@ -168,14 +175,29 @@ class BaseTrainer:
|
|
|
168
175
|
self.tloss = None
|
|
169
176
|
self.loss_names = ["Loss"]
|
|
170
177
|
self.csv = self.save_dir / "results.csv"
|
|
178
|
+
if self.csv.exists() and not self.args.resume:
|
|
179
|
+
self.csv.unlink()
|
|
171
180
|
self.plot_idx = [0, 1, 2]
|
|
172
|
-
|
|
173
|
-
# HUB
|
|
174
|
-
self.hub_session = None
|
|
181
|
+
self.nan_recovery_attempts = 0
|
|
175
182
|
|
|
176
183
|
# Callbacks
|
|
177
184
|
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
|
178
|
-
|
|
185
|
+
|
|
186
|
+
if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3'
|
|
187
|
+
world_size = len(self.args.device.split(","))
|
|
188
|
+
elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
|
|
189
|
+
world_size = len(self.args.device)
|
|
190
|
+
elif self.args.device in {"cpu", "mps"}: # i.e. device='cpu' or 'mps'
|
|
191
|
+
world_size = 0
|
|
192
|
+
elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number
|
|
193
|
+
world_size = 1 # default to device 0
|
|
194
|
+
else: # i.e. device=None or device=''
|
|
195
|
+
world_size = 0
|
|
196
|
+
|
|
197
|
+
self.ddp = world_size > 1 and "LOCAL_RANK" not in os.environ
|
|
198
|
+
self.world_size = world_size
|
|
199
|
+
# Run subprocess if DDP training, else train normally
|
|
200
|
+
if RANK in {-1, 0} and not self.ddp:
|
|
179
201
|
callbacks.add_integration_callbacks(self)
|
|
180
202
|
# Start console logging immediately at trainer initialization
|
|
181
203
|
self.run_callbacks("on_pretrain_routine_start")
|
|
@@ -195,31 +217,20 @@ class BaseTrainer:
|
|
|
195
217
|
|
|
196
218
|
def train(self):
|
|
197
219
|
"""Allow device='', device=None on Multi-GPU systems to default to device=0."""
|
|
198
|
-
if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3'
|
|
199
|
-
world_size = len(self.args.device.split(","))
|
|
200
|
-
elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
|
|
201
|
-
world_size = len(self.args.device)
|
|
202
|
-
elif self.args.device in {"cpu", "mps"}: # i.e. device='cpu' or 'mps'
|
|
203
|
-
world_size = 0
|
|
204
|
-
elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number
|
|
205
|
-
world_size = 1 # default to device 0
|
|
206
|
-
else: # i.e. device=None or device=''
|
|
207
|
-
world_size = 0
|
|
208
|
-
|
|
209
220
|
# Run subprocess if DDP training, else train normally
|
|
210
|
-
if
|
|
221
|
+
if self.ddp:
|
|
211
222
|
# Argument checks
|
|
212
223
|
if self.args.rect:
|
|
213
224
|
LOGGER.warning("'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'")
|
|
214
225
|
self.args.rect = False
|
|
215
226
|
if self.args.batch < 1.0:
|
|
216
|
-
|
|
217
|
-
"
|
|
227
|
+
raise ValueError(
|
|
228
|
+
"AutoBatch with batch<1 not supported for Multi-GPU training, "
|
|
229
|
+
f"please specify a valid batch size multiple of GPU count {self.world_size}, i.e. batch={self.world_size * 8}."
|
|
218
230
|
)
|
|
219
|
-
self.args.batch = 16
|
|
220
231
|
|
|
221
232
|
# Command
|
|
222
|
-
cmd, file = generate_ddp_command(
|
|
233
|
+
cmd, file = generate_ddp_command(self)
|
|
223
234
|
try:
|
|
224
235
|
LOGGER.info(f"{colorstr('DDP:')} debug command {' '.join(cmd)}")
|
|
225
236
|
subprocess.run(cmd, check=True)
|
|
@@ -229,7 +240,7 @@ class BaseTrainer:
|
|
|
229
240
|
ddp_cleanup(self, str(file))
|
|
230
241
|
|
|
231
242
|
else:
|
|
232
|
-
self._do_train(
|
|
243
|
+
self._do_train()
|
|
233
244
|
|
|
234
245
|
def _setup_scheduler(self):
|
|
235
246
|
"""Initialize training learning rate scheduler."""
|
|
@@ -239,32 +250,26 @@ class BaseTrainer:
|
|
|
239
250
|
self.lf = lambda x: max(1 - x / self.epochs, 0) * (1.0 - self.args.lrf) + self.args.lrf # linear
|
|
240
251
|
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
|
|
241
252
|
|
|
242
|
-
def _setup_ddp(self
|
|
253
|
+
def _setup_ddp(self):
|
|
243
254
|
"""Initialize and set the DistributedDataParallel parameters for training."""
|
|
244
255
|
torch.cuda.set_device(RANK)
|
|
245
256
|
self.device = torch.device("cuda", RANK)
|
|
246
|
-
# LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
|
|
247
257
|
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout
|
|
248
258
|
dist.init_process_group(
|
|
249
259
|
backend="nccl" if dist.is_nccl_available() else "gloo",
|
|
250
260
|
timeout=timedelta(seconds=10800), # 3 hours
|
|
251
261
|
rank=RANK,
|
|
252
|
-
world_size=world_size,
|
|
262
|
+
world_size=self.world_size,
|
|
253
263
|
)
|
|
254
264
|
|
|
255
|
-
def _setup_train(self
|
|
265
|
+
def _setup_train(self):
|
|
256
266
|
"""Build dataloaders and optimizer on correct rank process."""
|
|
257
267
|
ckpt = self.setup_model()
|
|
258
268
|
self.model = self.model.to(self.device)
|
|
259
269
|
self.set_model_attributes()
|
|
260
270
|
|
|
261
|
-
# Initialize loss criterion before compilation for torch.compile compatibility
|
|
262
|
-
if hasattr(self.model, "init_criterion"):
|
|
263
|
-
self.model.criterion = self.model.init_criterion()
|
|
264
|
-
|
|
265
271
|
# Compile model
|
|
266
|
-
|
|
267
|
-
self.model = attempt_compile(self.model, device=self.device)
|
|
272
|
+
self.model = attempt_compile(self.model, device=self.device, mode=self.args.compile)
|
|
268
273
|
|
|
269
274
|
# Freeze layers
|
|
270
275
|
freeze_list = (
|
|
@@ -295,13 +300,13 @@ class BaseTrainer:
|
|
|
295
300
|
callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
|
|
296
301
|
self.amp = torch.tensor(check_amp(self.model), device=self.device)
|
|
297
302
|
callbacks.default_callbacks = callbacks_backup # restore callbacks
|
|
298
|
-
if RANK > -1 and world_size > 1: # DDP
|
|
303
|
+
if RANK > -1 and self.world_size > 1: # DDP
|
|
299
304
|
dist.broadcast(self.amp.int(), src=0) # broadcast from rank 0 to all other ranks; gloo errors with boolean
|
|
300
305
|
self.amp = bool(self.amp) # as boolean
|
|
301
306
|
self.scaler = (
|
|
302
307
|
torch.amp.GradScaler("cuda", enabled=self.amp) if TORCH_2_4 else torch.cuda.amp.GradScaler(enabled=self.amp)
|
|
303
308
|
)
|
|
304
|
-
if world_size > 1:
|
|
309
|
+
if self.world_size > 1:
|
|
305
310
|
self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK], find_unused_parameters=True)
|
|
306
311
|
|
|
307
312
|
# Check imgsz
|
|
@@ -314,22 +319,22 @@ class BaseTrainer:
|
|
|
314
319
|
self.args.batch = self.batch_size = self.auto_batch()
|
|
315
320
|
|
|
316
321
|
# Dataloaders
|
|
317
|
-
batch_size = self.batch_size // max(world_size, 1)
|
|
322
|
+
batch_size = self.batch_size // max(self.world_size, 1)
|
|
318
323
|
self.train_loader = self.get_dataloader(
|
|
319
324
|
self.data["train"], batch_size=batch_size, rank=LOCAL_RANK, mode="train"
|
|
320
325
|
)
|
|
326
|
+
# Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
|
|
327
|
+
self.test_loader = self.get_dataloader(
|
|
328
|
+
self.data.get("val") or self.data.get("test"),
|
|
329
|
+
batch_size=batch_size if self.args.task == "obb" else batch_size * 2,
|
|
330
|
+
rank=LOCAL_RANK,
|
|
331
|
+
mode="val",
|
|
332
|
+
)
|
|
333
|
+
self.validator = self.get_validator()
|
|
334
|
+
self.ema = ModelEMA(self.model)
|
|
321
335
|
if RANK in {-1, 0}:
|
|
322
|
-
# Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
|
|
323
|
-
self.test_loader = self.get_dataloader(
|
|
324
|
-
self.data.get("val") or self.data.get("test"),
|
|
325
|
-
batch_size=batch_size if self.args.task == "obb" else batch_size * 2,
|
|
326
|
-
rank=-1,
|
|
327
|
-
mode="val",
|
|
328
|
-
)
|
|
329
|
-
self.validator = self.get_validator()
|
|
330
336
|
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val")
|
|
331
337
|
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
|
|
332
|
-
self.ema = ModelEMA(self.model)
|
|
333
338
|
if self.args.plots:
|
|
334
339
|
self.plot_training_labels()
|
|
335
340
|
|
|
@@ -352,11 +357,11 @@ class BaseTrainer:
|
|
|
352
357
|
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
|
|
353
358
|
self.run_callbacks("on_pretrain_routine_end")
|
|
354
359
|
|
|
355
|
-
def _do_train(self
|
|
360
|
+
def _do_train(self):
|
|
356
361
|
"""Train the model with the specified world size."""
|
|
357
|
-
if world_size > 1:
|
|
358
|
-
self._setup_ddp(
|
|
359
|
-
self._setup_train(
|
|
362
|
+
if self.world_size > 1:
|
|
363
|
+
self._setup_ddp()
|
|
364
|
+
self._setup_train()
|
|
360
365
|
|
|
361
366
|
nb = len(self.train_loader) # number of batches
|
|
362
367
|
nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1 # warmup iterations
|
|
@@ -367,7 +372,7 @@ class BaseTrainer:
|
|
|
367
372
|
self.run_callbacks("on_train_start")
|
|
368
373
|
LOGGER.info(
|
|
369
374
|
f"Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n"
|
|
370
|
-
f"Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n"
|
|
375
|
+
f"Using {self.train_loader.num_workers * (self.world_size or 1)} dataloader workers\n"
|
|
371
376
|
f"Logging results to {colorstr('bold', self.save_dir)}\n"
|
|
372
377
|
f"Starting training for " + (f"{self.args.time} hours..." if self.args.time else f"{self.epochs} epochs...")
|
|
373
378
|
)
|
|
@@ -414,19 +419,19 @@ class BaseTrainer:
|
|
|
414
419
|
# Forward
|
|
415
420
|
with autocast(self.amp):
|
|
416
421
|
batch = self.preprocess_batch(batch)
|
|
417
|
-
|
|
418
|
-
|
|
422
|
+
if self.args.compile:
|
|
423
|
+
# Decouple inference and loss calculations for improved compile performance
|
|
424
|
+
preds = self.model(batch["img"])
|
|
425
|
+
loss, self.loss_items = unwrap_model(self.model).loss(batch, preds)
|
|
426
|
+
else:
|
|
427
|
+
loss, self.loss_items = self.model(batch)
|
|
419
428
|
self.loss = loss.sum()
|
|
420
429
|
if RANK != -1:
|
|
421
|
-
self.loss *= world_size
|
|
422
|
-
self.tloss = (
|
|
423
|
-
(self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items
|
|
424
|
-
)
|
|
430
|
+
self.loss *= self.world_size
|
|
431
|
+
self.tloss = self.loss_items if self.tloss is None else (self.tloss * i + self.loss_items) / (i + 1)
|
|
425
432
|
|
|
426
433
|
# Backward
|
|
427
434
|
self.scaler.scale(self.loss).backward()
|
|
428
|
-
|
|
429
|
-
# Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
|
|
430
435
|
if ni - last_opt_step >= self.accumulate:
|
|
431
436
|
self.optimizer_step()
|
|
432
437
|
last_opt_step = ni
|
|
@@ -456,21 +461,28 @@ class BaseTrainer:
|
|
|
456
461
|
)
|
|
457
462
|
self.run_callbacks("on_batch_end")
|
|
458
463
|
if self.args.plots and ni in self.plot_idx:
|
|
459
|
-
batch = {**batch, **metadata}
|
|
460
464
|
self.plot_training_samples(batch, ni)
|
|
461
465
|
|
|
462
466
|
self.run_callbacks("on_train_batch_end")
|
|
463
467
|
|
|
464
468
|
self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
|
|
469
|
+
|
|
465
470
|
self.run_callbacks("on_train_epoch_end")
|
|
466
471
|
if RANK in {-1, 0}:
|
|
467
|
-
final_epoch = epoch + 1 >= self.epochs
|
|
468
472
|
self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
|
|
469
473
|
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
+
# Validation
|
|
475
|
+
final_epoch = epoch + 1 >= self.epochs
|
|
476
|
+
if self.args.val or final_epoch or self.stopper.possible_stop or self.stop:
|
|
477
|
+
self._clear_memory(threshold=0.5) # prevent VRAM spike
|
|
478
|
+
self.metrics, self.fitness = self.validate()
|
|
479
|
+
|
|
480
|
+
# NaN recovery
|
|
481
|
+
if self._handle_nan_recovery(epoch):
|
|
482
|
+
continue
|
|
483
|
+
|
|
484
|
+
self.nan_recovery_attempts = 0
|
|
485
|
+
if RANK in {-1, 0}:
|
|
474
486
|
self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr})
|
|
475
487
|
self.stop |= self.stopper(epoch + 1, self.fitness) or final_epoch
|
|
476
488
|
if self.args.time:
|
|
@@ -503,11 +515,11 @@ class BaseTrainer:
|
|
|
503
515
|
break # must break all DDP ranks
|
|
504
516
|
epoch += 1
|
|
505
517
|
|
|
518
|
+
seconds = time.time() - self.train_time_start
|
|
519
|
+
LOGGER.info(f"\n{epoch - self.start_epoch + 1} epochs completed in {seconds / 3600:.3f} hours.")
|
|
520
|
+
# Do final val with best.pt
|
|
521
|
+
self.final_eval()
|
|
506
522
|
if RANK in {-1, 0}:
|
|
507
|
-
# Do final val with best.pt
|
|
508
|
-
seconds = time.time() - self.train_time_start
|
|
509
|
-
LOGGER.info(f"\n{epoch - self.start_epoch + 1} epochs completed in {seconds / 3600:.3f} hours.")
|
|
510
|
-
self.final_eval()
|
|
511
523
|
if self.args.plots:
|
|
512
524
|
self.plot_metrics()
|
|
513
525
|
self.run_callbacks("on_train_end")
|
|
@@ -538,7 +550,7 @@ class BaseTrainer:
|
|
|
538
550
|
total = torch.cuda.get_device_properties(self.device).total_memory
|
|
539
551
|
return ((memory / total) if total > 0 else 0) if fraction else (memory / 2**30)
|
|
540
552
|
|
|
541
|
-
def _clear_memory(self, threshold: float = None):
|
|
553
|
+
def _clear_memory(self, threshold: float | None = None):
|
|
542
554
|
"""Clear accelerator memory by calling garbage collector and emptying cache."""
|
|
543
555
|
if threshold:
|
|
544
556
|
assert 0 <= threshold <= 1, "Threshold must be between 0 and 1."
|
|
@@ -556,7 +568,10 @@ class BaseTrainer:
|
|
|
556
568
|
"""Read results.csv into a dictionary using polars."""
|
|
557
569
|
import polars as pl # scope for faster 'import ultralytics'
|
|
558
570
|
|
|
559
|
-
|
|
571
|
+
try:
|
|
572
|
+
return pl.read_csv(self.csv, infer_schema_length=None).to_dict(as_series=False)
|
|
573
|
+
except Exception:
|
|
574
|
+
return {}
|
|
560
575
|
|
|
561
576
|
def _model_train(self):
|
|
562
577
|
"""Set model in training mode."""
|
|
@@ -580,6 +595,7 @@ class BaseTrainer:
|
|
|
580
595
|
"ema": deepcopy(unwrap_model(self.ema.ema)).half(),
|
|
581
596
|
"updates": self.ema.updates,
|
|
582
597
|
"optimizer": convert_optimizer_state_dict_to_fp16(deepcopy(self.optimizer.state_dict())),
|
|
598
|
+
"scaler": self.scaler.state_dict(),
|
|
583
599
|
"train_args": vars(self.args), # save as dict
|
|
584
600
|
"train_metrics": {**self.metrics, **{"fitness": self.fitness}},
|
|
585
601
|
"train_results": self.read_results_csv(),
|
|
@@ -599,6 +615,7 @@ class BaseTrainer:
|
|
|
599
615
|
serialized_ckpt = buffer.getvalue() # get the serialized content to save
|
|
600
616
|
|
|
601
617
|
# Save checkpoints
|
|
618
|
+
self.wdir.mkdir(parents=True, exist_ok=True) # ensure weights directory exists
|
|
602
619
|
self.last.write_bytes(serialized_ckpt) # save last.pt
|
|
603
620
|
if self.best_fitness == self.fitness:
|
|
604
621
|
self.best.write_bytes(serialized_ckpt) # save best.pt
|
|
@@ -606,8 +623,7 @@ class BaseTrainer:
|
|
|
606
623
|
(self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'
|
|
607
624
|
|
|
608
625
|
def get_dataset(self):
|
|
609
|
-
"""
|
|
610
|
-
Get train and validation datasets from data dictionary.
|
|
626
|
+
"""Get train and validation datasets from data dictionary.
|
|
611
627
|
|
|
612
628
|
Returns:
|
|
613
629
|
(dict): A dictionary containing the training/validation/test dataset and category names.
|
|
@@ -615,7 +631,7 @@ class BaseTrainer:
|
|
|
615
631
|
try:
|
|
616
632
|
if self.args.task == "classify":
|
|
617
633
|
data = check_cls_dataset(self.args.data)
|
|
618
|
-
elif self.args.data.rsplit(".", 1)[-1] == "ndjson":
|
|
634
|
+
elif str(self.args.data).rsplit(".", 1)[-1] == "ndjson":
|
|
619
635
|
# Convert NDJSON to YOLO format
|
|
620
636
|
import asyncio
|
|
621
637
|
|
|
@@ -624,7 +640,7 @@ class BaseTrainer:
|
|
|
624
640
|
yaml_path = asyncio.run(convert_ndjson_to_yolo(self.args.data))
|
|
625
641
|
self.args.data = str(yaml_path)
|
|
626
642
|
data = check_det_dataset(self.args.data)
|
|
627
|
-
elif self.args.data.rsplit(".", 1)[-1] in {"yaml", "yml"} or self.args.task in {
|
|
643
|
+
elif str(self.args.data).rsplit(".", 1)[-1] in {"yaml", "yml"} or self.args.task in {
|
|
628
644
|
"detect",
|
|
629
645
|
"segment",
|
|
630
646
|
"pose",
|
|
@@ -642,8 +658,7 @@ class BaseTrainer:
|
|
|
642
658
|
return data
|
|
643
659
|
|
|
644
660
|
def setup_model(self):
|
|
645
|
-
"""
|
|
646
|
-
Load, create, or download model for any task.
|
|
661
|
+
"""Load, create, or download model for any task.
|
|
647
662
|
|
|
648
663
|
Returns:
|
|
649
664
|
(dict): Optional checkpoint to resume training from.
|
|
@@ -664,7 +679,7 @@ class BaseTrainer:
|
|
|
664
679
|
def optimizer_step(self):
|
|
665
680
|
"""Perform a single step of the training optimizer with gradient clipping and EMA update."""
|
|
666
681
|
self.scaler.unscale_(self.optimizer) # unscale gradients
|
|
667
|
-
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)
|
|
682
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)
|
|
668
683
|
self.scaler.step(self.optimizer)
|
|
669
684
|
self.scaler.update()
|
|
670
685
|
self.optimizer.zero_grad()
|
|
@@ -676,14 +691,19 @@ class BaseTrainer:
|
|
|
676
691
|
return batch
|
|
677
692
|
|
|
678
693
|
def validate(self):
|
|
679
|
-
"""
|
|
680
|
-
Run validation on val set using self.validator.
|
|
694
|
+
"""Run validation on val set using self.validator.
|
|
681
695
|
|
|
682
696
|
Returns:
|
|
683
697
|
metrics (dict): Dictionary of validation metrics.
|
|
684
698
|
fitness (float): Fitness score for the validation.
|
|
685
699
|
"""
|
|
700
|
+
if self.ema and self.world_size > 1:
|
|
701
|
+
# Sync EMA buffers from rank 0 to all ranks
|
|
702
|
+
for buffer in self.ema.ema.buffers():
|
|
703
|
+
dist.broadcast(buffer, src=0)
|
|
686
704
|
metrics = self.validator(self)
|
|
705
|
+
if metrics is None:
|
|
706
|
+
return None, None
|
|
687
707
|
fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
|
|
688
708
|
if not self.best_fitness or self.best_fitness < fitness:
|
|
689
709
|
self.best_fitness = fitness
|
|
@@ -694,11 +714,11 @@ class BaseTrainer:
|
|
|
694
714
|
raise NotImplementedError("This task trainer doesn't support loading cfg files")
|
|
695
715
|
|
|
696
716
|
def get_validator(self):
|
|
697
|
-
"""
|
|
717
|
+
"""Raise NotImplementedError (must be implemented by subclasses)."""
|
|
698
718
|
raise NotImplementedError("get_validator function not implemented in trainer")
|
|
699
719
|
|
|
700
720
|
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
|
|
701
|
-
"""
|
|
721
|
+
"""Raise NotImplementedError (must return a `torch.utils.data.DataLoader` in subclasses)."""
|
|
702
722
|
raise NotImplementedError("get_dataloader function not implemented in trainer")
|
|
703
723
|
|
|
704
724
|
def build_dataset(self, img_path, mode="train", batch=None):
|
|
@@ -706,10 +726,9 @@ class BaseTrainer:
|
|
|
706
726
|
raise NotImplementedError("build_dataset function not implemented in trainer")
|
|
707
727
|
|
|
708
728
|
def label_loss_items(self, loss_items=None, prefix="train"):
|
|
709
|
-
"""
|
|
710
|
-
Return a loss dict with labelled training loss items tensor.
|
|
729
|
+
"""Return a loss dict with labeled training loss items tensor.
|
|
711
730
|
|
|
712
|
-
|
|
731
|
+
Notes:
|
|
713
732
|
This is not needed for classification but necessary for segmentation & detection
|
|
714
733
|
"""
|
|
715
734
|
return {"loss": loss_items} if loss_items is not None else ["loss"]
|
|
@@ -739,14 +758,15 @@ class BaseTrainer:
|
|
|
739
758
|
"""Save training metrics to a CSV file."""
|
|
740
759
|
keys, vals = list(metrics.keys()), list(metrics.values())
|
|
741
760
|
n = len(metrics) + 2 # number of cols
|
|
742
|
-
s = "" if self.csv.exists() else (("%s," * n % tuple(["epoch", "time"] + keys)).rstrip(",") + "\n") # header
|
|
743
761
|
t = time.time() - self.train_time_start
|
|
762
|
+
self.csv.parent.mkdir(parents=True, exist_ok=True) # ensure parent directory exists
|
|
763
|
+
s = "" if self.csv.exists() else ("%s," * n % ("epoch", "time", *keys)).rstrip(",") + "\n"
|
|
744
764
|
with open(self.csv, "a", encoding="utf-8") as f:
|
|
745
|
-
f.write(s + ("%.6g," * n %
|
|
765
|
+
f.write(s + ("%.6g," * n % (self.epoch + 1, t, *vals)).rstrip(",") + "\n")
|
|
746
766
|
|
|
747
767
|
def plot_metrics(self):
|
|
748
|
-
"""Plot
|
|
749
|
-
|
|
768
|
+
"""Plot metrics from a CSV file."""
|
|
769
|
+
plot_results(file=self.csv, on_plot=self.on_plot) # save results.png
|
|
750
770
|
|
|
751
771
|
def on_plot(self, name, data=None):
|
|
752
772
|
"""Register plots (e.g. to be consumed in callbacks)."""
|
|
@@ -755,20 +775,20 @@ class BaseTrainer:
|
|
|
755
775
|
|
|
756
776
|
def final_eval(self):
|
|
757
777
|
"""Perform final evaluation and validation for object detection YOLO model."""
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
if
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
778
|
+
model = self.best if self.best.exists() else None
|
|
779
|
+
with torch_distributed_zero_first(LOCAL_RANK): # strip only on GPU 0; other GPUs should wait
|
|
780
|
+
if RANK in {-1, 0}:
|
|
781
|
+
ckpt = strip_optimizer(self.last) if self.last.exists() else {}
|
|
782
|
+
if model:
|
|
783
|
+
# update best.pt train_metrics from last.pt
|
|
784
|
+
strip_optimizer(self.best, updates={"train_results": ckpt.get("train_results")})
|
|
785
|
+
if model:
|
|
786
|
+
LOGGER.info(f"\nValidating {model}...")
|
|
787
|
+
self.validator.args.plots = self.args.plots
|
|
788
|
+
self.validator.args.compile = False # disable final val compile as too slow
|
|
789
|
+
self.metrics = self.validator(model=model)
|
|
790
|
+
self.metrics.pop("fitness", None)
|
|
791
|
+
self.run_callbacks("on_fit_epoch_end")
|
|
772
792
|
|
|
773
793
|
def check_resume(self, overrides):
|
|
774
794
|
"""Check if resume checkpoint exists and update arguments accordingly."""
|
|
@@ -791,10 +811,29 @@ class BaseTrainer:
|
|
|
791
811
|
"batch",
|
|
792
812
|
"device",
|
|
793
813
|
"close_mosaic",
|
|
814
|
+
"augmentations",
|
|
815
|
+
"save_period",
|
|
816
|
+
"workers",
|
|
817
|
+
"cache",
|
|
818
|
+
"patience",
|
|
819
|
+
"time",
|
|
820
|
+
"freeze",
|
|
821
|
+
"val",
|
|
822
|
+
"plots",
|
|
794
823
|
): # allow arg updates to reduce memory or update device on resume
|
|
795
824
|
if k in overrides:
|
|
796
825
|
setattr(self.args, k, overrides[k])
|
|
797
826
|
|
|
827
|
+
# Handle augmentations parameter for resume: check if user provided custom augmentations
|
|
828
|
+
if ckpt_args.get("augmentations") is not None:
|
|
829
|
+
# Augmentations were saved in checkpoint as reprs but can't be restored automatically
|
|
830
|
+
LOGGER.warning(
|
|
831
|
+
"Custom Albumentations transforms were used in the original training run but are not "
|
|
832
|
+
"being restored. To preserve custom augmentations when resuming, you need to pass the "
|
|
833
|
+
"'augmentations' parameter again to get expected results. Example: \n"
|
|
834
|
+
f"model.train(resume=True, augmentations={ckpt_args['augmentations']})"
|
|
835
|
+
)
|
|
836
|
+
|
|
798
837
|
except Exception as e:
|
|
799
838
|
raise FileNotFoundError(
|
|
800
839
|
"Resume checkpoint not found. Please pass a valid checkpoint to resume from, "
|
|
@@ -802,18 +841,54 @@ class BaseTrainer:
|
|
|
802
841
|
) from e
|
|
803
842
|
self.resume = resume
|
|
804
843
|
|
|
844
|
+
def _load_checkpoint_state(self, ckpt):
|
|
845
|
+
"""Load optimizer, scaler, EMA, and best_fitness from checkpoint."""
|
|
846
|
+
if ckpt.get("optimizer") is not None:
|
|
847
|
+
self.optimizer.load_state_dict(ckpt["optimizer"])
|
|
848
|
+
if ckpt.get("scaler") is not None:
|
|
849
|
+
self.scaler.load_state_dict(ckpt["scaler"])
|
|
850
|
+
if self.ema and ckpt.get("ema"):
|
|
851
|
+
self.ema = ModelEMA(self.model) # validation with EMA creates inference tensors that can't be updated
|
|
852
|
+
self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict())
|
|
853
|
+
self.ema.updates = ckpt["updates"]
|
|
854
|
+
self.best_fitness = ckpt.get("best_fitness", 0.0)
|
|
855
|
+
|
|
856
|
+
def _handle_nan_recovery(self, epoch):
|
|
857
|
+
"""Detect and recover from NaN/Inf loss and fitness collapse by loading last checkpoint."""
|
|
858
|
+
loss_nan = self.loss is not None and not self.loss.isfinite()
|
|
859
|
+
fitness_nan = self.fitness is not None and not np.isfinite(self.fitness)
|
|
860
|
+
fitness_collapse = self.best_fitness and self.best_fitness > 0 and self.fitness == 0
|
|
861
|
+
corrupted = RANK in {-1, 0} and loss_nan and (fitness_nan or fitness_collapse)
|
|
862
|
+
reason = "Loss NaN/Inf" if loss_nan else "Fitness NaN/Inf" if fitness_nan else "Fitness collapse"
|
|
863
|
+
if RANK != -1: # DDP: broadcast to all ranks
|
|
864
|
+
broadcast_list = [corrupted if RANK == 0 else None]
|
|
865
|
+
dist.broadcast_object_list(broadcast_list, 0)
|
|
866
|
+
corrupted = broadcast_list[0]
|
|
867
|
+
if not corrupted:
|
|
868
|
+
return False
|
|
869
|
+
if epoch == self.start_epoch or not self.last.exists():
|
|
870
|
+
LOGGER.warning(f"{reason} detected but can not recover from last.pt...")
|
|
871
|
+
return False # Cannot recover on first epoch, let training continue
|
|
872
|
+
self.nan_recovery_attempts += 1
|
|
873
|
+
if self.nan_recovery_attempts > 3:
|
|
874
|
+
raise RuntimeError(f"Training failed: NaN persisted for {self.nan_recovery_attempts} epochs")
|
|
875
|
+
LOGGER.warning(f"{reason} detected (attempt {self.nan_recovery_attempts}/3), recovering from last.pt...")
|
|
876
|
+
self._model_train() # set model to train mode before loading checkpoint to avoid inference tensor errors
|
|
877
|
+
_, ckpt = load_checkpoint(self.last)
|
|
878
|
+
ema_state = ckpt["ema"].float().state_dict()
|
|
879
|
+
if not all(torch.isfinite(v).all() for v in ema_state.values() if isinstance(v, torch.Tensor)):
|
|
880
|
+
raise RuntimeError(f"Checkpoint {self.last} is corrupted with NaN/Inf weights")
|
|
881
|
+
unwrap_model(self.model).load_state_dict(ema_state) # Load EMA weights into model
|
|
882
|
+
self._load_checkpoint_state(ckpt) # Load optimizer/scaler/EMA/best_fitness
|
|
883
|
+
del ckpt, ema_state
|
|
884
|
+
self.scheduler.last_epoch = epoch - 1
|
|
885
|
+
return True
|
|
886
|
+
|
|
805
887
|
def resume_training(self, ckpt):
|
|
806
888
|
"""Resume YOLO training from given epoch and best fitness."""
|
|
807
889
|
if ckpt is None or not self.resume:
|
|
808
890
|
return
|
|
809
|
-
best_fitness = 0.0
|
|
810
891
|
start_epoch = ckpt.get("epoch", -1) + 1
|
|
811
|
-
if ckpt.get("optimizer", None) is not None:
|
|
812
|
-
self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer
|
|
813
|
-
best_fitness = ckpt["best_fitness"]
|
|
814
|
-
if self.ema and ckpt.get("ema"):
|
|
815
|
-
self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict()) # EMA
|
|
816
|
-
self.ema.updates = ckpt["updates"]
|
|
817
892
|
assert start_epoch > 0, (
|
|
818
893
|
f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n"
|
|
819
894
|
f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
|
|
@@ -824,7 +899,7 @@ class BaseTrainer:
|
|
|
824
899
|
f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
|
|
825
900
|
)
|
|
826
901
|
self.epochs += ckpt["epoch"] # finetune additional epochs
|
|
827
|
-
self.
|
|
902
|
+
self._load_checkpoint_state(ckpt)
|
|
828
903
|
self.start_epoch = start_epoch
|
|
829
904
|
if start_epoch > (self.epochs - self.args.close_mosaic):
|
|
830
905
|
self._close_dataloader_mosaic()
|
|
@@ -838,18 +913,16 @@ class BaseTrainer:
|
|
|
838
913
|
self.train_loader.dataset.close_mosaic(hyp=copy(self.args))
|
|
839
914
|
|
|
840
915
|
def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
|
|
841
|
-
"""
|
|
842
|
-
Construct an optimizer for the given model.
|
|
916
|
+
"""Construct an optimizer for the given model.
|
|
843
917
|
|
|
844
918
|
Args:
|
|
845
919
|
model (torch.nn.Module): The model for which to build an optimizer.
|
|
846
|
-
name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected
|
|
847
|
-
|
|
920
|
+
name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected based on the
|
|
921
|
+
number of iterations.
|
|
848
922
|
lr (float, optional): The learning rate for the optimizer.
|
|
849
923
|
momentum (float, optional): The momentum factor for the optimizer.
|
|
850
924
|
decay (float, optional): The weight decay for the optimizer.
|
|
851
|
-
iterations (float, optional): The number of iterations, which determines the optimizer if
|
|
852
|
-
name is 'auto'.
|
|
925
|
+
iterations (float, optional): The number of iterations, which determines the optimizer if name is 'auto'.
|
|
853
926
|
|
|
854
927
|
Returns:
|
|
855
928
|
(torch.optim.Optimizer): The constructed optimizer.
|