dgenerate-ultralytics-headless 8.3.196__py3-none-any.whl → 8.3.248__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (243) hide show
  1. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +33 -34
  2. dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
  3. tests/__init__.py +5 -7
  4. tests/conftest.py +8 -15
  5. tests/test_cli.py +8 -10
  6. tests/test_cuda.py +9 -10
  7. tests/test_engine.py +29 -2
  8. tests/test_exports.py +69 -21
  9. tests/test_integrations.py +8 -11
  10. tests/test_python.py +109 -71
  11. tests/test_solutions.py +170 -159
  12. ultralytics/__init__.py +27 -9
  13. ultralytics/cfg/__init__.py +57 -64
  14. ultralytics/cfg/datasets/Argoverse.yaml +7 -6
  15. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  16. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  17. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  18. ultralytics/cfg/datasets/Objects365.yaml +19 -15
  19. ultralytics/cfg/datasets/SKU-110K.yaml +1 -1
  20. ultralytics/cfg/datasets/VOC.yaml +19 -21
  21. ultralytics/cfg/datasets/VisDrone.yaml +5 -5
  22. ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
  23. ultralytics/cfg/datasets/coco-pose.yaml +24 -2
  24. ultralytics/cfg/datasets/coco.yaml +2 -2
  25. ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
  26. ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
  27. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  28. ultralytics/cfg/datasets/dog-pose.yaml +28 -0
  29. ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
  30. ultralytics/cfg/datasets/dota8.yaml +2 -2
  31. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
  32. ultralytics/cfg/datasets/kitti.yaml +27 -0
  33. ultralytics/cfg/datasets/lvis.yaml +7 -7
  34. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  35. ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
  36. ultralytics/cfg/datasets/xView.yaml +16 -16
  37. ultralytics/cfg/default.yaml +96 -94
  38. ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
  39. ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
  40. ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
  41. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
  42. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
  43. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
  44. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
  45. ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
  46. ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
  47. ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
  48. ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
  49. ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
  50. ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
  51. ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
  52. ultralytics/cfg/models/v6/yolov6.yaml +1 -1
  53. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
  54. ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
  55. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
  56. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
  57. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
  58. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
  59. ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
  60. ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
  61. ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
  62. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
  63. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
  64. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
  65. ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
  66. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
  67. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  68. ultralytics/cfg/trackers/botsort.yaml +16 -17
  69. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  70. ultralytics/data/__init__.py +4 -4
  71. ultralytics/data/annotator.py +3 -4
  72. ultralytics/data/augment.py +286 -476
  73. ultralytics/data/base.py +18 -26
  74. ultralytics/data/build.py +151 -26
  75. ultralytics/data/converter.py +38 -50
  76. ultralytics/data/dataset.py +47 -75
  77. ultralytics/data/loaders.py +42 -49
  78. ultralytics/data/split.py +5 -6
  79. ultralytics/data/split_dota.py +8 -15
  80. ultralytics/data/utils.py +41 -45
  81. ultralytics/engine/exporter.py +462 -462
  82. ultralytics/engine/model.py +150 -191
  83. ultralytics/engine/predictor.py +30 -40
  84. ultralytics/engine/results.py +177 -311
  85. ultralytics/engine/trainer.py +193 -120
  86. ultralytics/engine/tuner.py +77 -63
  87. ultralytics/engine/validator.py +39 -22
  88. ultralytics/hub/__init__.py +16 -19
  89. ultralytics/hub/auth.py +6 -12
  90. ultralytics/hub/google/__init__.py +7 -10
  91. ultralytics/hub/session.py +15 -25
  92. ultralytics/hub/utils.py +5 -8
  93. ultralytics/models/__init__.py +1 -1
  94. ultralytics/models/fastsam/__init__.py +1 -1
  95. ultralytics/models/fastsam/model.py +8 -10
  96. ultralytics/models/fastsam/predict.py +19 -30
  97. ultralytics/models/fastsam/utils.py +1 -2
  98. ultralytics/models/fastsam/val.py +5 -7
  99. ultralytics/models/nas/__init__.py +1 -1
  100. ultralytics/models/nas/model.py +5 -8
  101. ultralytics/models/nas/predict.py +7 -9
  102. ultralytics/models/nas/val.py +1 -2
  103. ultralytics/models/rtdetr/__init__.py +1 -1
  104. ultralytics/models/rtdetr/model.py +7 -8
  105. ultralytics/models/rtdetr/predict.py +15 -19
  106. ultralytics/models/rtdetr/train.py +10 -13
  107. ultralytics/models/rtdetr/val.py +21 -23
  108. ultralytics/models/sam/__init__.py +15 -2
  109. ultralytics/models/sam/amg.py +14 -20
  110. ultralytics/models/sam/build.py +26 -19
  111. ultralytics/models/sam/build_sam3.py +377 -0
  112. ultralytics/models/sam/model.py +29 -32
  113. ultralytics/models/sam/modules/blocks.py +83 -144
  114. ultralytics/models/sam/modules/decoders.py +22 -40
  115. ultralytics/models/sam/modules/encoders.py +44 -101
  116. ultralytics/models/sam/modules/memory_attention.py +16 -30
  117. ultralytics/models/sam/modules/sam.py +206 -79
  118. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  119. ultralytics/models/sam/modules/transformer.py +18 -28
  120. ultralytics/models/sam/modules/utils.py +174 -50
  121. ultralytics/models/sam/predict.py +2268 -366
  122. ultralytics/models/sam/sam3/__init__.py +3 -0
  123. ultralytics/models/sam/sam3/decoder.py +546 -0
  124. ultralytics/models/sam/sam3/encoder.py +529 -0
  125. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  126. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  127. ultralytics/models/sam/sam3/model_misc.py +199 -0
  128. ultralytics/models/sam/sam3/necks.py +129 -0
  129. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  130. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  131. ultralytics/models/sam/sam3/vitdet.py +547 -0
  132. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  133. ultralytics/models/utils/loss.py +14 -26
  134. ultralytics/models/utils/ops.py +13 -17
  135. ultralytics/models/yolo/__init__.py +1 -1
  136. ultralytics/models/yolo/classify/predict.py +9 -12
  137. ultralytics/models/yolo/classify/train.py +15 -41
  138. ultralytics/models/yolo/classify/val.py +34 -32
  139. ultralytics/models/yolo/detect/predict.py +8 -11
  140. ultralytics/models/yolo/detect/train.py +13 -32
  141. ultralytics/models/yolo/detect/val.py +75 -63
  142. ultralytics/models/yolo/model.py +37 -53
  143. ultralytics/models/yolo/obb/predict.py +5 -14
  144. ultralytics/models/yolo/obb/train.py +11 -14
  145. ultralytics/models/yolo/obb/val.py +42 -39
  146. ultralytics/models/yolo/pose/__init__.py +1 -1
  147. ultralytics/models/yolo/pose/predict.py +7 -22
  148. ultralytics/models/yolo/pose/train.py +10 -22
  149. ultralytics/models/yolo/pose/val.py +40 -59
  150. ultralytics/models/yolo/segment/predict.py +16 -20
  151. ultralytics/models/yolo/segment/train.py +3 -12
  152. ultralytics/models/yolo/segment/val.py +106 -56
  153. ultralytics/models/yolo/world/train.py +12 -16
  154. ultralytics/models/yolo/world/train_world.py +11 -34
  155. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  156. ultralytics/models/yolo/yoloe/predict.py +16 -23
  157. ultralytics/models/yolo/yoloe/train.py +31 -56
  158. ultralytics/models/yolo/yoloe/train_seg.py +5 -10
  159. ultralytics/models/yolo/yoloe/val.py +16 -21
  160. ultralytics/nn/__init__.py +7 -7
  161. ultralytics/nn/autobackend.py +152 -80
  162. ultralytics/nn/modules/__init__.py +60 -60
  163. ultralytics/nn/modules/activation.py +4 -6
  164. ultralytics/nn/modules/block.py +133 -217
  165. ultralytics/nn/modules/conv.py +52 -97
  166. ultralytics/nn/modules/head.py +64 -116
  167. ultralytics/nn/modules/transformer.py +79 -89
  168. ultralytics/nn/modules/utils.py +16 -21
  169. ultralytics/nn/tasks.py +111 -156
  170. ultralytics/nn/text_model.py +40 -67
  171. ultralytics/solutions/__init__.py +12 -12
  172. ultralytics/solutions/ai_gym.py +11 -17
  173. ultralytics/solutions/analytics.py +15 -16
  174. ultralytics/solutions/config.py +5 -6
  175. ultralytics/solutions/distance_calculation.py +10 -13
  176. ultralytics/solutions/heatmap.py +7 -13
  177. ultralytics/solutions/instance_segmentation.py +5 -8
  178. ultralytics/solutions/object_blurrer.py +7 -10
  179. ultralytics/solutions/object_counter.py +12 -19
  180. ultralytics/solutions/object_cropper.py +8 -14
  181. ultralytics/solutions/parking_management.py +33 -31
  182. ultralytics/solutions/queue_management.py +10 -12
  183. ultralytics/solutions/region_counter.py +9 -12
  184. ultralytics/solutions/security_alarm.py +15 -20
  185. ultralytics/solutions/similarity_search.py +13 -17
  186. ultralytics/solutions/solutions.py +75 -74
  187. ultralytics/solutions/speed_estimation.py +7 -10
  188. ultralytics/solutions/streamlit_inference.py +4 -7
  189. ultralytics/solutions/templates/similarity-search.html +7 -18
  190. ultralytics/solutions/trackzone.py +7 -10
  191. ultralytics/solutions/vision_eye.py +5 -8
  192. ultralytics/trackers/__init__.py +1 -1
  193. ultralytics/trackers/basetrack.py +3 -5
  194. ultralytics/trackers/bot_sort.py +10 -27
  195. ultralytics/trackers/byte_tracker.py +14 -30
  196. ultralytics/trackers/track.py +3 -6
  197. ultralytics/trackers/utils/gmc.py +11 -22
  198. ultralytics/trackers/utils/kalman_filter.py +37 -48
  199. ultralytics/trackers/utils/matching.py +12 -15
  200. ultralytics/utils/__init__.py +116 -116
  201. ultralytics/utils/autobatch.py +2 -4
  202. ultralytics/utils/autodevice.py +17 -18
  203. ultralytics/utils/benchmarks.py +70 -70
  204. ultralytics/utils/callbacks/base.py +8 -10
  205. ultralytics/utils/callbacks/clearml.py +5 -13
  206. ultralytics/utils/callbacks/comet.py +32 -46
  207. ultralytics/utils/callbacks/dvc.py +13 -18
  208. ultralytics/utils/callbacks/mlflow.py +4 -5
  209. ultralytics/utils/callbacks/neptune.py +7 -15
  210. ultralytics/utils/callbacks/platform.py +314 -38
  211. ultralytics/utils/callbacks/raytune.py +3 -4
  212. ultralytics/utils/callbacks/tensorboard.py +23 -31
  213. ultralytics/utils/callbacks/wb.py +10 -13
  214. ultralytics/utils/checks.py +151 -87
  215. ultralytics/utils/cpu.py +3 -8
  216. ultralytics/utils/dist.py +19 -15
  217. ultralytics/utils/downloads.py +29 -41
  218. ultralytics/utils/errors.py +6 -14
  219. ultralytics/utils/events.py +2 -4
  220. ultralytics/utils/export/__init__.py +7 -0
  221. ultralytics/utils/{export.py → export/engine.py} +16 -16
  222. ultralytics/utils/export/imx.py +325 -0
  223. ultralytics/utils/export/tensorflow.py +231 -0
  224. ultralytics/utils/files.py +24 -28
  225. ultralytics/utils/git.py +9 -11
  226. ultralytics/utils/instance.py +30 -51
  227. ultralytics/utils/logger.py +212 -114
  228. ultralytics/utils/loss.py +15 -24
  229. ultralytics/utils/metrics.py +131 -160
  230. ultralytics/utils/nms.py +21 -30
  231. ultralytics/utils/ops.py +107 -165
  232. ultralytics/utils/patches.py +33 -21
  233. ultralytics/utils/plotting.py +122 -119
  234. ultralytics/utils/tal.py +28 -44
  235. ultralytics/utils/torch_utils.py +70 -187
  236. ultralytics/utils/tqdm.py +20 -20
  237. ultralytics/utils/triton.py +13 -19
  238. ultralytics/utils/tuner.py +17 -5
  239. dgenerate_ultralytics_headless-8.3.196.dist-info/RECORD +0 -281
  240. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
  241. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
  242. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
  243. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
