dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/METADATA +64 -74
  2. dgenerate_ultralytics_headless-8.4.7.dist-info/RECORD +311 -0
  3. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -9
  5. tests/conftest.py +8 -15
  6. tests/test_cli.py +1 -1
  7. tests/test_cuda.py +13 -10
  8. tests/test_engine.py +9 -9
  9. tests/test_exports.py +65 -13
  10. tests/test_integrations.py +13 -13
  11. tests/test_python.py +125 -69
  12. tests/test_solutions.py +161 -152
  13. ultralytics/__init__.py +1 -1
  14. ultralytics/cfg/__init__.py +86 -92
  15. ultralytics/cfg/datasets/Argoverse.yaml +7 -6
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  17. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  18. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  19. ultralytics/cfg/datasets/TT100K.yaml +346 -0
  20. ultralytics/cfg/datasets/VOC.yaml +15 -16
  21. ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
  22. ultralytics/cfg/datasets/coco-pose.yaml +21 -0
  23. ultralytics/cfg/datasets/coco12-formats.yaml +101 -0
  24. ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
  25. ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
  26. ultralytics/cfg/datasets/dog-pose.yaml +28 -0
  27. ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
  28. ultralytics/cfg/datasets/dota8.yaml +2 -2
  29. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
  30. ultralytics/cfg/datasets/kitti.yaml +27 -0
  31. ultralytics/cfg/datasets/lvis.yaml +5 -5
  32. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  33. ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
  34. ultralytics/cfg/datasets/xView.yaml +16 -16
  35. ultralytics/cfg/default.yaml +4 -2
  36. ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
  37. ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
  38. ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
  39. ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
  40. ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
  41. ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
  42. ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
  43. ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
  44. ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
  45. ultralytics/cfg/models/26/yolo26.yaml +52 -0
  46. ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
  47. ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
  48. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
  49. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
  51. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
  52. ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
  53. ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
  54. ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
  55. ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
  56. ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
  57. ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
  58. ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
  59. ultralytics/cfg/models/v6/yolov6.yaml +1 -1
  60. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
  61. ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
  62. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
  63. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
  64. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
  65. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
  66. ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
  67. ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
  68. ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
  69. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
  70. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
  71. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
  72. ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
  73. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
  74. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  75. ultralytics/data/__init__.py +4 -4
  76. ultralytics/data/annotator.py +5 -6
  77. ultralytics/data/augment.py +300 -475
  78. ultralytics/data/base.py +18 -26
  79. ultralytics/data/build.py +147 -25
  80. ultralytics/data/converter.py +108 -87
  81. ultralytics/data/dataset.py +47 -75
  82. ultralytics/data/loaders.py +42 -49
  83. ultralytics/data/split.py +5 -6
  84. ultralytics/data/split_dota.py +8 -15
  85. ultralytics/data/utils.py +36 -45
  86. ultralytics/engine/exporter.py +351 -263
  87. ultralytics/engine/model.py +186 -225
  88. ultralytics/engine/predictor.py +45 -54
  89. ultralytics/engine/results.py +198 -325
  90. ultralytics/engine/trainer.py +165 -106
  91. ultralytics/engine/tuner.py +41 -43
  92. ultralytics/engine/validator.py +55 -38
  93. ultralytics/hub/__init__.py +16 -19
  94. ultralytics/hub/auth.py +6 -12
  95. ultralytics/hub/google/__init__.py +7 -10
  96. ultralytics/hub/session.py +15 -25
  97. ultralytics/hub/utils.py +5 -8
  98. ultralytics/models/__init__.py +1 -1
  99. ultralytics/models/fastsam/__init__.py +1 -1
  100. ultralytics/models/fastsam/model.py +8 -10
  101. ultralytics/models/fastsam/predict.py +18 -30
  102. ultralytics/models/fastsam/utils.py +1 -2
  103. ultralytics/models/fastsam/val.py +5 -7
  104. ultralytics/models/nas/__init__.py +1 -1
  105. ultralytics/models/nas/model.py +5 -8
  106. ultralytics/models/nas/predict.py +7 -9
  107. ultralytics/models/nas/val.py +1 -2
  108. ultralytics/models/rtdetr/__init__.py +1 -1
  109. ultralytics/models/rtdetr/model.py +5 -8
  110. ultralytics/models/rtdetr/predict.py +15 -19
  111. ultralytics/models/rtdetr/train.py +10 -13
  112. ultralytics/models/rtdetr/val.py +21 -23
  113. ultralytics/models/sam/__init__.py +15 -2
  114. ultralytics/models/sam/amg.py +14 -20
  115. ultralytics/models/sam/build.py +26 -19
  116. ultralytics/models/sam/build_sam3.py +377 -0
  117. ultralytics/models/sam/model.py +29 -32
  118. ultralytics/models/sam/modules/blocks.py +83 -144
  119. ultralytics/models/sam/modules/decoders.py +19 -37
  120. ultralytics/models/sam/modules/encoders.py +44 -101
  121. ultralytics/models/sam/modules/memory_attention.py +16 -30
  122. ultralytics/models/sam/modules/sam.py +200 -73
  123. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  124. ultralytics/models/sam/modules/transformer.py +18 -28
  125. ultralytics/models/sam/modules/utils.py +174 -50
  126. ultralytics/models/sam/predict.py +2248 -350
  127. ultralytics/models/sam/sam3/__init__.py +3 -0
  128. ultralytics/models/sam/sam3/decoder.py +546 -0
  129. ultralytics/models/sam/sam3/encoder.py +529 -0
  130. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  131. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  132. ultralytics/models/sam/sam3/model_misc.py +199 -0
  133. ultralytics/models/sam/sam3/necks.py +129 -0
  134. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  135. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  136. ultralytics/models/sam/sam3/vitdet.py +547 -0
  137. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  138. ultralytics/models/utils/loss.py +14 -26
  139. ultralytics/models/utils/ops.py +13 -17
  140. ultralytics/models/yolo/__init__.py +1 -1
  141. ultralytics/models/yolo/classify/predict.py +10 -13
  142. ultralytics/models/yolo/classify/train.py +12 -33
  143. ultralytics/models/yolo/classify/val.py +30 -29
  144. ultralytics/models/yolo/detect/predict.py +9 -12
  145. ultralytics/models/yolo/detect/train.py +17 -23
  146. ultralytics/models/yolo/detect/val.py +77 -59
  147. ultralytics/models/yolo/model.py +43 -60
  148. ultralytics/models/yolo/obb/predict.py +7 -16
  149. ultralytics/models/yolo/obb/train.py +14 -17
  150. ultralytics/models/yolo/obb/val.py +40 -37
  151. ultralytics/models/yolo/pose/__init__.py +1 -1
  152. ultralytics/models/yolo/pose/predict.py +7 -22
  153. ultralytics/models/yolo/pose/train.py +13 -16
  154. ultralytics/models/yolo/pose/val.py +39 -58
  155. ultralytics/models/yolo/segment/predict.py +17 -21
  156. ultralytics/models/yolo/segment/train.py +7 -10
  157. ultralytics/models/yolo/segment/val.py +95 -47
  158. ultralytics/models/yolo/world/train.py +8 -14
  159. ultralytics/models/yolo/world/train_world.py +11 -34
  160. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  161. ultralytics/models/yolo/yoloe/predict.py +16 -23
  162. ultralytics/models/yolo/yoloe/train.py +36 -44
  163. ultralytics/models/yolo/yoloe/train_seg.py +11 -11
  164. ultralytics/models/yolo/yoloe/val.py +15 -20
  165. ultralytics/nn/__init__.py +7 -7
  166. ultralytics/nn/autobackend.py +159 -85
  167. ultralytics/nn/modules/__init__.py +68 -60
  168. ultralytics/nn/modules/activation.py +4 -6
  169. ultralytics/nn/modules/block.py +260 -224
  170. ultralytics/nn/modules/conv.py +52 -97
  171. ultralytics/nn/modules/head.py +831 -299
  172. ultralytics/nn/modules/transformer.py +76 -88
  173. ultralytics/nn/modules/utils.py +16 -21
  174. ultralytics/nn/tasks.py +180 -195
  175. ultralytics/nn/text_model.py +45 -69
  176. ultralytics/optim/__init__.py +5 -0
  177. ultralytics/optim/muon.py +338 -0
  178. ultralytics/solutions/__init__.py +12 -12
  179. ultralytics/solutions/ai_gym.py +13 -19
  180. ultralytics/solutions/analytics.py +15 -16
  181. ultralytics/solutions/config.py +6 -7
  182. ultralytics/solutions/distance_calculation.py +10 -13
  183. ultralytics/solutions/heatmap.py +8 -14
  184. ultralytics/solutions/instance_segmentation.py +6 -9
  185. ultralytics/solutions/object_blurrer.py +7 -10
  186. ultralytics/solutions/object_counter.py +12 -19
  187. ultralytics/solutions/object_cropper.py +8 -14
  188. ultralytics/solutions/parking_management.py +34 -32
  189. ultralytics/solutions/queue_management.py +10 -12
  190. ultralytics/solutions/region_counter.py +9 -12
  191. ultralytics/solutions/security_alarm.py +15 -20
  192. ultralytics/solutions/similarity_search.py +10 -15
  193. ultralytics/solutions/solutions.py +77 -76
  194. ultralytics/solutions/speed_estimation.py +7 -10
  195. ultralytics/solutions/streamlit_inference.py +2 -4
  196. ultralytics/solutions/templates/similarity-search.html +7 -18
  197. ultralytics/solutions/trackzone.py +7 -10
  198. ultralytics/solutions/vision_eye.py +5 -8
  199. ultralytics/trackers/__init__.py +1 -1
  200. ultralytics/trackers/basetrack.py +3 -5
  201. ultralytics/trackers/bot_sort.py +10 -27
  202. ultralytics/trackers/byte_tracker.py +21 -37
  203. ultralytics/trackers/track.py +4 -7
  204. ultralytics/trackers/utils/gmc.py +11 -22
  205. ultralytics/trackers/utils/kalman_filter.py +37 -48
  206. ultralytics/trackers/utils/matching.py +12 -15
  207. ultralytics/utils/__init__.py +124 -124
  208. ultralytics/utils/autobatch.py +2 -4
  209. ultralytics/utils/autodevice.py +17 -18
  210. ultralytics/utils/benchmarks.py +57 -71
  211. ultralytics/utils/callbacks/base.py +8 -10
  212. ultralytics/utils/callbacks/clearml.py +5 -13
  213. ultralytics/utils/callbacks/comet.py +32 -46
  214. ultralytics/utils/callbacks/dvc.py +13 -18
  215. ultralytics/utils/callbacks/mlflow.py +4 -5
  216. ultralytics/utils/callbacks/neptune.py +7 -15
  217. ultralytics/utils/callbacks/platform.py +423 -38
  218. ultralytics/utils/callbacks/raytune.py +3 -4
  219. ultralytics/utils/callbacks/tensorboard.py +25 -31
  220. ultralytics/utils/callbacks/wb.py +16 -14
  221. ultralytics/utils/checks.py +127 -85
  222. ultralytics/utils/cpu.py +3 -8
  223. ultralytics/utils/dist.py +9 -12
  224. ultralytics/utils/downloads.py +25 -33
  225. ultralytics/utils/errors.py +6 -14
  226. ultralytics/utils/events.py +2 -4
  227. ultralytics/utils/export/__init__.py +4 -236
  228. ultralytics/utils/export/engine.py +246 -0
  229. ultralytics/utils/export/imx.py +117 -63
  230. ultralytics/utils/export/tensorflow.py +231 -0
  231. ultralytics/utils/files.py +26 -30
  232. ultralytics/utils/git.py +9 -11
  233. ultralytics/utils/instance.py +30 -51
  234. ultralytics/utils/logger.py +212 -114
  235. ultralytics/utils/loss.py +601 -215
  236. ultralytics/utils/metrics.py +128 -156
  237. ultralytics/utils/nms.py +13 -16
  238. ultralytics/utils/ops.py +117 -166
  239. ultralytics/utils/patches.py +75 -21
  240. ultralytics/utils/plotting.py +75 -80
  241. ultralytics/utils/tal.py +125 -59
  242. ultralytics/utils/torch_utils.py +53 -79
  243. ultralytics/utils/tqdm.py +24 -21
  244. ultralytics/utils/triton.py +13 -19
  245. ultralytics/utils/tuner.py +19 -10
  246. dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
  247. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/entry_points.txt +0 -0
  248. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/licenses/LICENSE +0 -0
  249. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/top_level.txt +0 -0
