dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/METADATA +64 -74
- dgenerate_ultralytics_headless-8.4.7.dist-info/RECORD +311 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -9
- tests/conftest.py +8 -15
- tests/test_cli.py +1 -1
- tests/test_cuda.py +13 -10
- tests/test_engine.py +9 -9
- tests/test_exports.py +65 -13
- tests/test_integrations.py +13 -13
- tests/test_python.py +125 -69
- tests/test_solutions.py +161 -152
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +86 -92
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/TT100K.yaml +346 -0
- ultralytics/cfg/datasets/VOC.yaml +15 -16
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +21 -0
- ultralytics/cfg/datasets/coco12-formats.yaml +101 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +5 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +4 -2
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
- ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
- ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
- ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
- ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
- ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
- ultralytics/cfg/models/26/yolo26.yaml +52 -0
- ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
- ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +5 -6
- ultralytics/data/augment.py +300 -475
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +147 -25
- ultralytics/data/converter.py +108 -87
- ultralytics/data/dataset.py +47 -75
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +36 -45
- ultralytics/engine/exporter.py +351 -263
- ultralytics/engine/model.py +186 -225
- ultralytics/engine/predictor.py +45 -54
- ultralytics/engine/results.py +198 -325
- ultralytics/engine/trainer.py +165 -106
- ultralytics/engine/tuner.py +41 -43
- ultralytics/engine/validator.py +55 -38
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +18 -30
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +5 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +19 -37
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +200 -73
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2248 -350
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +10 -13
- ultralytics/models/yolo/classify/train.py +12 -33
- ultralytics/models/yolo/classify/val.py +30 -29
- ultralytics/models/yolo/detect/predict.py +9 -12
- ultralytics/models/yolo/detect/train.py +17 -23
- ultralytics/models/yolo/detect/val.py +77 -59
- ultralytics/models/yolo/model.py +43 -60
- ultralytics/models/yolo/obb/predict.py +7 -16
- ultralytics/models/yolo/obb/train.py +14 -17
- ultralytics/models/yolo/obb/val.py +40 -37
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +7 -22
- ultralytics/models/yolo/pose/train.py +13 -16
- ultralytics/models/yolo/pose/val.py +39 -58
- ultralytics/models/yolo/segment/predict.py +17 -21
- ultralytics/models/yolo/segment/train.py +7 -10
- ultralytics/models/yolo/segment/val.py +95 -47
- ultralytics/models/yolo/world/train.py +8 -14
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +36 -44
- ultralytics/models/yolo/yoloe/train_seg.py +11 -11
- ultralytics/models/yolo/yoloe/val.py +15 -20
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +159 -85
- ultralytics/nn/modules/__init__.py +68 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +260 -224
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +831 -299
- ultralytics/nn/modules/transformer.py +76 -88
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +180 -195
- ultralytics/nn/text_model.py +45 -69
- ultralytics/optim/__init__.py +5 -0
- ultralytics/optim/muon.py +338 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +13 -19
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +6 -7
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +8 -14
- ultralytics/solutions/instance_segmentation.py +6 -9
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +34 -32
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +10 -15
- ultralytics/solutions/solutions.py +77 -76
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +2 -4
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +21 -37
- ultralytics/trackers/track.py +4 -7
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +124 -124
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +57 -71
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +423 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +25 -31
- ultralytics/utils/callbacks/wb.py +16 -14
- ultralytics/utils/checks.py +127 -85
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +9 -12
- ultralytics/utils/downloads.py +25 -33
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +4 -236
- ultralytics/utils/export/engine.py +246 -0
- ultralytics/utils/export/imx.py +117 -63
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +26 -30
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +601 -215
- ultralytics/utils/metrics.py +128 -156
- ultralytics/utils/nms.py +13 -16
- ultralytics/utils/ops.py +117 -166
- ultralytics/utils/patches.py +75 -21
- ultralytics/utils/plotting.py +75 -80
- ultralytics/utils/tal.py +125 -59
- ultralytics/utils/torch_utils.py +53 -79
- ultralytics/utils/tqdm.py +24 -21
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +19 -10
- dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/top_level.txt +0 -0
ultralytics/engine/trainer.py
CHANGED
|
@@ -3,9 +3,11 @@
|
|
|
3
3
|
Train a model on a dataset.
|
|
4
4
|
|
|
5
5
|
Usage:
|
|
6
|
-
$ yolo mode=train model=
|
|
6
|
+
$ yolo mode=train model=yolo26n.pt data=coco8.yaml imgsz=640 epochs=100 batch=16
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
9
11
|
import gc
|
|
10
12
|
import math
|
|
11
13
|
import os
|
|
@@ -14,6 +16,7 @@ import time
|
|
|
14
16
|
import warnings
|
|
15
17
|
from copy import copy, deepcopy
|
|
16
18
|
from datetime import datetime, timedelta
|
|
19
|
+
from functools import partial
|
|
17
20
|
from pathlib import Path
|
|
18
21
|
|
|
19
22
|
import numpy as np
|
|
@@ -25,6 +28,7 @@ from ultralytics import __version__
|
|
|
25
28
|
from ultralytics.cfg import get_cfg, get_save_dir
|
|
26
29
|
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
|
27
30
|
from ultralytics.nn.tasks import load_checkpoint
|
|
31
|
+
from ultralytics.optim import MuSGD
|
|
28
32
|
from ultralytics.utils import (
|
|
29
33
|
DEFAULT_CFG,
|
|
30
34
|
GIT,
|
|
@@ -61,8 +65,7 @@ from ultralytics.utils.torch_utils import (
|
|
|
61
65
|
|
|
62
66
|
|
|
63
67
|
class BaseTrainer:
|
|
64
|
-
"""
|
|
65
|
-
A base class for creating trainers.
|
|
68
|
+
"""A base class for creating trainers.
|
|
66
69
|
|
|
67
70
|
This class provides the foundation for training YOLO models, handling the training loop, validation, checkpointing,
|
|
68
71
|
and various training utilities. It supports both single-GPU and multi-GPU distributed training.
|
|
@@ -112,8 +115,7 @@ class BaseTrainer:
|
|
|
112
115
|
"""
|
|
113
116
|
|
|
114
117
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
115
|
-
"""
|
|
116
|
-
Initialize the BaseTrainer class.
|
|
118
|
+
"""Initialize the BaseTrainer class.
|
|
117
119
|
|
|
118
120
|
Args:
|
|
119
121
|
cfg (str, optional): Path to a configuration file.
|
|
@@ -138,7 +140,12 @@ class BaseTrainer:
|
|
|
138
140
|
if RANK in {-1, 0}:
|
|
139
141
|
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
|
140
142
|
self.args.save_dir = str(self.save_dir)
|
|
141
|
-
|
|
143
|
+
# Save run args, serializing augmentations as reprs for resume compatibility
|
|
144
|
+
args_dict = vars(self.args).copy()
|
|
145
|
+
if args_dict.get("augmentations") is not None:
|
|
146
|
+
# Serialize Albumentations transforms as their repr strings for checkpoint compatibility
|
|
147
|
+
args_dict["augmentations"] = [repr(t) for t in args_dict["augmentations"]]
|
|
148
|
+
YAML.save(self.save_dir / "args.yaml", args_dict) # save run args
|
|
142
149
|
self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt" # checkpoint paths
|
|
143
150
|
self.save_period = self.args.save_period
|
|
144
151
|
|
|
@@ -152,8 +159,29 @@ class BaseTrainer:
|
|
|
152
159
|
if self.device.type in {"cpu", "mps"}:
|
|
153
160
|
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
|
|
154
161
|
|
|
162
|
+
# Callbacks - initialize early so on_pretrain_routine_start can capture original args.data
|
|
163
|
+
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
|
164
|
+
|
|
165
|
+
if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3'
|
|
166
|
+
world_size = len(self.args.device.split(","))
|
|
167
|
+
elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
|
|
168
|
+
world_size = len(self.args.device)
|
|
169
|
+
elif self.args.device in {"cpu", "mps"}: # i.e. device='cpu' or 'mps'
|
|
170
|
+
world_size = 0
|
|
171
|
+
elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number
|
|
172
|
+
world_size = 1 # default to device 0
|
|
173
|
+
else: # i.e. device=None or device=''
|
|
174
|
+
world_size = 0
|
|
175
|
+
|
|
176
|
+
self.ddp = world_size > 1 and "LOCAL_RANK" not in os.environ
|
|
177
|
+
self.world_size = world_size
|
|
178
|
+
# Run on_pretrain_routine_start before get_dataset() to capture original args.data (e.g., ul:// URIs)
|
|
179
|
+
if RANK in {-1, 0} and not self.ddp:
|
|
180
|
+
callbacks.add_integration_callbacks(self)
|
|
181
|
+
self.run_callbacks("on_pretrain_routine_start")
|
|
182
|
+
|
|
155
183
|
# Model and Dataset
|
|
156
|
-
self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e.
|
|
184
|
+
self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolo26n -> yolo26n.pt
|
|
157
185
|
with torch_distributed_zero_first(LOCAL_RANK): # avoid auto-downloading dataset multiple times
|
|
158
186
|
self.data = self.get_dataset()
|
|
159
187
|
|
|
@@ -175,28 +203,6 @@ class BaseTrainer:
|
|
|
175
203
|
self.plot_idx = [0, 1, 2]
|
|
176
204
|
self.nan_recovery_attempts = 0
|
|
177
205
|
|
|
178
|
-
# Callbacks
|
|
179
|
-
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
|
180
|
-
|
|
181
|
-
if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3'
|
|
182
|
-
world_size = len(self.args.device.split(","))
|
|
183
|
-
elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
|
|
184
|
-
world_size = len(self.args.device)
|
|
185
|
-
elif self.args.device in {"cpu", "mps"}: # i.e. device='cpu' or 'mps'
|
|
186
|
-
world_size = 0
|
|
187
|
-
elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number
|
|
188
|
-
world_size = 1 # default to device 0
|
|
189
|
-
else: # i.e. device=None or device=''
|
|
190
|
-
world_size = 0
|
|
191
|
-
|
|
192
|
-
self.ddp = world_size > 1 and "LOCAL_RANK" not in os.environ
|
|
193
|
-
self.world_size = world_size
|
|
194
|
-
# Run subprocess if DDP training, else train normally
|
|
195
|
-
if RANK in {-1, 0} and not self.ddp:
|
|
196
|
-
callbacks.add_integration_callbacks(self)
|
|
197
|
-
# Start console logging immediately at trainer initialization
|
|
198
|
-
self.run_callbacks("on_pretrain_routine_start")
|
|
199
|
-
|
|
200
206
|
def add_callback(self, event: str, callback):
|
|
201
207
|
"""Append the given callback to the event's callback list."""
|
|
202
208
|
self.callbacks[event].append(callback)
|
|
@@ -318,18 +324,18 @@ class BaseTrainer:
|
|
|
318
324
|
self.train_loader = self.get_dataloader(
|
|
319
325
|
self.data["train"], batch_size=batch_size, rank=LOCAL_RANK, mode="train"
|
|
320
326
|
)
|
|
327
|
+
# Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
|
|
328
|
+
self.test_loader = self.get_dataloader(
|
|
329
|
+
self.data.get("val") or self.data.get("test"),
|
|
330
|
+
batch_size=batch_size if self.args.task == "obb" else batch_size * 2,
|
|
331
|
+
rank=LOCAL_RANK,
|
|
332
|
+
mode="val",
|
|
333
|
+
)
|
|
334
|
+
self.validator = self.get_validator()
|
|
335
|
+
self.ema = ModelEMA(self.model)
|
|
321
336
|
if RANK in {-1, 0}:
|
|
322
|
-
# Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
|
|
323
|
-
self.test_loader = self.get_dataloader(
|
|
324
|
-
self.data.get("val") or self.data.get("test"),
|
|
325
|
-
batch_size=batch_size if self.args.task == "obb" else batch_size * 2,
|
|
326
|
-
rank=-1,
|
|
327
|
-
mode="val",
|
|
328
|
-
)
|
|
329
|
-
self.validator = self.get_validator()
|
|
330
337
|
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val")
|
|
331
338
|
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
|
|
332
|
-
self.ema = ModelEMA(self.model)
|
|
333
339
|
if self.args.plots:
|
|
334
340
|
self.plot_training_labels()
|
|
335
341
|
|
|
@@ -403,10 +409,15 @@ class BaseTrainer:
|
|
|
403
409
|
if ni <= nw:
|
|
404
410
|
xi = [0, nw] # x interp
|
|
405
411
|
self.accumulate = max(1, int(np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round()))
|
|
406
|
-
for
|
|
412
|
+
for x in self.optimizer.param_groups:
|
|
407
413
|
# Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
|
|
408
414
|
x["lr"] = np.interp(
|
|
409
|
-
ni,
|
|
415
|
+
ni,
|
|
416
|
+
xi,
|
|
417
|
+
[
|
|
418
|
+
self.args.warmup_bias_lr if x.get("param_group") == "bias" else 0.0,
|
|
419
|
+
x["initial_lr"] * self.lf(epoch),
|
|
420
|
+
],
|
|
410
421
|
)
|
|
411
422
|
if "momentum" in x:
|
|
412
423
|
x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
|
|
@@ -460,17 +471,20 @@ class BaseTrainer:
|
|
|
460
471
|
|
|
461
472
|
self.run_callbacks("on_train_batch_end")
|
|
462
473
|
|
|
474
|
+
if hasattr(unwrap_model(self.model).criterion, "update"):
|
|
475
|
+
unwrap_model(self.model).criterion.update()
|
|
476
|
+
|
|
463
477
|
self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
|
|
464
478
|
|
|
465
479
|
self.run_callbacks("on_train_epoch_end")
|
|
466
480
|
if RANK in {-1, 0}:
|
|
467
|
-
final_epoch = epoch + 1 >= self.epochs
|
|
468
481
|
self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
|
|
469
482
|
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
483
|
+
# Validation
|
|
484
|
+
final_epoch = epoch + 1 >= self.epochs
|
|
485
|
+
if self.args.val or final_epoch or self.stopper.possible_stop or self.stop:
|
|
486
|
+
self._clear_memory(threshold=0.5) # prevent VRAM spike
|
|
487
|
+
self.metrics, self.fitness = self.validate()
|
|
474
488
|
|
|
475
489
|
# NaN recovery
|
|
476
490
|
if self._handle_nan_recovery(epoch):
|
|
@@ -510,11 +524,11 @@ class BaseTrainer:
|
|
|
510
524
|
break # must break all DDP ranks
|
|
511
525
|
epoch += 1
|
|
512
526
|
|
|
527
|
+
seconds = time.time() - self.train_time_start
|
|
528
|
+
LOGGER.info(f"\n{epoch - self.start_epoch + 1} epochs completed in {seconds / 3600:.3f} hours.")
|
|
529
|
+
# Do final val with best.pt
|
|
530
|
+
self.final_eval()
|
|
513
531
|
if RANK in {-1, 0}:
|
|
514
|
-
# Do final val with best.pt
|
|
515
|
-
seconds = time.time() - self.train_time_start
|
|
516
|
-
LOGGER.info(f"\n{epoch - self.start_epoch + 1} epochs completed in {seconds / 3600:.3f} hours.")
|
|
517
|
-
self.final_eval()
|
|
518
532
|
if self.args.plots:
|
|
519
533
|
self.plot_metrics()
|
|
520
534
|
self.run_callbacks("on_train_end")
|
|
@@ -545,7 +559,7 @@ class BaseTrainer:
|
|
|
545
559
|
total = torch.cuda.get_device_properties(self.device).total_memory
|
|
546
560
|
return ((memory / total) if total > 0 else 0) if fraction else (memory / 2**30)
|
|
547
561
|
|
|
548
|
-
def _clear_memory(self, threshold: float = None):
|
|
562
|
+
def _clear_memory(self, threshold: float | None = None):
|
|
549
563
|
"""Clear accelerator memory by calling garbage collector and emptying cache."""
|
|
550
564
|
if threshold:
|
|
551
565
|
assert 0 <= threshold <= 1, "Threshold must be between 0 and 1."
|
|
@@ -618,25 +632,26 @@ class BaseTrainer:
|
|
|
618
632
|
(self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'
|
|
619
633
|
|
|
620
634
|
def get_dataset(self):
|
|
621
|
-
"""
|
|
622
|
-
Get train and validation datasets from data dictionary.
|
|
635
|
+
"""Get train and validation datasets from data dictionary.
|
|
623
636
|
|
|
624
637
|
Returns:
|
|
625
638
|
(dict): A dictionary containing the training/validation/test dataset and category names.
|
|
626
639
|
"""
|
|
627
640
|
try:
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
# Convert NDJSON to YOLO format
|
|
641
|
+
# Convert ul:// platform URIs and NDJSON files to local dataset format first
|
|
642
|
+
data_str = str(self.args.data)
|
|
643
|
+
if data_str.endswith(".ndjson") or (data_str.startswith("ul://") and "/datasets/" in data_str):
|
|
632
644
|
import asyncio
|
|
633
645
|
|
|
634
646
|
from ultralytics.data.converter import convert_ndjson_to_yolo
|
|
647
|
+
from ultralytics.utils.checks import check_file
|
|
635
648
|
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
649
|
+
self.args.data = str(asyncio.run(convert_ndjson_to_yolo(check_file(self.args.data))))
|
|
650
|
+
|
|
651
|
+
# Task-specific dataset checking
|
|
652
|
+
if self.args.task == "classify":
|
|
653
|
+
data = check_cls_dataset(self.args.data)
|
|
654
|
+
elif str(self.args.data).rsplit(".", 1)[-1] in {"yaml", "yml"} or self.args.task in {
|
|
640
655
|
"detect",
|
|
641
656
|
"segment",
|
|
642
657
|
"pose",
|
|
@@ -654,8 +669,7 @@ class BaseTrainer:
|
|
|
654
669
|
return data
|
|
655
670
|
|
|
656
671
|
def setup_model(self):
|
|
657
|
-
"""
|
|
658
|
-
Load, create, or download model for any task.
|
|
672
|
+
"""Load, create, or download model for any task.
|
|
659
673
|
|
|
660
674
|
Returns:
|
|
661
675
|
(dict): Optional checkpoint to resume training from.
|
|
@@ -688,14 +702,19 @@ class BaseTrainer:
|
|
|
688
702
|
return batch
|
|
689
703
|
|
|
690
704
|
def validate(self):
|
|
691
|
-
"""
|
|
692
|
-
Run validation on val set using self.validator.
|
|
705
|
+
"""Run validation on val set using self.validator.
|
|
693
706
|
|
|
694
707
|
Returns:
|
|
695
708
|
metrics (dict): Dictionary of validation metrics.
|
|
696
709
|
fitness (float): Fitness score for the validation.
|
|
697
710
|
"""
|
|
711
|
+
if self.ema and self.world_size > 1:
|
|
712
|
+
# Sync EMA buffers from rank 0 to all ranks
|
|
713
|
+
for buffer in self.ema.ema.buffers():
|
|
714
|
+
dist.broadcast(buffer, src=0)
|
|
698
715
|
metrics = self.validator(self)
|
|
716
|
+
if metrics is None:
|
|
717
|
+
return None, None
|
|
699
718
|
fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
|
|
700
719
|
if not self.best_fitness or self.best_fitness < fitness:
|
|
701
720
|
self.best_fitness = fitness
|
|
@@ -706,11 +725,11 @@ class BaseTrainer:
|
|
|
706
725
|
raise NotImplementedError("This task trainer doesn't support loading cfg files")
|
|
707
726
|
|
|
708
727
|
def get_validator(self):
|
|
709
|
-
"""
|
|
728
|
+
"""Raise NotImplementedError (must be implemented by subclasses)."""
|
|
710
729
|
raise NotImplementedError("get_validator function not implemented in trainer")
|
|
711
730
|
|
|
712
731
|
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
|
|
713
|
-
"""
|
|
732
|
+
"""Raise NotImplementedError (must return a `torch.utils.data.DataLoader` in subclasses)."""
|
|
714
733
|
raise NotImplementedError("get_dataloader function not implemented in trainer")
|
|
715
734
|
|
|
716
735
|
def build_dataset(self, img_path, mode="train", batch=None):
|
|
@@ -718,10 +737,9 @@ class BaseTrainer:
|
|
|
718
737
|
raise NotImplementedError("build_dataset function not implemented in trainer")
|
|
719
738
|
|
|
720
739
|
def label_loss_items(self, loss_items=None, prefix="train"):
|
|
721
|
-
"""
|
|
722
|
-
Return a loss dict with labelled training loss items tensor.
|
|
740
|
+
"""Return a loss dict with labeled training loss items tensor.
|
|
723
741
|
|
|
724
|
-
|
|
742
|
+
Notes:
|
|
725
743
|
This is not needed for classification but necessary for segmentation & detection
|
|
726
744
|
"""
|
|
727
745
|
return {"loss": loss_items} if loss_items is not None else ["loss"]
|
|
@@ -753,9 +771,9 @@ class BaseTrainer:
|
|
|
753
771
|
n = len(metrics) + 2 # number of cols
|
|
754
772
|
t = time.time() - self.train_time_start
|
|
755
773
|
self.csv.parent.mkdir(parents=True, exist_ok=True) # ensure parent directory exists
|
|
756
|
-
s = "" if self.csv.exists() else (
|
|
774
|
+
s = "" if self.csv.exists() else ("%s," * n % ("epoch", "time", *keys)).rstrip(",") + "\n"
|
|
757
775
|
with open(self.csv, "a", encoding="utf-8") as f:
|
|
758
|
-
f.write(s + ("%.6g," * n %
|
|
776
|
+
f.write(s + ("%.6g," * n % (self.epoch + 1, t, *vals)).rstrip(",") + "\n")
|
|
759
777
|
|
|
760
778
|
def plot_metrics(self):
|
|
761
779
|
"""Plot metrics from a CSV file."""
|
|
@@ -768,20 +786,20 @@ class BaseTrainer:
|
|
|
768
786
|
|
|
769
787
|
def final_eval(self):
|
|
770
788
|
"""Perform final evaluation and validation for object detection YOLO model."""
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
if
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
789
|
+
model = self.best if self.best.exists() else None
|
|
790
|
+
with torch_distributed_zero_first(LOCAL_RANK): # strip only on GPU 0; other GPUs should wait
|
|
791
|
+
if RANK in {-1, 0}:
|
|
792
|
+
ckpt = strip_optimizer(self.last) if self.last.exists() else {}
|
|
793
|
+
if model:
|
|
794
|
+
# update best.pt train_metrics from last.pt
|
|
795
|
+
strip_optimizer(self.best, updates={"train_results": ckpt.get("train_results")})
|
|
796
|
+
if model:
|
|
797
|
+
LOGGER.info(f"\nValidating {model}...")
|
|
798
|
+
self.validator.args.plots = self.args.plots
|
|
799
|
+
self.validator.args.compile = False # disable final val compile as too slow
|
|
800
|
+
self.metrics = self.validator(model=model)
|
|
801
|
+
self.metrics.pop("fitness", None)
|
|
802
|
+
self.run_callbacks("on_fit_epoch_end")
|
|
785
803
|
|
|
786
804
|
def check_resume(self, overrides):
|
|
787
805
|
"""Check if resume checkpoint exists and update arguments accordingly."""
|
|
@@ -804,10 +822,29 @@ class BaseTrainer:
|
|
|
804
822
|
"batch",
|
|
805
823
|
"device",
|
|
806
824
|
"close_mosaic",
|
|
825
|
+
"augmentations",
|
|
826
|
+
"save_period",
|
|
827
|
+
"workers",
|
|
828
|
+
"cache",
|
|
829
|
+
"patience",
|
|
830
|
+
"time",
|
|
831
|
+
"freeze",
|
|
832
|
+
"val",
|
|
833
|
+
"plots",
|
|
807
834
|
): # allow arg updates to reduce memory or update device on resume
|
|
808
835
|
if k in overrides:
|
|
809
836
|
setattr(self.args, k, overrides[k])
|
|
810
837
|
|
|
838
|
+
# Handle augmentations parameter for resume: check if user provided custom augmentations
|
|
839
|
+
if ckpt_args.get("augmentations") is not None:
|
|
840
|
+
# Augmentations were saved in checkpoint as reprs but can't be restored automatically
|
|
841
|
+
LOGGER.warning(
|
|
842
|
+
"Custom Albumentations transforms were used in the original training run but are not "
|
|
843
|
+
"being restored. To preserve custom augmentations when resuming, you need to pass the "
|
|
844
|
+
"'augmentations' parameter again to get expected results. Example: \n"
|
|
845
|
+
f"model.train(resume=True, augmentations={ckpt_args['augmentations']})"
|
|
846
|
+
)
|
|
847
|
+
|
|
811
848
|
except Exception as e:
|
|
812
849
|
raise FileNotFoundError(
|
|
813
850
|
"Resume checkpoint not found. Please pass a valid checkpoint to resume from, "
|
|
@@ -887,23 +924,21 @@ class BaseTrainer:
|
|
|
887
924
|
self.train_loader.dataset.close_mosaic(hyp=copy(self.args))
|
|
888
925
|
|
|
889
926
|
def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
|
|
890
|
-
"""
|
|
891
|
-
Construct an optimizer for the given model.
|
|
927
|
+
"""Construct an optimizer for the given model.
|
|
892
928
|
|
|
893
929
|
Args:
|
|
894
930
|
model (torch.nn.Module): The model for which to build an optimizer.
|
|
895
|
-
name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected
|
|
896
|
-
|
|
931
|
+
name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected based on the
|
|
932
|
+
number of iterations.
|
|
897
933
|
lr (float, optional): The learning rate for the optimizer.
|
|
898
934
|
momentum (float, optional): The momentum factor for the optimizer.
|
|
899
935
|
decay (float, optional): The weight decay for the optimizer.
|
|
900
|
-
iterations (float, optional): The number of iterations, which determines the optimizer if
|
|
901
|
-
name is 'auto'.
|
|
936
|
+
iterations (float, optional): The number of iterations, which determines the optimizer if name is 'auto'.
|
|
902
937
|
|
|
903
938
|
Returns:
|
|
904
939
|
(torch.optim.Optimizer): The constructed optimizer.
|
|
905
940
|
"""
|
|
906
|
-
g = [
|
|
941
|
+
g = [{}, {}, {}, {}] # optimizer parameter groups
|
|
907
942
|
bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d()
|
|
908
943
|
if name == "auto":
|
|
909
944
|
LOGGER.info(
|
|
@@ -913,38 +948,62 @@ class BaseTrainer:
|
|
|
913
948
|
)
|
|
914
949
|
nc = self.data.get("nc", 10) # number of classes
|
|
915
950
|
lr_fit = round(0.002 * 5 / (4 + nc), 6) # lr0 fit equation to 6 decimal places
|
|
916
|
-
name, lr, momentum = ("
|
|
951
|
+
name, lr, momentum = ("MuSGD", 0.01 if iterations > 10000 else lr_fit, 0.9)
|
|
917
952
|
self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for Adam
|
|
918
953
|
|
|
919
|
-
|
|
954
|
+
use_muon = name == "MuSGD"
|
|
955
|
+
for module_name, module in unwrap_model(model).named_modules():
|
|
920
956
|
for param_name, param in module.named_parameters(recurse=False):
|
|
921
957
|
fullname = f"{module_name}.{param_name}" if module_name else param_name
|
|
922
|
-
if
|
|
923
|
-
g[
|
|
958
|
+
if param.ndim >= 2 and use_muon:
|
|
959
|
+
g[3][fullname] = param # muon params
|
|
960
|
+
elif "bias" in fullname: # bias (no decay)
|
|
961
|
+
g[2][fullname] = param
|
|
924
962
|
elif isinstance(module, bn) or "logit_scale" in fullname: # weight (no decay)
|
|
925
963
|
# ContrastiveHead and BNContrastiveHead included here with 'logit_scale'
|
|
926
|
-
g[1]
|
|
964
|
+
g[1][fullname] = param
|
|
927
965
|
else: # weight (with decay)
|
|
928
|
-
g[0]
|
|
966
|
+
g[0][fullname] = param
|
|
967
|
+
if not use_muon:
|
|
968
|
+
g = [x.values() for x in g[:3]] # convert to list of params
|
|
929
969
|
|
|
930
|
-
optimizers = {"Adam", "Adamax", "AdamW", "NAdam", "RAdam", "RMSProp", "SGD", "auto"}
|
|
970
|
+
optimizers = {"Adam", "Adamax", "AdamW", "NAdam", "RAdam", "RMSProp", "SGD", "MuSGD", "auto"}
|
|
931
971
|
name = {x.lower(): x for x in optimizers}.get(name.lower())
|
|
932
972
|
if name in {"Adam", "Adamax", "AdamW", "NAdam", "RAdam"}:
|
|
933
|
-
|
|
973
|
+
optim_args = dict(lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
|
|
934
974
|
elif name == "RMSProp":
|
|
935
|
-
|
|
936
|
-
elif name == "SGD":
|
|
937
|
-
|
|
975
|
+
optim_args = dict(lr=lr, momentum=momentum)
|
|
976
|
+
elif name == "SGD" or name == "MuSGD":
|
|
977
|
+
optim_args = dict(lr=lr, momentum=momentum, nesterov=True)
|
|
938
978
|
else:
|
|
939
979
|
raise NotImplementedError(
|
|
940
980
|
f"Optimizer '{name}' not found in list of available optimizers {optimizers}. "
|
|
941
981
|
"Request support for addition optimizers at https://github.com/ultralytics/ultralytics."
|
|
942
982
|
)
|
|
943
983
|
|
|
944
|
-
|
|
945
|
-
|
|
984
|
+
num_params = [len(g[0]), len(g[1]), len(g[2])] # number of param groups
|
|
985
|
+
g[2] = {"params": g[2], **optim_args, "param_group": "bias"}
|
|
986
|
+
g[0] = {"params": g[0], **optim_args, "weight_decay": decay, "param_group": "weight"}
|
|
987
|
+
g[1] = {"params": g[1], **optim_args, "weight_decay": 0.0, "param_group": "bn"}
|
|
988
|
+
muon, sgd = (0.1, 1.0) if iterations > 10000 else (0.5, 0.5) # scale factor for MuSGD
|
|
989
|
+
if use_muon:
|
|
990
|
+
num_params[0] = len(g[3]) # update number of params
|
|
991
|
+
g[3] = {"params": g[3], **optim_args, "weight_decay": decay, "use_muon": True, "param_group": "muon"}
|
|
992
|
+
import re
|
|
993
|
+
|
|
994
|
+
# higher lr for certain parameters in MuSGD when funetuning
|
|
995
|
+
pattern = re.compile(r"(?=.*23)(?=.*cv3)|proto\.semseg|flow_model")
|
|
996
|
+
g_ = [] # new param groups
|
|
997
|
+
for x in g:
|
|
998
|
+
p = x.pop("params")
|
|
999
|
+
p1 = [v for k, v in p.items() if pattern.search(k)]
|
|
1000
|
+
p2 = [v for k, v in p.items() if not pattern.search(k)]
|
|
1001
|
+
g_.extend([{"params": p1, **x, "lr": lr * 3}, {"params": p2, **x}])
|
|
1002
|
+
g = g_
|
|
1003
|
+
optimizer = getattr(optim, name, partial(MuSGD, muon=muon, sgd=sgd))(params=g)
|
|
1004
|
+
|
|
946
1005
|
LOGGER.info(
|
|
947
1006
|
f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
|
|
948
|
-
f"{
|
|
1007
|
+
f"{num_params[1]} weight(decay=0.0), {num_params[0]} weight(decay={decay}), {num_params[2]} bias(decay=0.0)"
|
|
949
1008
|
)
|
|
950
1009
|
return optimizer
|