@@ -6,6 +6,8 @@ Usage:
6
6
  $ yolo mode=train model=yolo11n.pt data=coco8.yaml imgsz=640 epochs=100 batch=16
7
7
  """
8
8
 
9
+ from __future__ import annotations
10
+
9
11
  import gc
10
12
  import math
11
13
  import os
@@ -42,6 +44,7 @@ from ultralytics.utils.autobatch import check_train_batch_size
42
44
  from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_model_file_from_stem, print_args
43
45
  from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
44
46
  from ultralytics.utils.files import get_latest_run
47
+ from ultralytics.utils.plotting import plot_results
45
48
  from ultralytics.utils.torch_utils import (
46
49
  TORCH_2_4,
47
50
  EarlyStopping,
@@ -60,8 +63,7 @@ from ultralytics.utils.torch_utils import (
60
63
 
61
64
 
62
65
  class BaseTrainer:
63
- """
64
- A base class for creating trainers.
66
+ """A base class for creating trainers.
65
67
 
66
68
  This class provides the foundation for training YOLO models, handling the training loop, validation, checkpointing,
67
69
  and various training utilities. It supports both single-GPU and multi-GPU distributed training.
@@ -111,17 +113,17 @@ class BaseTrainer:
111
113
  """
112
114
 
113
115
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
114
- """
115
- Initialize the BaseTrainer class.
116
+ """Initialize the BaseTrainer class.
116
117
 