@@ -3,9 +3,11 @@
3
3
  Train a model on a dataset.
4
4
 
5
5
  Usage:
6
- $ yolo mode=train model=yolo11n.pt data=coco8.yaml imgsz=640 epochs=100 batch=16
6
+ $ yolo mode=train model=yolo26n.pt data=coco8.yaml imgsz=640 epochs=100 batch=16
7
7
  """
8
8
 
9
+ from __future__ import annotations
10
+
9
11
  import gc
10
12
  import math
11
13
  import os
@@ -14,6 +16,7 @@ import time
14
16
  import warnings
15
17
  from copy import copy, deepcopy
16
18
  from datetime import datetime, timedelta
19
+ from functools import partial
17
20
  from pathlib import Path
18
21
 
19
22
  import numpy as np
@@ -25,6 +28,7 @@ from ultralytics import __version__
25
28
  from ultralytics.cfg import get_cfg, get_save_dir
26
29
  from ultralytics.data.utils import check_cls_dataset, check_det_dataset
27
30
  from ultralytics.nn.tasks import load_checkpoint
31
+ from ultralytics.optim import MuSGD
28
32
  from ultralytics.utils import (
29
33
  DEFAULT_CFG,
30
34
  GIT,
@@ -61,8 +65,7 @@ from ultralytics.utils.torch_utils import (
61
65
 
62
66
 
63
67
  class BaseTrainer:
64
- """
65
- A base class for creating trainers.
68
+ """A base class for creating trainers.
66
69
 
67
70
  This class provides the foundation for training YOLO models, handling the training loop, validation, checkpointing,
68
71
  and various training utilities. It supports both single-GPU and multi-GPU distributed training.
@@ -112,8 +115,7 @@ class BaseTrainer:
112
115
  """
113
116
 
114
117
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
115
- """
116
- Initialize the BaseTrainer class.
118
+ """Initialize the BaseTrainer class.
117
119
 
118
120
  Args:
119
121
  cfg (str, optional): Path to a configuration file.
@@ -138,7 +140,12 @@ class BaseTrainer:
138
140
  if RANK in {-1, 0}:
139
141
  self.wdir.mkdir(parents=True, exist_ok=True) # make dir
140
142
  self.args.save_dir = str(self.save_dir)
141
- YAML.save(self.save_dir / "args.yaml", vars(self.args)) # save run args
143
+ # Save run args, serializing augmentations as reprs for resume compatibility
144
+ args_dict = vars(self.args).copy()
145
+ if args_dict.get("augmentations") is not None:
146
+ # Serialize Albumentations transforms as their repr strings for checkpoint compatibility
147
+ args_dict["augmentations"] = [repr(t) for t in args_dict["augmentations"]]
148
+ YAML.save(self.save_dir / "args.yaml", args_dict) # save run args
142
149
  self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt" # checkpoint paths
143
150
  self.save_period = self.args.save_period
144
151
 
@@ -152,8 +159,29 @@ class BaseTrainer:
152
159
  if self.device.type in {"cpu", "mps"}:
153
160
  self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