117
118
  Args:
118
119
  cfg (str, optional): Path to a configuration file.
119
120
  overrides (dict, optional): Configuration overrides.
120
121
  _callbacks (list, optional): List of callback functions.
121
122
  """
123
+ self.hub_session = overrides.pop("session", None) # HUB
122
124
  self.args = get_cfg(cfg, overrides)
123
125
  self.check_resume(overrides)
124
- self.device = select_device(self.args.device, self.args.batch)
126
+ self.device = select_device(self.args.device)
125
127
  # Update "-1" devices so post-training val does not repeat search
126
128
  self.args.device = os.getenv("CUDA_VISIBLE_DEVICES") if "cuda" in str(self.device) else str(self.device)
127
129
  self.validator = None
@@ -136,7 +138,12 @@ class BaseTrainer:
136
138
  if RANK in {-1, 0}:
137
139
  self.wdir.mkdir(parents=True, exist_ok=True) # make dir
138
140
  self.args.save_dir = str(self.save_dir)
139
- YAML.save(self.save_dir / "args.yaml", vars(self.args)) # save run args
141
+ # Save run args, serializing augmentations as reprs for resume compatibility
142
+ args_dict = vars(self.args).copy()
143
+ if args_dict.get("augmentations") is not None:
144
+ # Serialize Albumentations transforms as their repr strings for checkpoint compatibility
145
+ args_dict["augmentations"] = [repr(t) for t in args_dict["augmentations"]]
146
+ YAML.save(self.save_dir / "args.yaml", args_dict) # save run args
140
147
  self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt" # checkpoint paths
141
148
  self.save_period = self.args.save_period
142
149
 
@@ -168,14 +175,29 @@ class BaseTrainer:
168
175
  self.tloss = None
169
176
  self.loss_names = ["Loss"]
170
177
  self.csv = self.save_dir / "results.csv"
178
+ if self.csv.exists() and not self.args.resume:
179
+ self.csv.unlink()
171
180
  self.plot_idx = [0, 1, 2]
172
-
173
- # HUB
174
- self.hub_session = None
181
+ self.nan_recovery_attempts = 0
175
182
 
176
183
  # Callbacks
177
184
  self.callbacks = _callbacks or callbacks.get_default_callbacks()
178
- if RANK in {-1, 0}:
185
+
186
+ if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3'
187
+ world_size = len(self.args.device.split(","))
188
+ elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
189
+ world_size = len(self.args.device)
190
+ elif self.args.device in {"cpu", "mps"}: # i.e. device='cpu' or 'mps'
191
+ world_size = 0
192
+ elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number
193
+ world_size = 1 # default to device 0
194
+ else: # i.e. device=None or device=''
195
+ world_size = 0
196
+
197
+ self.ddp = world_size > 1 and "LOCAL_RANK" not in os.environ
198
+ self.world_size = world_size
199
+ # Run subprocess if DDP training, else train normally
200
+ if RANK in {-1, 0} and not self.ddp:
179
201
  callbacks.add_integration_callbacks(self)
180
202
  # Start console logging immediately at trainer initialization
181
203
  self.run_callbacks("on_pretrain_routine_start")
@@ -195,31 +217,20 @@ class BaseTrainer:
195
217
 
196
218
  def train(self):
197
219
  """Allow device='', device=None on Multi-GPU systems to default to device=0."""
198
- if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3'
199
- world_size = len(self.args.device.split(","))
200
- elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
201
- world_size = len(self.args.device)
202
- elif self.args.device in {"cpu", "mps"}: # i.e. device='cpu' or 'mps'
203
- world_size = 0
204
- elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number
205
- world_size = 1 # default to device 0
206
- else: # i.e. device=None or device=''
207
- world_size = 0
208
-
209
220
  # Run subprocess if DDP training, else train normally
210
- if world_size > 1 and "LOCAL_RANK" not in os.environ:
221
+ if self.ddp:
211
222
  # Argument checks
212
223
  if self.args.rect:
213
224
  LOGGER.warning("'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'")
214
225
  self.args.rect = False
215
226
  if self.args.batch < 1.0:
216
- LOGGER.warning(
217
- "'batch<1' for AutoBatch is incompatible with Multi-GPU training, setting default 'batch=16'"
227
+ raise ValueError(
228
+ "AutoBatch with batch<1 not supported for Multi-GPU training, "
229
+ f"please specify a valid batch size multiple of GPU count {self.world_size}, i.e. batch={self.world_size * 8}."
218
230
  )
219
- self.args.batch = 16
220
231
 
221
232
  # Command
222
- cmd, file = generate_ddp_command(world_size, self)
233
+ cmd, file = generate_ddp_command(self)
223
234
  try:
224
235
  LOGGER.info(f"{colorstr('DDP:')} debug command {' '.join(cmd)}")
225
236
  subprocess.run(cmd, check=True)
@@ -229,7 +240,7 @@ class BaseTrainer:
229
240
  ddp_cleanup(self, str(file))
230
241
 
231
242
  else:
232
- self._do_train(world_size)
243
+ self._do_train()
233
244
 
234
245
  def _setup_scheduler(self):
235
246
  """Initialize training learning rate scheduler."""
@@ -239,32 +250,26 @@ class BaseTrainer:
239
250
  self.lf = lambda x: max(1 - x / self.epochs, 0) * (1.0 - self.args.lrf) + self.args.lrf # linear
240
251
  self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
241
252
 
242
- def _setup_ddp(self, world_size):
253
+ def _setup_ddp(self):
243
254
  """Initialize and set the DistributedDataParallel parameters for training."""
244
255
  torch.cuda.set_device(RANK)
245
256
  self.device = torch.device("cuda", RANK)
246
- # LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
247
257
  os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout
248
258
  dist.init_process_group(
249
259
  backend="nccl" if dist.is_nccl_available() else "gloo",
250
260
  timeout=timedelta(seconds=10800), # 3 hours
251
261
  rank=RANK,
252
- world_size=world_size,
262
+ world_size=self.world_size,
253
263
  )
254
264
 
255
- def _setup_train(self, world_size):
265
+ def _setup_train(self):
256
266
  """Build dataloaders and optimizer on correct rank process."""
257
267
  ckpt = self.setup_model()
258
268
  self.model = self.model.to(self.device)
259
269
  self.set_model_attributes()
260
270
 
261
- # Initialize loss criterion before compilation for torch.compile compatibility
262
- if hasattr(self.model, "init_criterion"):
263
- self.model.criterion = self.model.init_criterion()
264
-
265
271
  # Compile model
266
- if self.args.compile:
267
- self.model = attempt_compile(self.model, device=self.device)
272
+ self.model = attempt_compile(self.model, device=self.device, mode=self.args.compile)
268
273
 
269
274
  # Freeze layers
270
275
  freeze_list = (
@@ -295,13 +300,13 @@ class BaseTrainer:
295
300
  callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
296
301
  self.amp = torch.tensor(check_amp(self.model), device=self.device)
297
302
  callbacks.default_callbacks = callbacks_backup # restore callbacks
298
- if RANK > -1 and world_size > 1: # DDP
303
+ if RANK > -1 and self.world_size > 1: # DDP
299
304
  dist.broadcast(self.amp.int(), src=0) # broadcast from rank 0 to all other ranks; gloo errors with boolean
300
305
  self.amp = bool(self.amp) # as boolean
301
306
  self.scaler = (
302
307
  torch.amp.GradScaler("cuda", enabled=self.amp) if TORCH_2_4 else torch.cuda.amp.GradScaler(enabled=self.amp)
303
308
  )
304
- if world_size > 1:
309
+ if self.world_size > 1:
305
310
  self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK], find_unused_parameters=True)
306
311
 
307
312
  # Check imgsz
@@ -314,22 +319,22 @@ class BaseTrainer:
314
319
  self.args.batch = self.batch_size = self.auto_batch()
315
320
 
316
321
  # Dataloaders
317
- batch_size = self.batch_size // max(world_size, 1)
322
+ batch_size = self.batch_size // max(self.world_size, 1)
318
323
  self.train_loader = self.get_dataloader(
319
324
  self.data["train"], batch_size=batch_size, rank=LOCAL_RANK, mode="train"
320
325
  )
326
+ # Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
327
+ self.test_loader = self.get_dataloader(
328
+ self.data.get("val") or self.data.get("test"),
329
+ batch_size=batch_size if self.args.task == "obb" else batch_size * 2,
330
+ rank=LOCAL_RANK,
331
+ mode="val",
332
+ )
333
+ self.validator = self.get_validator()
334
+ self.ema = ModelEMA(self.model)
321
335
  if RANK in {-1, 0}:
322
- # Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
323
- self.test_loader = self.get_dataloader(
324
- self.data.get("val") or self.data.get("test"),
325
- batch_size=batch_size if self.args.task == "obb" else batch_size * 2,
326
- rank=-1,
327
- mode="val",
328
- )
329
- self.validator = self.get_validator()
330
336
  metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val")
331
337
  self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
332
- self.ema = ModelEMA(self.model)
333
338
  if self.args.plots:
334
339
  self.plot_training_labels()
335
340
 
@@ -352,11 +357,11 @@ class BaseTrainer:
352
357
  self.scheduler.last_epoch = self.start_epoch - 1 # do not move
353
358
  self.run_callbacks("on_pretrain_routine_end")
354
359
 
355
- def _do_train(self, world_size=1):
360
+ def _do_train(self):
356
361
  """Train the model with the specified world size."""
357
- if world_size > 1:
358
- self._setup_ddp(world_size)
359
- self._setup_train(world_size)
362
+ if self.world_size > 1:
363
+ self._setup_ddp()
364
+ self._setup_train()
360
365
 
361
366
  nb = len(self.train_loader) # number of batches
362
367
  nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1 # warmup iterations
@@ -367,7 +372,7 @@ class BaseTrainer:
367
372
  self.run_callbacks("on_train_start")
368
373
  LOGGER.info(
369
374
  f"Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n"
370
- f"Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n"
375
+ f"Using {self.train_loader.num_workers * (self.world_size or 1)} dataloader workers\n"
371
376
  f"Logging results to {colorstr('bold', self.save_dir)}\n"
372
377
  f"Starting training for " + (f"{self.args.time} hours..." if self.args.time else f"{self.epochs} epochs...")
373
378
  )
@@ -414,19 +419,19 @@ class BaseTrainer:
414
419
  # Forward
415
420
  with autocast(self.amp):
416
421
  batch = self.preprocess_batch(batch)
417
- metadata = {k: batch.pop(k, None) for k in ["im_file", "ori_shape", "resized_shape"]}
418
- loss, self.loss_items = self.model(batch)
422
+ if self.args.compile:
423
+ # Decouple inference and loss calculations for improved compile performance
424
+ preds = self.model(batch["img"])
425
+ loss, self.loss_items = unwrap_model(self.model).loss(batch, preds)
426
+ else:
427
+ loss, self.loss_items = self.model(batch)
419
428
  self.loss = loss.sum()
420
429
  if RANK != -1:
421
- self.loss *= world_size
422
- self.tloss = (
423
- (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items
424
- )
430
+ self.loss *= self.world_size
431
+ self.tloss = self.loss_items if self.tloss is None else (self.tloss * i + self.loss_items) / (i + 1)
425
432
 
426
433
  # Backward
427
434
  self.scaler.scale(self.loss).backward()
428
-
429
- # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
430
435
  if ni - last_opt_step >= self.accumulate:
431
436
  self.optimizer_step()
432
437
  last_opt_step = ni
@@ -456,21 +461,28 @@ class BaseTrainer:
456
461
  )
457
462
  self.run_callbacks("on_batch_end")
458
463
  if self.args.plots and ni in self.plot_idx:
459
- batch = {**batch, **metadata}
460
464
  self.plot_training_samples(batch, ni)
461
465
 
462
466
  self.run_callbacks("on_train_batch_end")
463
467
 
464
468
  self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
469
+
465
470
  self.run_callbacks("on_train_epoch_end")
466
471
  if RANK in {-1, 0}:
467
- final_epoch = epoch + 1 >= self.epochs
468
472
  self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
469
473
 
470
- # Validation
471
- if self.args.val or final_epoch or self.stopper.possible_stop or self.stop:
472
- self._clear_memory(threshold=0.5) # prevent VRAM spike
473
- self.metrics, self.fitness = self.validate()
474
+ # Validation
475
+ final_epoch = epoch + 1 >= self.epochs
476
+ if self.args.val or final_epoch or self.stopper.possible_stop or self.stop:
477
+ self._clear_memory(threshold=0.5) # prevent VRAM spike
478
+ self.metrics, self.fitness = self.validate()
479
+
480
+ # NaN recovery
481
+ if self._handle_nan_recovery(epoch):
482
+ continue
483
+
484
+ self.nan_recovery_attempts = 0
485
+ if RANK in {-1, 0}:
474
486
  self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr})
475
487
  self.stop |= self.stopper(epoch + 1, self.fitness) or final_epoch
476
488
  if self.args.time:
@@ -503,11 +515,11 @@ class BaseTrainer:
503
515
  break # must break all DDP ranks
504
516
  epoch += 1
505
517
 
518
+ seconds = time.time() - self.train_time_start
519
+ LOGGER.info(f"\n{epoch - self.start_epoch + 1} epochs completed in {seconds / 3600:.3f} hours.")
520
+ # Do final val with best.pt
521
+ self.final_eval()
506
522
  if RANK in {-1, 0}:
507
- # Do final val with best.pt
508
- seconds = time.time() - self.train_time_start
509
- LOGGER.info(f"\n{epoch - self.start_epoch + 1} epochs completed in {seconds / 3600:.3f} hours.")
510
- self.final_eval()
511
523
  if self.args.plots:
512
524
  self.plot_metrics()
513
525
  self.run_callbacks("on_train_end")
@@ -538,7 +550,7 @@ class BaseTrainer:
538
550
  total = torch.cuda.get_device_properties(self.device).total_memory
539
551
  return ((memory / total) if total > 0 else 0) if fraction else (memory / 2**30)
540
552
 
541
- def _clear_memory(self, threshold: float = None):
553
+ def _clear_memory(self, threshold: float | None = None):
542
554
  """Clear accelerator memory by calling garbage collector and emptying cache."""
543
555
  if threshold:
544
556
  assert 0 <= threshold <= 1, "Threshold must be between 0 and 1."
@@ -556,7 +568,10 @@ class BaseTrainer:
556
568
  """Read results.csv into a dictionary using polars."""
557
569
  import polars as pl # scope for faster 'import ultralytics'
558
570
 
559
- return pl.read_csv(self.csv, infer_schema_length=None).to_dict(as_series=False)
571
+ try:
572
+ return pl.read_csv(self.csv, infer_schema_length=None).to_dict(as_series=False)
573
+ except Exception:
574
+ return {}
560
575
 
561
576
  def _model_train(self):
562
577
  """Set model in training mode."""
@@ -580,6 +595,7 @@ class BaseTrainer:
580
595
  "ema": deepcopy(unwrap_model(self.ema.ema)).half(),
581
596
  "updates": self.ema.updates,
582
597
  "optimizer": convert_optimizer_state_dict_to_fp16(deepcopy(self.optimizer.state_dict())),
598
+ "scaler": self.scaler.state_dict(),
583
599
  "train_args": vars(self.args), # save as dict
584
600
  "train_metrics": {**self.metrics, **{"fitness": self.fitness}},
585
601
  "train_results": self.read_results_csv(),
@@ -599,6 +615,7 @@ class BaseTrainer:
599
615
  serialized_ckpt = buffer.getvalue() # get the serialized content to save
600
616
 
601
617
  # Save checkpoints
618
+ self.wdir.mkdir(parents=True, exist_ok=True) # ensure weights directory exists
602
619
  self.last.write_bytes(serialized_ckpt) # save last.pt
603
620
  if self.best_fitness == self.fitness:
604
621
  self.best.write_bytes(serialized_ckpt) # save best.pt
@@ -606,8 +623,7 @@ class BaseTrainer:
606
623
  (self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'
607
624
 
608
625
  def get_dataset(self):
609
- """
610
- Get train and validation datasets from data dictionary.
626
+ """Get train and validation datasets from data dictionary.
611
627
 
612
628
  Returns:
613
629
  (dict): A dictionary containing the training/validation/test dataset and category names.
@@ -615,7 +631,7 @@ class BaseTrainer:
615
631
  try:
616
632
  if self.args.task == "classify":
617
633
  data = check_cls_dataset(self.args.data)
618
- elif self.args.data.rsplit(".", 1)[-1] == "ndjson":
634
+ elif str(self.args.data).rsplit(".", 1)[-1] == "ndjson":
619
635
  # Convert NDJSON to YOLO format
620
636
  import asyncio
621
637
 
@@ -624,7 +640,7 @@ class BaseTrainer:
624
640
  yaml_path = asyncio.run(convert_ndjson_to_yolo(self.args.data))
625
641
  self.args.data = str(yaml_path)
626
642
  data = check_det_dataset(self.args.data)
627
- elif self.args.data.rsplit(".", 1)[-1] in {"yaml", "yml"} or self.args.task in {
643
+ elif str(self.args.data).rsplit(".", 1)[-1] in {"yaml", "yml"} or self.args.task in {
628
644
  "detect",
629
645
  "segment",
630
646
  "pose",
@@ -642,8 +658,7 @@ class BaseTrainer:
642
658
  return data
643
659
 
644
660
  def setup_model(self):
645
- """
646
- Load, create, or download model for any task.
661
+ """Load, create, or download model for any task.
647
662
 
648
663
  Returns:
649
664
  (dict): Optional checkpoint to resume training from.
@@ -664,7 +679,7 @@ class BaseTrainer:
664
679
  def optimizer_step(self):
665
680
  """Perform a single step of the training optimizer with gradient clipping and EMA update."""
666
681
  self.scaler.unscale_(self.optimizer) # unscale gradients
667
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0) # clip gradients
682
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)
668
683
  self.scaler.step(self.optimizer)
669
684
  self.scaler.update()
670
685
  self.optimizer.zero_grad()
@@ -676,14 +691,19 @@ class BaseTrainer:
676
691
  return batch
677
692
 
678
693
  def validate(self):
679
- """
680
- Run validation on val set using self.validator.
694
+ """Run validation on val set using self.validator.
681
695
 
682
696
  Returns:
683
697
  metrics (dict): Dictionary of validation metrics.
684
698
  fitness (float): Fitness score for the validation.
685
699
  """