154
161
 
162
+ # Callbacks - initialize early so on_pretrain_routine_start can capture original args.data
163
+ self.callbacks = _callbacks or callbacks.get_default_callbacks()
164
+
165
+ if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3'
166
+ world_size = len(self.args.device.split(","))
167
+ elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
168
+ world_size = len(self.args.device)
169
+ elif self.args.device in {"cpu", "mps"}: # i.e. device='cpu' or 'mps'
170
+ world_size = 0
171
+ elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number
172
+ world_size = 1 # default to device 0
173
+ else: # i.e. device=None or device=''
174
+ world_size = 0
175
+
176
+ self.ddp = world_size > 1 and "LOCAL_RANK" not in os.environ
177
+ self.world_size = world_size
178
+ # Run on_pretrain_routine_start before get_dataset() to capture original args.data (e.g., ul:// URIs)
179
+ if RANK in {-1, 0} and not self.ddp:
180
+ callbacks.add_integration_callbacks(self)
181
+ self.run_callbacks("on_pretrain_routine_start")
182
+
155
183
  # Model and Dataset
156
- self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolo11n -> yolo11n.pt
184
+ self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolo26n -> yolo26n.pt
157
185
  with torch_distributed_zero_first(LOCAL_RANK): # avoid auto-downloading dataset multiple times
158
186
  self.data = self.get_dataset()
159
187
 
@@ -175,28 +203,6 @@ class BaseTrainer:
175
203
  self.plot_idx = [0, 1, 2]
176
204
  self.nan_recovery_attempts = 0
177
205
 
178
- # Callbacks
179
- self.callbacks = _callbacks or callbacks.get_default_callbacks()
180
-
181
- if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3'
182
- world_size = len(self.args.device.split(","))
183
- elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
184
- world_size = len(self.args.device)
185
- elif self.args.device in {"cpu", "mps"}: # i.e. device='cpu' or 'mps'
186
- world_size = 0
187
- elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number
188
- world_size = 1 # default to device 0
189
- else: # i.e. device=None or device=''
190
- world_size = 0
191
-
192
- self.ddp = world_size > 1 and "LOCAL_RANK" not in os.environ
193
- self.world_size = world_size
194
- # Run subprocess if DDP training, else train normally
195
- if RANK in {-1, 0} and not self.ddp:
196
- callbacks.add_integration_callbacks(self)
197
- # Start console logging immediately at trainer initialization
198
- self.run_callbacks("on_pretrain_routine_start")
199
-
200
206
  def add_callback(self, event: str, callback):
201
207
  """Append the given callback to the event's callback list."""
202
208
  self.callbacks[event].append(callback)
@@ -318,18 +324,18 @@ class BaseTrainer:
318
324
  self.train_loader = self.get_dataloader(
319
325
  self.data["train"], batch_size=batch_size, rank=LOCAL_RANK, mode="train"
320
326
  )
327
+ # Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
328
+ self.test_loader = self.get_dataloader(
329
+ self.data.get("val") or self.data.get("test"),
330
+ batch_size=batch_size if self.args.task == "obb" else batch_size * 2,
331
+ rank=LOCAL_RANK,
332
+ mode="val",
333
+ )
334
+ self.validator = self.get_validator()
335
+ self.ema = ModelEMA(self.model)
321
336
  if RANK in {-1, 0}:
322
- # Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
323
- self.test_loader = self.get_dataloader(
324
- self.data.get("val") or self.data.get("test"),
325
- batch_size=batch_size if self.args.task == "obb" else batch_size * 2,
326
- rank=-1,
327
- mode="val",
328
- )
329
- self.validator = self.get_validator()
330
337
  metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val")
331
338
  self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
332
- self.ema = ModelEMA(self.model)
333
339
  if self.args.plots:
334
340
  self.plot_training_labels()
335
341
 
@@ -403,10 +409,15 @@ class BaseTrainer:
403
409
  if ni <= nw:
404
410
  xi = [0, nw] # x interp
405
411
  self.accumulate = max(1, int(np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round()))
406
- for j, x in enumerate(self.optimizer.param_groups):
412
+ for x in self.optimizer.param_groups:
407
413
  # Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
408
414
  x["lr"] = np.interp(
409
- ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x["initial_lr"] * self.lf(epoch)]
415
+ ni,
416
+ xi,
417
+ [
418
+ self.args.warmup_bias_lr if x.get("param_group") == "bias" else 0.0,
419
+ x["initial_lr"] * self.lf(epoch),
420
+ ],
410
421
  )
411
422
  if "momentum" in x:
412
423
  x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
@@ -460,17 +471,20 @@ class BaseTrainer:
460
471
 
461
472
  self.run_callbacks("on_train_batch_end")
462
473
 
474
+ if hasattr(unwrap_model(self.model).criterion, "update"):
475
+ unwrap_model(self.model).criterion.update()
476
+
463
477
  self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
464
478
 
465
479
  self.run_callbacks("on_train_epoch_end")
466
480
  if RANK in {-1, 0}:
467
- final_epoch = epoch + 1 >= self.epochs
468
481
  self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
469
482
 
470
- # Validation
471
- if self.args.val or final_epoch or self.stopper.possible_stop or self.stop:
472
- self._clear_memory(threshold=0.5) # prevent VRAM spike
473
- self.metrics, self.fitness = self.validate()
483
+ # Validation
484
+ final_epoch = epoch + 1 >= self.epochs
485
+ if self.args.val or final_epoch or self.stopper.possible_stop or self.stop:
486
+ self._clear_memory(threshold=0.5) # prevent VRAM spike
487
+ self.metrics, self.fitness = self.validate()
474
488
 
475
489
  # NaN recovery
476
490
  if self._handle_nan_recovery(epoch):
@@ -510,11 +524,11 @@ class BaseTrainer:
510
524
  break # must break all DDP ranks
511
525
  epoch += 1
512
526
 
527
+ seconds = time.time() - self.train_time_start
528
+ LOGGER.info(f"\n{epoch - self.start_epoch + 1} epochs completed in {seconds / 3600:.3f} hours.")
529
+ # Do final val with best.pt
530
+ self.final_eval()
513
531
  if RANK in {-1, 0}:
514
- # Do final val with best.pt
515
- seconds = time.time() - self.train_time_start
516
- LOGGER.info(f"\n{epoch - self.start_epoch + 1} epochs completed in {seconds / 3600:.3f} hours.")
517
- self.final_eval()
518
532
  if self.args.plots:
519
533
  self.plot_metrics()
520
534
  self.run_callbacks("on_train_end")
@@ -545,7 +559,7 @@ class BaseTrainer:
545
559
  total = torch.cuda.get_device_properties(self.device).total_memory
546
560
  return ((memory / total) if total > 0 else 0) if fraction else (memory / 2**30)
547
561
 
548
- def _clear_memory(self, threshold: float = None):
562
+ def _clear_memory(self, threshold: float | None = None):
549
563
  """Clear accelerator memory by calling garbage collector and emptying cache."""
550
564
  if threshold:
551
565
  assert 0 <= threshold <= 1, "Threshold must be between 0 and 1."