700
+ if self.ema and self.world_size > 1:
701
+ # Sync EMA buffers from rank 0 to all ranks
702
+ for buffer in self.ema.ema.buffers():
703
+ dist.broadcast(buffer, src=0)
686
704
  metrics = self.validator(self)
705
+ if metrics is None:
706
+ return None, None
687
707
  fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
688
708
  if not self.best_fitness or self.best_fitness < fitness:
689
709
  self.best_fitness = fitness
@@ -694,11 +714,11 @@ class BaseTrainer:
694
714
  raise NotImplementedError("This task trainer doesn't support loading cfg files")
695
715
 
696
716
  def get_validator(self):
697
- """Return a NotImplementedError when the get_validator function is called."""
717
+ """Raise NotImplementedError (must be implemented by subclasses)."""
698
718
  raise NotImplementedError("get_validator function not implemented in trainer")
699
719
 
700
720
  def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
701
- """Return dataloader derived from torch.data.Dataloader."""
721
+ """Raise NotImplementedError (must return a `torch.utils.data.DataLoader` in subclasses)."""
702
722
  raise NotImplementedError("get_dataloader function not implemented in trainer")
703
723
 
704
724
  def build_dataset(self, img_path, mode="train", batch=None):
@@ -706,10 +726,9 @@ class BaseTrainer:
706
726
  raise NotImplementedError("build_dataset function not implemented in trainer")