@@ -618,25 +632,26 @@ class BaseTrainer:
618
632
  (self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'
619
633
 
620
634
  def get_dataset(self):
621
- """
622
- Get train and validation datasets from data dictionary.
635
+ """Get train and validation datasets from data dictionary.
623
636
 
624
637
  Returns:
625
638
  (dict): A dictionary containing the training/validation/test dataset and category names.
626
639
  """
627
640
  try:
628
- if self.args.task == "classify":
629
- data = check_cls_dataset(self.args.data)
630
- elif self.args.data.rsplit(".", 1)[-1] == "ndjson":
631
- # Convert NDJSON to YOLO format
641
+ # Convert ul:// platform URIs and NDJSON files to local dataset format first
642
+ data_str = str(self.args.data)
643
+ if data_str.endswith(".ndjson") or (data_str.startswith("ul://") and "/datasets/" in data_str):
632
644
  import asyncio
633
645
 
634
646
  from ultralytics.data.converter import convert_ndjson_to_yolo
647
+ from ultralytics.utils.checks import check_file
635
648
 
636
- yaml_path = asyncio.run(convert_ndjson_to_yolo(self.args.data))
637
- self.args.data = str(yaml_path)
638
- data = check_det_dataset(self.args.data)
639
- elif self.args.data.rsplit(".", 1)[-1] in {"yaml", "yml"} or self.args.task in {
649
+ self.args.data = str(asyncio.run(convert_ndjson_to_yolo(check_file(self.args.data))))
650
+
651
+ # Task-specific dataset checking
652
+ if self.args.task == "classify":
653
+ data = check_cls_dataset(self.args.data)
654
+ elif str(self.args.data).rsplit(".", 1)[-1] in {"yaml", "yml"} or self.args.task in {
640
655
  "detect",
641
656
  "segment",
642
657
  "pose",
@@ -654,8 +669,7 @@ class BaseTrainer:
654
669
  return data
655
670
 
656
671
  def setup_model(self):
657
- """
658
- Load, create, or download model for any task.
672
+ """Load, create, or download model for any task.
659
673
 
660
674
  Returns:
661
675
  (dict): Optional checkpoint to resume training from.
@@ -688,14 +702,19 @@ class BaseTrainer:
688
702
  return batch
689
703
 
690
704
  def validate(self):
691
- """
692
- Run validation on val set using self.validator.
705
+ """Run validation on val set using self.validator.
693
706
 
694
707
  Returns:
695
708
  metrics (dict): Dictionary of validation metrics.
696
709
  fitness (float): Fitness score for the validation.
697
710
  """
711
+ if self.ema and self.world_size > 1:
712
+ # Sync EMA buffers from rank 0 to all ranks
713
+ for buffer in self.ema.ema.buffers():
714
+ dist.broadcast(buffer, src=0)
698
715
  metrics = self.validator(self)
716
+ if metrics is None:
717
+ return None, None
699
718
  fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
700
719
  if not self.best_fitness or self.best_fitness < fitness:
701
720
  self.best_fitness = fitness
@@ -706,11 +725,11 @@ class BaseTrainer:
706
725
  raise NotImplementedError("This task trainer doesn't support loading cfg files")
707
726
 
708
727
  def get_validator(self):
709
- """Return a NotImplementedError when the get_validator function is called."""
728
+ """Raise NotImplementedError (must be implemented by subclasses)."""
710
729
  raise NotImplementedError("get_validator function not implemented in trainer")
711
730
 
712
731
  def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
713
- """Return dataloader derived from torch.data.Dataloader."""
732
+ """Raise NotImplementedError (must return a `torch.utils.data.DataLoader` in subclasses)."""
714
733
  raise NotImplementedError("get_dataloader function not implemented in trainer")
715
734
 
716
735
  def build_dataset(self, img_path, mode="train", batch=None):
@@ -718,10 +737,9 @@ class BaseTrainer:
718
737
  raise NotImplementedError("build_dataset function not implemented in trainer")
719
738
 
720
739
  def label_loss_items(self, loss_items=None, prefix="train"):
721
- """
722
- Return a loss dict with labelled training loss items tensor.
740
+ """Return a loss dict with labeled training loss items tensor.
723
741
 
724
- Note:
742
+ Notes:
725
743
  This is not needed for classification but necessary for segmentation & detection
726
744
  """
727
745
  return {"loss": loss_items} if loss_items is not None else ["loss"]
@@ -753,9 +771,9 @@ class BaseTrainer:
753
771
  n = len(metrics) + 2 # number of cols
754
772
  t = time.time() - self.train_time_start
755
773
  self.csv.parent.mkdir(parents=True, exist_ok=True) # ensure parent directory exists
756
- s = "" if self.csv.exists() else (("%s," * n % tuple(["epoch", "time"] + keys)).rstrip(",") + "\n") # header
774
+ s = "" if self.csv.exists() else ("%s," * n % ("epoch", "time", *keys)).rstrip(",") + "\n"
757
775
  with open(self.csv, "a", encoding="utf-8") as f:
758
- f.write(s + ("%.6g," * n % tuple([self.epoch + 1, t] + vals)).rstrip(",") + "\n")
776
+ f.write(s + ("%.6g," * n % (self.epoch + 1, t, *vals)).rstrip(",") + "\n")
759
777
 
760
778
  def plot_metrics(self):
761
779
  """Plot metrics from a CSV file."""
@@ -768,20 +786,20 @@ class BaseTrainer:
768
786
 
769
787
  def final_eval(self):
770
788
  """Perform final evaluation and validation for object detection YOLO model."""
771
- ckpt = {}
772
- for f in self.last, self.best:
773
- if f.exists():
774
- if f is self.last:
775
- ckpt = strip_optimizer(f)
776
- elif f is self.best:
777
- k = "train_results" # update best.pt train_metrics from last.pt
778
- strip_optimizer(f, updates={k: ckpt[k]} if k in ckpt else None)
779
- LOGGER.info(f"\nValidating {f}...")
780
- self.validator.args.plots = self.args.plots
781
- self.validator.args.compile = False # disable final val compile as too slow
782
- self.metrics = self.validator(model=f)
783
- self.metrics.pop("fitness", None)
784
- self.run_callbacks("on_fit_epoch_end")
789
+ model = self.best if self.best.exists() else None
790
+ with torch_distributed_zero_first(LOCAL_RANK): # strip only on GPU 0; other GPUs should wait
791
+ if RANK in {-1, 0}:
792
+ ckpt = strip_optimizer(self.last) if self.last.exists() else {}
793
+ if model:
794
+ # update best.pt train_metrics from last.pt
795
+ strip_optimizer(self.best, updates={"train_results": ckpt.get("train_results")})
796
+ if model:
797
+ LOGGER.info(f"\nValidating {model}...")
798
+ self.validator.args.plots = self.args.plots
799
+ self.validator.args.compile = False # disable final val compile as too slow
800
+ self.metrics = self.validator(model=model)
801
+ self.metrics.pop("fitness", None)
802
+ self.run_callbacks("on_fit_epoch_end")
785
803
 
786
804
  def check_resume(self, overrides):
787
805
  """Check if resume checkpoint exists and update arguments accordingly."""
@@ -804,10 +822,29 @@ class BaseTrainer:
804
822
  "batch",
805
823
  "device",
806
824
  "close_mosaic",
825
+ "augmentations",
826
+ "save_period",
827
+ "workers",
828
+ "cache",
829
+ "patience",
830
+ "time",
831
+ "freeze",
832
+ "val",
833
+ "plots",
807
834
  ): # allow arg updates to reduce memory or update device on resume
808
835
  if k in overrides:
809
836
  setattr(self.args, k, overrides[k])
810
837
 
838
+ # Handle augmentations parameter for resume: check if user provided custom augmentations
839
+ if ckpt_args.get("augmentations") is not None:
840
+ # Augmentations were saved in checkpoint as reprs but can't be restored automatically
841
+ LOGGER.warning(
842
+ "Custom Albumentations transforms were used in the original training run but are not "
843
+ "being restored. To preserve custom augmentations when resuming, you need to pass the "
844
+ "'augmentations' parameter again to get expected results. Example: \n"
845
+ f"model.train(resume=True, augmentations={ckpt_args['augmentations']})"
846
+ )
847
+
811
848
  except Exception as e:
812
849
  raise FileNotFoundError(
813
850
  "Resume checkpoint not found. Please pass a valid checkpoint to resume from, "
@@ -887,23 +924,21 @@ class BaseTrainer:
887
924
  self.train_loader.dataset.close_mosaic(hyp=copy(self.args))
888
925
 
889
926
  def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
890
- """
891
- Construct an optimizer for the given model.
927
+ """Construct an optimizer for the given model.
892
928
 
893
929
  Args:
894
930
  model (torch.nn.Module): The model for which to build an optimizer.
895
- name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected
896
- based on the number of iterations.
931
+ name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected based on the
932
+ number of iterations.
897
933
  lr (float, optional): The learning rate for the optimizer.
898
934
  momentum (float, optional): The momentum factor for the optimizer.
899
935
  decay (float, optional): The weight decay for the optimizer.
900
- iterations (float, optional): The number of iterations, which determines the optimizer if
901
- name is 'auto'.
936
+ iterations (float, optional): The number of iterations, which determines the optimizer if name is 'auto'.
902
937
 
903
938
  Returns:
904
939
  (torch.optim.Optimizer): The constructed optimizer.
905
940
  """
906
- g = [], [], [] # optimizer parameter groups
941
+ g = [{}, {}, {}, {}] # optimizer parameter groups
907
942
  bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d()
908
943
  if name == "auto":
909
944
  LOGGER.info(
@@ -913,38 +948,62 @@ class BaseTrainer:
913
948
  )
914
949
  nc = self.data.get("nc", 10) # number of classes
915
950
  lr_fit = round(0.002 * 5 / (4 + nc), 6) # lr0 fit equation to 6 decimal places
916
- name, lr, momentum = ("SGD", 0.01, 0.9) if iterations > 10000 else ("AdamW", lr_fit, 0.9)
951
+ name, lr, momentum = ("MuSGD", 0.01 if iterations > 10000 else lr_fit, 0.9)
917
952
  self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for Adam
918
953
 
919
- for module_name, module in model.named_modules():
954
+ use_muon = name == "MuSGD"
955
+ for module_name, module in unwrap_model(model).named_modules():
920
956
  for param_name, param in module.named_parameters(recurse=False):
921
957
  fullname = f"{module_name}.{param_name}" if module_name else param_name
922
- if "bias" in fullname: # bias (no decay)
923
- g[2].append(param)
958
+ if param.ndim >= 2 and use_muon:
959
+ g[3][fullname] = param # muon params
960
+ elif "bias" in fullname: # bias (no decay)
961
+ g[2][fullname] = param
924
962
  elif isinstance(module, bn) or "logit_scale" in fullname: # weight (no decay)
925
963
  # ContrastiveHead and BNContrastiveHead included here with 'logit_scale'
926
- g[1].append(param)
964
+ g[1][fullname] = param
927
965
  else: # weight (with decay)
928
- g[0].append(param)
966
+ g[0][fullname] = param
967
+ if not use_muon:
968
+ g = [x.values() for x in g[:3]] # convert to list of params
929
969
 
930
- optimizers = {"Adam", "Adamax", "AdamW", "NAdam", "RAdam", "RMSProp", "SGD", "auto"}
970
+ optimizers = {"Adam", "Adamax", "AdamW", "NAdam", "RAdam", "RMSProp", "SGD", "MuSGD", "auto"}
931
971
  name = {x.lower(): x for x in optimizers}.get(name.lower())
932
972
  if name in {"Adam", "Adamax", "AdamW", "NAdam", "RAdam"}:
933
- optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
973
+ optim_args = dict(lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
934
974
  elif name == "RMSProp":
935
- optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
936
- elif name == "SGD":
937
- optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
975
+ optim_args = dict(lr=lr, momentum=momentum)
976
+ elif name == "SGD" or name == "MuSGD":
977
+ optim_args = dict(lr=lr, momentum=momentum, nesterov=True)
938
978
  else:
939
979
  raise NotImplementedError(
940
980
  f"Optimizer '{name}' not found in list of available optimizers {optimizers}. "
941
981
  "Request support for addition optimizers at https://github.com/ultralytics/ultralytics."
942
982
  )
943
983
 
944
- optimizer.add_param_group({"params": g[0], "weight_decay": decay}) # add g0 with weight_decay
945
- optimizer.add_param_group({"params": g[1], "weight_decay": 0.0}) # add g1 (BatchNorm2d weights)
984
+ num_params = [len(g[0]), len(g[1]), len(g[2])] # number of param groups
985
+ g[2] = {"params": g[2], **optim_args, "param_group": "bias"}
986
+ g[0] = {"params": g[0], **optim_args, "weight_decay": decay, "param_group": "weight"}
987
+ g[1] = {"params": g[1], **optim_args, "weight_decay": 0.0, "param_group": "bn"}
988
+ muon, sgd = (0.1, 1.0) if iterations > 10000 else (0.5, 0.5) # scale factor for MuSGD
989
+ if use_muon:
990
+ num_params[0] = len(g[3]) # update number of params
991
+ g[3] = {"params": g[3], **optim_args, "weight_decay": decay, "use_muon": True, "param_group": "muon"}
992
+ import re
993
+
994
+ # higher lr for certain parameters in MuSGD when funetuning
995
+ pattern = re.compile(r"(?=.*23)(?=.*cv3)|proto\.semseg|flow_model")
996
+ g_ = [] # new param groups
997
+ for x in g:
998
+ p = x.pop("params")
999
+ p1 = [v for k, v in p.items() if pattern.search(k)]
1000
+ p2 = [v for k, v in p.items() if not pattern.search(k)]
1001
+ g_.extend([{"params": p1, **x, "lr": lr * 3}, {"params": p2, **x}])
1002
+ g = g_
1003
+ optimizer = getattr(optim, name, partial(MuSGD, muon=muon, sgd=sgd))(params=g)
1004
+
946
1005
  LOGGER.info(
947
1006
  f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
948
- f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)"
1007
+ f"{num_params[1]} weight(decay=0.0), {num_params[0]} weight(decay={decay}), {num_params[2]} bias(decay=0.0)"
949
1008
  )
950
1009
  return optimizer