707
727
 
708
728
  def label_loss_items(self, loss_items=None, prefix="train"):
709
- """
710
- Return a loss dict with labelled training loss items tensor.
729
+ """Return a loss dict with labeled training loss items tensor.
711
730
 
712
- Note:
731
+ Notes:
713
732
  This is not needed for classification but necessary for segmentation & detection
714
733
  """
715
734
  return {"loss": loss_items} if loss_items is not None else ["loss"]
@@ -739,14 +758,15 @@ class BaseTrainer:
739
758
  """Save training metrics to a CSV file."""
740
759
  keys, vals = list(metrics.keys()), list(metrics.values())
741
760
  n = len(metrics) + 2 # number of cols
742
- s = "" if self.csv.exists() else (("%s," * n % tuple(["epoch", "time"] + keys)).rstrip(",") + "\n") # header
743
761
  t = time.time() - self.train_time_start
762
+ self.csv.parent.mkdir(parents=True, exist_ok=True) # ensure parent directory exists
763
+ s = "" if self.csv.exists() else ("%s," * n % ("epoch", "time", *keys)).rstrip(",") + "\n"
744
764
  with open(self.csv, "a", encoding="utf-8") as f:
745
- f.write(s + ("%.6g," * n % tuple([self.epoch + 1, t] + vals)).rstrip(",") + "\n")
765
+ f.write(s + ("%.6g," * n % (self.epoch + 1, t, *vals)).rstrip(",") + "\n")
746
766
 
747
767
  def plot_metrics(self):
748
- """Plot and display metrics visually."""
749
- pass
768
+ """Plot metrics from a CSV file."""
769
+ plot_results(file=self.csv, on_plot=self.on_plot) # save results.png
750
770
 
751
771
  def on_plot(self, name, data=None):
752
772
  """Register plots (e.g. to be consumed in callbacks)."""
@@ -755,20 +775,20 @@ class BaseTrainer:
755
775
 
756
776
  def final_eval(self):
757
777
  """Perform final evaluation and validation for object detection YOLO model."""
758
- ckpt = {}
759
- for f in self.last, self.best:
760
- if f.exists():
761
- if f is self.last:
762
- ckpt = strip_optimizer(f)
763
- elif f is self.best:
764
- k = "train_results" # update best.pt train_metrics from last.pt
765
- strip_optimizer(f, updates={k: ckpt[k]} if k in ckpt else None)
766
- LOGGER.info(f"\nValidating {f}...")
767
- self.validator.args.plots = self.args.plots
768
- self.validator.args.compile = False # disable final val compile as too slow
769
- self.metrics = self.validator(model=f)
770
- self.metrics.pop("fitness", None)
771
- self.run_callbacks("on_fit_epoch_end")
778
+ model = self.best if self.best.exists() else None
779
+ with torch_distributed_zero_first(LOCAL_RANK): # strip only on GPU 0; other GPUs should wait
780
+ if RANK in {-1, 0}:
781
+ ckpt = strip_optimizer(self.last) if self.last.exists() else {}
782
+ if model:
783
+ # update best.pt train_metrics from last.pt
784
+ strip_optimizer(self.best, updates={"train_results": ckpt.get("train_results")})
785
+ if model:
786
+ LOGGER.info(f"\nValidating {model}...")
787
+ self.validator.args.plots = self.args.plots
788
+ self.validator.args.compile = False # disable final val compile as too slow
789
+ self.metrics = self.validator(model=model)
790
+ self.metrics.pop("fitness", None)
791
+ self.run_callbacks("on_fit_epoch_end")
772
792
 
773
793
  def check_resume(self, overrides):
774
794
  """Check if resume checkpoint exists and update arguments accordingly."""
@@ -791,10 +811,29 @@ class BaseTrainer:
791
811
  "batch",
792
812
  "device",
793
813
  "close_mosaic",
814
+ "augmentations",
815
+ "save_period",
816
+ "workers",
817
+ "cache",
818
+ "patience",
819
+ "time",
820
+ "freeze",
821
+ "val",
822
+ "plots",
794
823
  ): # allow arg updates to reduce memory or update device on resume
795
824
  if k in overrides:
796
825
  setattr(self.args, k, overrides[k])
797
826
 
827
+ # Handle augmentations parameter for resume: check if user provided custom augmentations
828
+ if ckpt_args.get("augmentations") is not None:
829
+ # Augmentations were saved in checkpoint as reprs but can't be restored automatically
830
+ LOGGER.warning(
831
+ "Custom Albumentations transforms were used in the original training run but are not "
832
+ "being restored. To preserve custom augmentations when resuming, you need to pass the "
833
+ "'augmentations' parameter again to get expected results. Example: \n"
834
+ f"model.train(resume=True, augmentations={ckpt_args['augmentations']})"
835
+ )
836
+
798
837
  except Exception as e:
799
838
  raise FileNotFoundError(
800
839
  "Resume checkpoint not found. Please pass a valid checkpoint to resume from, "
@@ -802,18 +841,54 @@ class BaseTrainer:
802
841
  ) from e
803
842
  self.resume = resume
804
843
 
844
+ def _load_checkpoint_state(self, ckpt):
845
+ """Load optimizer, scaler, EMA, and best_fitness from checkpoint."""
846
+ if ckpt.get("optimizer") is not None:
847
+ self.optimizer.load_state_dict(ckpt["optimizer"])
848
+ if ckpt.get("scaler") is not None:
849
+ self.scaler.load_state_dict(ckpt["scaler"])
850
+ if self.ema and ckpt.get("ema"):
851
+ self.ema = ModelEMA(self.model) # validation with EMA creates inference tensors that can't be updated
852
+ self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict())
853
+ self.ema.updates = ckpt["updates"]
854
+ self.best_fitness = ckpt.get("best_fitness", 0.0)
855
+
856
+ def _handle_nan_recovery(self, epoch):
857
+ """Detect and recover from NaN/Inf loss and fitness collapse by loading last checkpoint."""
858
+ loss_nan = self.loss is not None and not self.loss.isfinite()
859
+ fitness_nan = self.fitness is not None and not np.isfinite(self.fitness)
860
+ fitness_collapse = self.best_fitness and self.best_fitness > 0 and self.fitness == 0
861
+ corrupted = RANK in {-1, 0} and loss_nan and (fitness_nan or fitness_collapse)
862
+ reason = "Loss NaN/Inf" if loss_nan else "Fitness NaN/Inf" if fitness_nan else "Fitness collapse"
863
+ if RANK != -1: # DDP: broadcast to all ranks
864
+ broadcast_list = [corrupted if RANK == 0 else None]
865
+ dist.broadcast_object_list(broadcast_list, 0)
866
+ corrupted = broadcast_list[0]
867
+ if not corrupted:
868
+ return False
869
+ if epoch == self.start_epoch or not self.last.exists():
870
+ LOGGER.warning(f"{reason} detected but can not recover from last.pt...")
871
+ return False # Cannot recover on first epoch, let training continue
872
+ self.nan_recovery_attempts += 1
873
+ if self.nan_recovery_attempts > 3:
874
+ raise RuntimeError(f"Training failed: NaN persisted for {self.nan_recovery_attempts} epochs")
875
+ LOGGER.warning(f"{reason} detected (attempt {self.nan_recovery_attempts}/3), recovering from last.pt...")
876
+ self._model_train() # set model to train mode before loading checkpoint to avoid inference tensor errors
877
+ _, ckpt = load_checkpoint(self.last)
878
+ ema_state = ckpt["ema"].float().state_dict()
879
+ if not all(torch.isfinite(v).all() for v in ema_state.values() if isinstance(v, torch.Tensor)):
880
+ raise RuntimeError(f"Checkpoint {self.last} is corrupted with NaN/Inf weights")
881
+ unwrap_model(self.model).load_state_dict(ema_state) # Load EMA weights into model
882
+ self._load_checkpoint_state(ckpt) # Load optimizer/scaler/EMA/best_fitness
883
+ del ckpt, ema_state
884
+ self.scheduler.last_epoch = epoch - 1
885
+ return True
886
+
805
887
  def resume_training(self, ckpt):
806
888
  """Resume YOLO training from given epoch and best fitness."""
807
889
  if ckpt is None or not self.resume:
808
890
  return
809
- best_fitness = 0.0
810
891
  start_epoch = ckpt.get("epoch", -1) + 1
811
- if ckpt.get("optimizer", None) is not None:
812
- self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer
813
- best_fitness = ckpt["best_fitness"]
814
- if self.ema and ckpt.get("ema"):
815
- self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict()) # EMA
816
- self.ema.updates = ckpt["updates"]
817
892
  assert start_epoch > 0, (
818
893
  f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n"
819
894
  f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
@@ -824,7 +899,7 @@ class BaseTrainer:
824
899
  f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
825
900
  )
826
901
  self.epochs += ckpt["epoch"] # finetune additional epochs
827
- self.best_fitness = best_fitness
902
+ self._load_checkpoint_state(ckpt)
828
903
  self.start_epoch = start_epoch
829
904
  if start_epoch > (self.epochs - self.args.close_mosaic):
830
905
  self._close_dataloader_mosaic()
@@ -838,18 +913,16 @@ class BaseTrainer:
838
913
  self.train_loader.dataset.close_mosaic(hyp=copy(self.args))
839
914
 
840
915
  def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
841
- """
842
- Construct an optimizer for the given model.
916
+ """Construct an optimizer for the given model.
843
917
 
844
918
  Args:
845
919
  model (torch.nn.Module): The model for which to build an optimizer.
846
- name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected
847
- based on the number of iterations.
920
+ name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected based on the
921
+ number of iterations.
848
922
  lr (float, optional): The learning rate for the optimizer.
849
923
  momentum (float, optional): The momentum factor for the optimizer.
850
924
  decay (float, optional): The weight decay for the optimizer.
851
- iterations (float, optional): The number of iterations, which determines the optimizer if
852
- name is 'auto'.
925
+ iterations (float, optional): The number of iterations, which determines the optimizer if name is 'auto'.
853
926
 
854
927
  Returns:
855
928
  (torch.optim.Optimizer): The constructed optimizer.