dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
  2. dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
  3. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -6
  5. tests/conftest.py +15 -39
  6. tests/test_cli.py +17 -17
  7. tests/test_cuda.py +17 -8
  8. tests/test_engine.py +36 -10
  9. tests/test_exports.py +98 -37
  10. tests/test_integrations.py +12 -15
  11. tests/test_python.py +126 -82
  12. tests/test_solutions.py +319 -135
  13. ultralytics/__init__.py +27 -9
  14. ultralytics/cfg/__init__.py +83 -87
  15. ultralytics/cfg/datasets/Argoverse.yaml +4 -4
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
  17. ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
  18. ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
  19. ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
  20. ultralytics/cfg/datasets/ImageNet.yaml +3 -3
  21. ultralytics/cfg/datasets/Objects365.yaml +24 -20
  22. ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
  23. ultralytics/cfg/datasets/VOC.yaml +10 -13
  24. ultralytics/cfg/datasets/VisDrone.yaml +43 -33
  25. ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
  26. ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
  27. ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
  28. ultralytics/cfg/datasets/coco-pose.yaml +26 -4
  29. ultralytics/cfg/datasets/coco.yaml +4 -4
  30. ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
  31. ultralytics/cfg/datasets/coco128.yaml +2 -2
  32. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  33. ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
  34. ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
  35. ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
  36. ultralytics/cfg/datasets/coco8.yaml +2 -2
  37. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  38. ultralytics/cfg/datasets/crack-seg.yaml +5 -5
  39. ultralytics/cfg/datasets/dog-pose.yaml +32 -4
  40. ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
  41. ultralytics/cfg/datasets/dota8.yaml +2 -2
  42. ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
  43. ultralytics/cfg/datasets/lvis.yaml +9 -9
  44. ultralytics/cfg/datasets/medical-pills.yaml +4 -5
  45. ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
  46. ultralytics/cfg/datasets/package-seg.yaml +5 -5
  47. ultralytics/cfg/datasets/signature.yaml +4 -4
  48. ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
  49. ultralytics/cfg/datasets/xView.yaml +5 -5
  50. ultralytics/cfg/default.yaml +96 -93
  51. ultralytics/cfg/trackers/botsort.yaml +16 -17
  52. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  53. ultralytics/data/__init__.py +4 -4
  54. ultralytics/data/annotator.py +12 -12
  55. ultralytics/data/augment.py +531 -564
  56. ultralytics/data/base.py +76 -81
  57. ultralytics/data/build.py +206 -42
  58. ultralytics/data/converter.py +179 -78
  59. ultralytics/data/dataset.py +121 -121
  60. ultralytics/data/loaders.py +114 -91
  61. ultralytics/data/split.py +28 -15
  62. ultralytics/data/split_dota.py +67 -48
  63. ultralytics/data/utils.py +110 -89
  64. ultralytics/engine/exporter.py +422 -460
  65. ultralytics/engine/model.py +224 -252
  66. ultralytics/engine/predictor.py +94 -89
  67. ultralytics/engine/results.py +345 -595
  68. ultralytics/engine/trainer.py +231 -134
  69. ultralytics/engine/tuner.py +279 -73
  70. ultralytics/engine/validator.py +53 -46
  71. ultralytics/hub/__init__.py +26 -28
  72. ultralytics/hub/auth.py +30 -16
  73. ultralytics/hub/google/__init__.py +34 -36
  74. ultralytics/hub/session.py +53 -77
  75. ultralytics/hub/utils.py +23 -109
  76. ultralytics/models/__init__.py +1 -1
  77. ultralytics/models/fastsam/__init__.py +1 -1
  78. ultralytics/models/fastsam/model.py +36 -18
  79. ultralytics/models/fastsam/predict.py +33 -44
  80. ultralytics/models/fastsam/utils.py +4 -5
  81. ultralytics/models/fastsam/val.py +12 -14
  82. ultralytics/models/nas/__init__.py +1 -1
  83. ultralytics/models/nas/model.py +16 -20
  84. ultralytics/models/nas/predict.py +12 -14
  85. ultralytics/models/nas/val.py +4 -5
  86. ultralytics/models/rtdetr/__init__.py +1 -1
  87. ultralytics/models/rtdetr/model.py +9 -9
  88. ultralytics/models/rtdetr/predict.py +22 -17
  89. ultralytics/models/rtdetr/train.py +20 -16
  90. ultralytics/models/rtdetr/val.py +79 -59
  91. ultralytics/models/sam/__init__.py +8 -2
  92. ultralytics/models/sam/amg.py +53 -38
  93. ultralytics/models/sam/build.py +29 -31
  94. ultralytics/models/sam/model.py +33 -38
  95. ultralytics/models/sam/modules/blocks.py +159 -182
  96. ultralytics/models/sam/modules/decoders.py +38 -47
  97. ultralytics/models/sam/modules/encoders.py +114 -133
  98. ultralytics/models/sam/modules/memory_attention.py +38 -31
  99. ultralytics/models/sam/modules/sam.py +114 -93
  100. ultralytics/models/sam/modules/tiny_encoder.py +268 -291
  101. ultralytics/models/sam/modules/transformer.py +59 -66
  102. ultralytics/models/sam/modules/utils.py +55 -72
  103. ultralytics/models/sam/predict.py +745 -341
  104. ultralytics/models/utils/loss.py +118 -107
  105. ultralytics/models/utils/ops.py +118 -71
  106. ultralytics/models/yolo/__init__.py +1 -1
  107. ultralytics/models/yolo/classify/predict.py +28 -26
  108. ultralytics/models/yolo/classify/train.py +50 -81
  109. ultralytics/models/yolo/classify/val.py +68 -61
  110. ultralytics/models/yolo/detect/predict.py +12 -15
  111. ultralytics/models/yolo/detect/train.py +56 -46
  112. ultralytics/models/yolo/detect/val.py +279 -223
  113. ultralytics/models/yolo/model.py +167 -86
  114. ultralytics/models/yolo/obb/predict.py +7 -11
  115. ultralytics/models/yolo/obb/train.py +23 -25
  116. ultralytics/models/yolo/obb/val.py +107 -99
  117. ultralytics/models/yolo/pose/__init__.py +1 -1
  118. ultralytics/models/yolo/pose/predict.py +12 -14
  119. ultralytics/models/yolo/pose/train.py +31 -69
  120. ultralytics/models/yolo/pose/val.py +119 -254
  121. ultralytics/models/yolo/segment/predict.py +21 -25
  122. ultralytics/models/yolo/segment/train.py +12 -66
  123. ultralytics/models/yolo/segment/val.py +126 -305
  124. ultralytics/models/yolo/world/train.py +53 -45
  125. ultralytics/models/yolo/world/train_world.py +51 -32
  126. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  127. ultralytics/models/yolo/yoloe/predict.py +30 -37
  128. ultralytics/models/yolo/yoloe/train.py +89 -71
  129. ultralytics/models/yolo/yoloe/train_seg.py +15 -17
  130. ultralytics/models/yolo/yoloe/val.py +56 -41
  131. ultralytics/nn/__init__.py +9 -11
  132. ultralytics/nn/autobackend.py +179 -107
  133. ultralytics/nn/modules/__init__.py +67 -67
  134. ultralytics/nn/modules/activation.py +8 -7
  135. ultralytics/nn/modules/block.py +302 -323
  136. ultralytics/nn/modules/conv.py +61 -104
  137. ultralytics/nn/modules/head.py +488 -186
  138. ultralytics/nn/modules/transformer.py +183 -123
  139. ultralytics/nn/modules/utils.py +15 -20
  140. ultralytics/nn/tasks.py +327 -203
  141. ultralytics/nn/text_model.py +81 -65
  142. ultralytics/py.typed +1 -0
  143. ultralytics/solutions/__init__.py +12 -12
  144. ultralytics/solutions/ai_gym.py +19 -27
  145. ultralytics/solutions/analytics.py +36 -26
  146. ultralytics/solutions/config.py +29 -28
  147. ultralytics/solutions/distance_calculation.py +23 -24
  148. ultralytics/solutions/heatmap.py +17 -19
  149. ultralytics/solutions/instance_segmentation.py +21 -19
  150. ultralytics/solutions/object_blurrer.py +16 -17
  151. ultralytics/solutions/object_counter.py +48 -53
  152. ultralytics/solutions/object_cropper.py +22 -16
  153. ultralytics/solutions/parking_management.py +61 -58
  154. ultralytics/solutions/queue_management.py +19 -19
  155. ultralytics/solutions/region_counter.py +63 -50
  156. ultralytics/solutions/security_alarm.py +22 -25
  157. ultralytics/solutions/similarity_search.py +107 -60
  158. ultralytics/solutions/solutions.py +343 -262
  159. ultralytics/solutions/speed_estimation.py +35 -31
  160. ultralytics/solutions/streamlit_inference.py +104 -40
  161. ultralytics/solutions/templates/similarity-search.html +31 -24
  162. ultralytics/solutions/trackzone.py +24 -24
  163. ultralytics/solutions/vision_eye.py +11 -12
  164. ultralytics/trackers/__init__.py +1 -1
  165. ultralytics/trackers/basetrack.py +18 -27
  166. ultralytics/trackers/bot_sort.py +48 -39
  167. ultralytics/trackers/byte_tracker.py +94 -94
  168. ultralytics/trackers/track.py +7 -16
  169. ultralytics/trackers/utils/gmc.py +37 -69
  170. ultralytics/trackers/utils/kalman_filter.py +68 -76
  171. ultralytics/trackers/utils/matching.py +13 -17
  172. ultralytics/utils/__init__.py +251 -275
  173. ultralytics/utils/autobatch.py +19 -7
  174. ultralytics/utils/autodevice.py +68 -38
  175. ultralytics/utils/benchmarks.py +169 -130
  176. ultralytics/utils/callbacks/base.py +12 -13
  177. ultralytics/utils/callbacks/clearml.py +14 -15
  178. ultralytics/utils/callbacks/comet.py +139 -66
  179. ultralytics/utils/callbacks/dvc.py +19 -27
  180. ultralytics/utils/callbacks/hub.py +8 -6
  181. ultralytics/utils/callbacks/mlflow.py +6 -10
  182. ultralytics/utils/callbacks/neptune.py +11 -19
  183. ultralytics/utils/callbacks/platform.py +73 -0
  184. ultralytics/utils/callbacks/raytune.py +3 -4
  185. ultralytics/utils/callbacks/tensorboard.py +9 -12
  186. ultralytics/utils/callbacks/wb.py +33 -30
  187. ultralytics/utils/checks.py +163 -114
  188. ultralytics/utils/cpu.py +89 -0
  189. ultralytics/utils/dist.py +24 -20
  190. ultralytics/utils/downloads.py +176 -146
  191. ultralytics/utils/errors.py +11 -13
  192. ultralytics/utils/events.py +113 -0
  193. ultralytics/utils/export/__init__.py +7 -0
  194. ultralytics/utils/{export.py → export/engine.py} +81 -63
  195. ultralytics/utils/export/imx.py +294 -0
  196. ultralytics/utils/export/tensorflow.py +217 -0
  197. ultralytics/utils/files.py +33 -36
  198. ultralytics/utils/git.py +137 -0
  199. ultralytics/utils/instance.py +105 -120
  200. ultralytics/utils/logger.py +404 -0
  201. ultralytics/utils/loss.py +99 -61
  202. ultralytics/utils/metrics.py +649 -478
  203. ultralytics/utils/nms.py +337 -0
  204. ultralytics/utils/ops.py +263 -451
  205. ultralytics/utils/patches.py +70 -31
  206. ultralytics/utils/plotting.py +253 -223
  207. ultralytics/utils/tal.py +48 -61
  208. ultralytics/utils/torch_utils.py +244 -251
  209. ultralytics/utils/tqdm.py +438 -0
  210. ultralytics/utils/triton.py +22 -23
  211. ultralytics/utils/tuner.py +11 -10
  212. dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
  213. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
  214. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
  215. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
@@ -6,6 +6,8 @@ Usage:
6
6
  $ yolo mode=train model=yolo11n.pt data=coco8.yaml imgsz=640 epochs=100 batch=16
7
7
  """
8
8
 
9
+ from __future__ import annotations
10
+
9
11
  import gc
10
12
  import math
11
13
  import os
@@ -24,9 +26,10 @@ from torch import nn, optim
24
26
  from ultralytics import __version__
25
27
  from ultralytics.cfg import get_cfg, get_save_dir
26
28
  from ultralytics.data.utils import check_cls_dataset, check_det_dataset
27
- from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
29
+ from ultralytics.nn.tasks import load_checkpoint
28
30
  from ultralytics.utils import (
29
31
  DEFAULT_CFG,
32
+ GIT,
30
33
  LOCAL_RANK,
31
34
  LOGGER,
32
35
  RANK,
@@ -41,10 +44,12 @@ from ultralytics.utils.autobatch import check_train_batch_size
41
44
  from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_model_file_from_stem, print_args
42
45
  from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
43
46
  from ultralytics.utils.files import get_latest_run
47
+ from ultralytics.utils.plotting import plot_results
44
48
  from ultralytics.utils.torch_utils import (
45
49
  TORCH_2_4,
46
50
  EarlyStopping,
47
51
  ModelEMA,
52
+ attempt_compile,
48
53
  autocast,
49
54
  convert_optimizer_state_dict_to_fp16,
50
55
  init_seeds,
@@ -53,12 +58,15 @@ from ultralytics.utils.torch_utils import (
53
58
  strip_optimizer,
54
59
  torch_distributed_zero_first,
55
60
  unset_deterministic,
61
+ unwrap_model,
56
62
  )
57
63
 
58
64
 
59
65
  class BaseTrainer:
60
- """
61
- A base class for creating trainers.
66
+ """A base class for creating trainers.
67
+
68
+ This class provides the foundation for training YOLO models, handling the training loop, validation, checkpointing,
69
+ and various training utilities. It supports both single-GPU and multi-GPU distributed training.
62
70
 
63
71
  Attributes:
64
72
  args (SimpleNamespace): Configuration for the trainer.
@@ -89,21 +97,34 @@ class BaseTrainer:
89
97
  csv (Path): Path to results CSV file.
90
98
  metrics (dict): Dictionary of metrics.
91
99
  plots (dict): Dictionary of plots.
100
+
101
+ Methods:
102
+ train: Execute the training process.
103
+ validate: Run validation on the test set.
104
+ save_model: Save model training checkpoints.
105
+ get_dataset: Get train and validation datasets.
106
+ setup_model: Load, create, or download model.
107
+ build_optimizer: Construct an optimizer for the model.
108
+
109
+ Examples:
110
+ Initialize a trainer and start training
111
+ >>> trainer = BaseTrainer(cfg="config.yaml")
112
+ >>> trainer.train()
92
113
  """
93
114
 
94
115
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
95
- """
96
- Initialize the BaseTrainer class.
116
+ """Initialize the BaseTrainer class.
97
117
 
98
118
  Args:
99
- cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
100
- overrides (dict, optional): Configuration overrides. Defaults to None.
101
- _callbacks (list, optional): List of callback functions. Defaults to None.
119
+ cfg (str, optional): Path to a configuration file.
120
+ overrides (dict, optional): Configuration overrides.
121
+ _callbacks (list, optional): List of callback functions.
102
122
  """
123
+ self.hub_session = overrides.pop("session", None) # HUB
103
124
  self.args = get_cfg(cfg, overrides)
104
125
  self.check_resume(overrides)
105
- self.device = select_device(self.args.device, self.args.batch)
106
- # update "-1" devices so post-training val does not repeat search
126
+ self.device = select_device(self.args.device)
127
+ # Update "-1" devices so post-training val does not repeat search
107
128
  self.args.device = os.getenv("CUDA_VISIBLE_DEVICES") if "cuda" in str(self.device) else str(self.device)
108
129
  self.validator = None
109
130
  self.metrics = None
@@ -149,15 +170,32 @@ class BaseTrainer:
149
170
  self.tloss = None
150
171
  self.loss_names = ["Loss"]
151
172
  self.csv = self.save_dir / "results.csv"
173
+ if self.csv.exists() and not self.args.resume:
174
+ self.csv.unlink()
152
175
  self.plot_idx = [0, 1, 2]
153
-
154
- # HUB
155
- self.hub_session = None
176
+ self.nan_recovery_attempts = 0
156
177
 
157
178
  # Callbacks
158
179
  self.callbacks = _callbacks or callbacks.get_default_callbacks()
159
- if RANK in {-1, 0}:
180
+
181
+ if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3'
182
+ world_size = len(self.args.device.split(","))
183
+ elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
184
+ world_size = len(self.args.device)
185
+ elif self.args.device in {"cpu", "mps"}: # i.e. device='cpu' or 'mps'
186
+ world_size = 0
187
+ elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number
188
+ world_size = 1 # default to device 0
189
+ else: # i.e. device=None or device=''
190
+ world_size = 0
191
+
192
+ self.ddp = world_size > 1 and "LOCAL_RANK" not in os.environ
193
+ self.world_size = world_size
194
+ # Run subprocess if DDP training, else train normally
195
+ if RANK in {-1, 0} and not self.ddp:
160
196
  callbacks.add_integration_callbacks(self)
197
+ # Start console logging immediately at trainer initialization
198
+ self.run_callbacks("on_pretrain_routine_start")
161
199
 
162
200
  def add_callback(self, event: str, callback):
163
201
  """Append the given callback to the event's callback list."""
@@ -174,31 +212,20 @@ class BaseTrainer:
174
212
 
175
213
  def train(self):
176
214
  """Allow device='', device=None on Multi-GPU systems to default to device=0."""
177
- if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3'
178
- world_size = len(self.args.device.split(","))
179
- elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
180
- world_size = len(self.args.device)
181
- elif self.args.device in {"cpu", "mps"}: # i.e. device='cpu' or 'mps'
182
- world_size = 0
183
- elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number
184
- world_size = 1 # default to device 0
185
- else: # i.e. device=None or device=''
186
- world_size = 0
187
-
188
215
  # Run subprocess if DDP training, else train normally
189
- if world_size > 1 and "LOCAL_RANK" not in os.environ:
216
+ if self.ddp:
190
217
  # Argument checks
191
218
  if self.args.rect:
192
219
  LOGGER.warning("'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'")
193
220
  self.args.rect = False
194
221
  if self.args.batch < 1.0:
195
- LOGGER.warning(
196
- "'batch<1' for AutoBatch is incompatible with Multi-GPU training, setting default 'batch=16'"
222
+ raise ValueError(
223
+ "AutoBatch with batch<1 not supported for Multi-GPU training, "
224
+ f"please specify a valid batch size multiple of GPU count {self.world_size}, i.e. batch={self.world_size * 8}."
197
225
  )
198
- self.args.batch = 16
199
226
 
200
227
  # Command
201
- cmd, file = generate_ddp_command(world_size, self)
228
+ cmd, file = generate_ddp_command(self)
202
229
  try:
203
230
  LOGGER.info(f"{colorstr('DDP:')} debug command {' '.join(cmd)}")
204
231
  subprocess.run(cmd, check=True)
@@ -208,7 +235,7 @@ class BaseTrainer:
208
235
  ddp_cleanup(self, str(file))
209
236
 
210
237
  else:
211
- self._do_train(world_size)
238
+ self._do_train()
212
239
 
213
240
  def _setup_scheduler(self):
214
241
  """Initialize training learning rate scheduler."""
@@ -218,27 +245,27 @@ class BaseTrainer:
218
245
  self.lf = lambda x: max(1 - x / self.epochs, 0) * (1.0 - self.args.lrf) + self.args.lrf # linear
219
246
  self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
220
247
 
221
- def _setup_ddp(self, world_size):
248
+ def _setup_ddp(self):
222
249
  """Initialize and set the DistributedDataParallel parameters for training."""
223
250
  torch.cuda.set_device(RANK)
224
251
  self.device = torch.device("cuda", RANK)
225
- # LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
226
252
  os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout
227
253
  dist.init_process_group(
228
254
  backend="nccl" if dist.is_nccl_available() else "gloo",
229
255
  timeout=timedelta(seconds=10800), # 3 hours
230
256
  rank=RANK,
231
- world_size=world_size,
257
+ world_size=self.world_size,
232
258
  )
233
259
 
234
- def _setup_train(self, world_size):
260
+ def _setup_train(self):
235
261
  """Build dataloaders and optimizer on correct rank process."""
236
- # Model
237
- self.run_callbacks("on_pretrain_routine_start")
238
262
  ckpt = self.setup_model()
239
263
  self.model = self.model.to(self.device)
240
264
  self.set_model_attributes()
241
265
 
266
+ # Compile model
267
+ self.model = attempt_compile(self.model, device=self.device, mode=self.args.compile)
268
+
242
269
  # Freeze layers
243
270
  freeze_list = (
244
271
  self.args.freeze
@@ -268,13 +295,13 @@ class BaseTrainer:
268
295
  callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
269
296
  self.amp = torch.tensor(check_amp(self.model), device=self.device)
270
297
  callbacks.default_callbacks = callbacks_backup # restore callbacks
271
- if RANK > -1 and world_size > 1: # DDP
298
+ if RANK > -1 and self.world_size > 1: # DDP
272
299
  dist.broadcast(self.amp.int(), src=0) # broadcast from rank 0 to all other ranks; gloo errors with boolean
273
300
  self.amp = bool(self.amp) # as boolean
274
301
  self.scaler = (
275
302
  torch.amp.GradScaler("cuda", enabled=self.amp) if TORCH_2_4 else torch.cuda.amp.GradScaler(enabled=self.amp)
276
303
  )
277
- if world_size > 1:
304
+ if self.world_size > 1:
278
305
  self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK], find_unused_parameters=True)
279
306
 
280
307
  # Check imgsz
@@ -287,22 +314,22 @@ class BaseTrainer:
287
314
  self.args.batch = self.batch_size = self.auto_batch()
288
315
 
289
316
  # Dataloaders
290
- batch_size = self.batch_size // max(world_size, 1)
317
+ batch_size = self.batch_size // max(self.world_size, 1)
291
318
  self.train_loader = self.get_dataloader(
292
319
  self.data["train"], batch_size=batch_size, rank=LOCAL_RANK, mode="train"
293
320
  )
321
+ # Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
322
+ self.test_loader = self.get_dataloader(
323
+ self.data.get("val") or self.data.get("test"),
324
+ batch_size=batch_size if self.args.task == "obb" else batch_size * 2,
325
+ rank=LOCAL_RANK,
326
+ mode="val",
327
+ )
328
+ self.validator = self.get_validator()
329
+ self.ema = ModelEMA(self.model)
294
330
  if RANK in {-1, 0}:
295
- # Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
296
- self.test_loader = self.get_dataloader(
297
- self.data.get("val") or self.data.get("test"),
298
- batch_size=batch_size if self.args.task == "obb" else batch_size * 2,
299
- rank=-1,
300
- mode="val",
301
- )
302
- self.validator = self.get_validator()
303
331
  metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val")
304
332
  self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
305
- self.ema = ModelEMA(self.model)
306
333
  if self.args.plots:
307
334
  self.plot_training_labels()
308
335
 
@@ -325,11 +352,11 @@ class BaseTrainer:
325
352
  self.scheduler.last_epoch = self.start_epoch - 1 # do not move
326
353
  self.run_callbacks("on_pretrain_routine_end")
327
354
 
328
- def _do_train(self, world_size=1):
355
+ def _do_train(self):
329
356
  """Train the model with the specified world size."""
330
- if world_size > 1:
331
- self._setup_ddp(world_size)
332
- self._setup_train(world_size)
357
+ if self.world_size > 1:
358
+ self._setup_ddp()
359
+ self._setup_train()
333
360
 
334
361
  nb = len(self.train_loader) # number of batches
335
362
  nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1 # warmup iterations
@@ -340,7 +367,7 @@ class BaseTrainer:
340
367
  self.run_callbacks("on_train_start")
341
368
  LOGGER.info(
342
369
  f"Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n"
343
- f"Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n"
370
+ f"Using {self.train_loader.num_workers * (self.world_size or 1)} dataloader workers\n"
344
371
  f"Logging results to {colorstr('bold', self.save_dir)}\n"
345
372
  f"Starting training for " + (f"{self.args.time} hours..." if self.args.time else f"{self.epochs} epochs...")
346
373
  )
@@ -387,18 +414,19 @@ class BaseTrainer:
387
414
  # Forward
388
415
  with autocast(self.amp):
389
416
  batch = self.preprocess_batch(batch)
390
- loss, self.loss_items = self.model(batch)
417
+ if self.args.compile:
418
+ # Decouple inference and loss calculations for improved compile performance
419
+ preds = self.model(batch["img"])
420
+ loss, self.loss_items = unwrap_model(self.model).loss(batch, preds)
421
+ else:
422
+ loss, self.loss_items = self.model(batch)
391
423
  self.loss = loss.sum()
392
424
  if RANK != -1:
393
- self.loss *= world_size
394
- self.tloss = (
395
- (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items
396
- )
425
+ self.loss *= self.world_size
426
+ self.tloss = self.loss_items if self.tloss is None else (self.tloss * i + self.loss_items) / (i + 1)
397
427
 
398
428
  # Backward
399
429
  self.scaler.scale(self.loss).backward()
400
-
401
- # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
402
430
  if ni - last_opt_step >= self.accumulate:
403
431
  self.optimizer_step()
404
432
  last_opt_step = ni
@@ -433,14 +461,23 @@ class BaseTrainer:
433
461
  self.run_callbacks("on_train_batch_end")
434
462
 
435
463
  self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
464
+
436
465
  self.run_callbacks("on_train_epoch_end")
437
466
  if RANK in {-1, 0}:
438
467
  final_epoch = epoch + 1 >= self.epochs
439
468
  self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
440
469
 
441
- # Validation
442
- if self.args.val or final_epoch or self.stopper.possible_stop or self.stop:
443
- self.metrics, self.fitness = self.validate()
470
+ # Validation
471
+ if self.args.val or final_epoch or self.stopper.possible_stop or self.stop:
472
+ self._clear_memory(threshold=0.5) # prevent VRAM spike
473
+ self.metrics, self.fitness = self.validate()
474
+
475
+ # NaN recovery
476
+ if self._handle_nan_recovery(epoch):
477
+ continue
478
+
479
+ self.nan_recovery_attempts = 0
480
+ if RANK in {-1, 0}:
444
481
  self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr})
445
482
  self.stop |= self.stopper(epoch + 1, self.fitness) or final_epoch
446
483
  if self.args.time:
@@ -462,8 +499,7 @@ class BaseTrainer:
462
499
  self.scheduler.last_epoch = self.epoch # do not move
463
500
  self.stop |= epoch >= self.epochs # stop if exceeded epochs
464
501
  self.run_callbacks("on_fit_epoch_end")
465
- if self._get_memory(fraction=True) > 0.5:
466
- self._clear_memory() # clear if memory utilization > 50%
502
+ self._clear_memory(0.5) # clear if memory utilization > 50%
467
503
 
468
504
  # Early Stopping
469
505
  if RANK != -1: # if DDP training
@@ -474,11 +510,11 @@ class BaseTrainer:
474
510
  break # must break all DDP ranks
475
511
  epoch += 1
476
512
 
513
+ seconds = time.time() - self.train_time_start
514
+ LOGGER.info(f"\n{epoch - self.start_epoch + 1} epochs completed in {seconds / 3600:.3f} hours.")
515
+ # Do final val with best.pt
516
+ self.final_eval()
477
517
  if RANK in {-1, 0}:
478
- # Do final val with best.pt
479
- seconds = time.time() - self.train_time_start
480
- LOGGER.info(f"\n{epoch - self.start_epoch + 1} epochs completed in {seconds / 3600:.3f} hours.")
481
- self.final_eval()
482
518
  if self.args.plots:
483
519
  self.plot_metrics()
484
520
  self.run_callbacks("on_train_end")
@@ -509,8 +545,12 @@ class BaseTrainer:
509
545
  total = torch.cuda.get_device_properties(self.device).total_memory
510
546
  return ((memory / total) if total > 0 else 0) if fraction else (memory / 2**30)
511
547
 
512
- def _clear_memory(self):
548
+ def _clear_memory(self, threshold: float | None = None):
513
549
  """Clear accelerator memory by calling garbage collector and emptying cache."""
550
+ if threshold:
551
+ assert 0 <= threshold <= 1, "Threshold must be between 0 and 1."
552
+ if self._get_memory(fraction=True) <= threshold:
553
+ return
514
554
  gc.collect()
515
555
  if self.device.type == "mps":
516
556
  torch.mps.empty_cache()
@@ -520,10 +560,13 @@ class BaseTrainer:
520
560
  torch.cuda.empty_cache()
521
561
 
522
562
  def read_results_csv(self):
523
- """Read results.csv into a dictionary using pandas."""
524
- import pandas as pd # scope for faster 'import ultralytics'
563
+ """Read results.csv into a dictionary using polars."""
564
+ import polars as pl # scope for faster 'import ultralytics'
525
565
 
526
- return pd.read_csv(self.csv).to_dict(orient="list")
566
+ try:
567
+ return pl.read_csv(self.csv, infer_schema_length=None).to_dict(as_series=False)
568
+ except Exception:
569
+ return {}
527
570
 
528
571
  def _model_train(self):
529
572
  """Set model in training mode."""
@@ -544,14 +587,21 @@ class BaseTrainer:
544
587
  "epoch": self.epoch,
545
588
  "best_fitness": self.best_fitness,
546
589
  "model": None, # resume and final checkpoints derive from EMA
547
- "ema": deepcopy(self.ema.ema).half(),
590
+ "ema": deepcopy(unwrap_model(self.ema.ema)).half(),
548
591
  "updates": self.ema.updates,
549
592
  "optimizer": convert_optimizer_state_dict_to_fp16(deepcopy(self.optimizer.state_dict())),
593
+ "scaler": self.scaler.state_dict(),
550
594
  "train_args": vars(self.args), # save as dict
551
595
  "train_metrics": {**self.metrics, **{"fitness": self.fitness}},
552
596
  "train_results": self.read_results_csv(),
553
597
  "date": datetime.now().isoformat(),
554
598
  "version": __version__,
599
+ "git": {
600
+ "root": str(GIT.root),
601
+ "branch": GIT.branch,
602
+ "commit": GIT.commit,
603
+ "origin": GIT.origin,
604
+ },
555
605
  "license": "AGPL-3.0 (https://ultralytics.com/license)",
556
606
  "docs": "https://docs.ultralytics.com",
557
607
  },
@@ -560,17 +610,15 @@ class BaseTrainer:
560
610
  serialized_ckpt = buffer.getvalue() # get the serialized content to save
561
611
 
562
612
  # Save checkpoints
613
+ self.wdir.mkdir(parents=True, exist_ok=True) # ensure weights directory exists
563
614
  self.last.write_bytes(serialized_ckpt) # save last.pt
564
615
  if self.best_fitness == self.fitness:
565
616
  self.best.write_bytes(serialized_ckpt) # save best.pt
566
617
  if (self.save_period > 0) and (self.epoch % self.save_period == 0):
567
618
  (self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'
568
- # if self.args.close_mosaic and self.epoch == (self.epochs - self.args.close_mosaic - 1):
569
- # (self.wdir / "last_mosaic.pt").write_bytes(serialized_ckpt) # save mosaic checkpoint
570
619
 
571
620
  def get_dataset(self):
572
- """
573
- Get train and validation datasets from data dictionary.
621
+ """Get train and validation datasets from data dictionary.
574
622
 
575
623
  Returns:
576
624
  (dict): A dictionary containing the training/validation/test dataset and category names.
@@ -578,7 +626,16 @@ class BaseTrainer:
578
626
  try:
579
627
  if self.args.task == "classify":
580
628
  data = check_cls_dataset(self.args.data)
581
- elif self.args.data.split(".")[-1] in {"yaml", "yml"} or self.args.task in {
629
+ elif self.args.data.rsplit(".", 1)[-1] == "ndjson":
630
+ # Convert NDJSON to YOLO format
631
+ import asyncio
632
+
633
+ from ultralytics.data.converter import convert_ndjson_to_yolo
634
+
635
+ yaml_path = asyncio.run(convert_ndjson_to_yolo(self.args.data))
636
+ self.args.data = str(yaml_path)
637
+ data = check_det_dataset(self.args.data)
638
+ elif self.args.data.rsplit(".", 1)[-1] in {"yaml", "yml"} or self.args.task in {
582
639
  "detect",
583
640
  "segment",
584
641
  "pose",
@@ -596,8 +653,7 @@ class BaseTrainer:
596
653
  return data
597
654
 
598
655
  def setup_model(self):
599
- """
600
- Load, create, or download model for any task.
656
+ """Load, create, or download model for any task.
601
657
 
602
658
  Returns:
603
659
  (dict): Optional checkpoint to resume training from.
@@ -608,17 +664,17 @@ class BaseTrainer:
608
664
  cfg, weights = self.model, None
609
665
  ckpt = None
610
666
  if str(self.model).endswith(".pt"):
611
- weights, ckpt = attempt_load_one_weight(self.model)
667
+ weights, ckpt = load_checkpoint(self.model)
612
668
  cfg = weights.yaml
613
669
  elif isinstance(self.args.pretrained, (str, Path)):
614
- weights, _ = attempt_load_one_weight(self.args.pretrained)
670
+ weights, _ = load_checkpoint(self.args.pretrained)
615
671
  self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
616
672
  return ckpt
617
673
 
618
674
  def optimizer_step(self):
619
675
  """Perform a single step of the training optimizer with gradient clipping and EMA update."""
620
676
  self.scaler.unscale_(self.optimizer) # unscale gradients
621
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0) # clip gradients
677
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)
622
678
  self.scaler.step(self.optimizer)
623
679
  self.scaler.update()
624
680
  self.optimizer.zero_grad()
@@ -626,17 +682,23 @@ class BaseTrainer:
626
682
  self.ema.update(self.model)
627
683
 
628
684
  def preprocess_batch(self, batch):
629
- """Allows custom preprocessing model inputs and ground truths depending on task type."""
685
+ """Allow custom preprocessing model inputs and ground truths depending on task type."""
630
686
  return batch
631
687
 
632
688
  def validate(self):
633
- """
634
- Run validation on test set using self.validator.
689
+ """Run validation on val set using self.validator.
635
690
 
636
691
  Returns:
637
- (tuple): A tuple containing metrics dictionary and fitness score.
692
+ metrics (dict): Dictionary of validation metrics.
693
+ fitness (float): Fitness score for the validation.
638
694
  """
695
+ if self.ema and self.world_size > 1:
696
+ # Sync EMA buffers from rank 0 to all ranks
697
+ for buffer in self.ema.ema.buffers():
698
+ dist.broadcast(buffer, src=0)
639
699
  metrics = self.validator(self)
700
+ if metrics is None:
701
+ return None, None
640
702
  fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
641
703
  if not self.best_fitness or self.best_fitness < fitness:
642
704
  self.best_fitness = fitness
@@ -647,11 +709,11 @@ class BaseTrainer:
647
709
  raise NotImplementedError("This task trainer doesn't support loading cfg files")
648
710
 
649
711
  def get_validator(self):
650
- """Returns a NotImplementedError when the get_validator function is called."""
712
+ """Return a NotImplementedError when the get_validator function is called."""
651
713
  raise NotImplementedError("get_validator function not implemented in trainer")
652
714
 
653
715
  def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
654
- """Returns dataloader derived from torch.data.Dataloader."""
716
+ """Return dataloader derived from torch.data.Dataloader."""
655
717
  raise NotImplementedError("get_dataloader function not implemented in trainer")
656
718
 
657
719
  def build_dataset(self, img_path, mode="train", batch=None):
@@ -659,10 +721,9 @@ class BaseTrainer:
659
721
  raise NotImplementedError("build_dataset function not implemented in trainer")
660
722
 
661
723
  def label_loss_items(self, loss_items=None, prefix="train"):
662
- """
663
- Returns a loss dict with labelled training loss items tensor.
724
+ """Return a loss dict with labeled training loss items tensor.
664
725
 
665
- Note:
726
+ Notes:
666
727
  This is not needed for classification but necessary for segmentation & detection
667
728
  """
668
729
  return {"loss": loss_items} if loss_items is not None else ["loss"]
@@ -672,55 +733,57 @@ class BaseTrainer:
672
733
  self.model.names = self.data["names"]
673
734
 
674
735
  def build_targets(self, preds, targets):
675
- """Builds target tensors for training YOLO model."""
736
+ """Build target tensors for training YOLO model."""
676
737
  pass
677
738
 
678
739
  def progress_string(self):
679
- """Returns a string describing training progress."""
740
+ """Return a string describing training progress."""
680
741
  return ""
681
742
 
682
743
  # TODO: may need to put these following functions into callback
683
744
  def plot_training_samples(self, batch, ni):
684
- """Plots training samples during YOLO training."""
745
+ """Plot training samples during YOLO training."""
685
746
  pass
686
747
 
687
748
  def plot_training_labels(self):
688
- """Plots training labels for YOLO model."""
749
+ """Plot training labels for YOLO model."""
689
750
  pass
690
751
 
691
752
  def save_metrics(self, metrics):
692
753
  """Save training metrics to a CSV file."""
693
754
  keys, vals = list(metrics.keys()), list(metrics.values())
694
755
  n = len(metrics) + 2 # number of cols
695
- s = "" if self.csv.exists() else (("%s," * n % tuple(["epoch", "time"] + keys)).rstrip(",") + "\n") # header
696
756
  t = time.time() - self.train_time_start
757
+ self.csv.parent.mkdir(parents=True, exist_ok=True) # ensure parent directory exists
758
+ s = "" if self.csv.exists() else (("%s," * n % tuple(["epoch", "time", *keys])).rstrip(",") + "\n") # header
697
759
  with open(self.csv, "a", encoding="utf-8") as f:
698
- f.write(s + ("%.6g," * n % tuple([self.epoch + 1, t] + vals)).rstrip(",") + "\n")
760
+ f.write(s + ("%.6g," * n % tuple([self.epoch + 1, t, *vals])).rstrip(",") + "\n")
699
761
 
700
762
  def plot_metrics(self):
701
- """Plot and display metrics visually."""
702
- pass
763
+ """Plot metrics from a CSV file."""
764
+ plot_results(file=self.csv, on_plot=self.on_plot) # save results.png
703
765
 
704
766
  def on_plot(self, name, data=None):
705
- """Registers plots (e.g. to be consumed in callbacks)."""
767
+ """Register plots (e.g. to be consumed in callbacks)."""
706
768
  path = Path(name)
707
769
  self.plots[path] = {"data": data, "timestamp": time.time()}
708
770
 
709
771
  def final_eval(self):
710
772
  """Perform final evaluation and validation for object detection YOLO model."""
711
- ckpt = {}
712
- for f in self.last, self.best:
713
- if f.exists():
714
- if f is self.last:
715
- ckpt = strip_optimizer(f)
716
- elif f is self.best:
717
- k = "train_results" # update best.pt train_metrics from last.pt
718
- strip_optimizer(f, updates={k: ckpt[k]} if k in ckpt else None)
719
- LOGGER.info(f"\nValidating {f}...")
720
- self.validator.args.plots = self.args.plots
721
- self.metrics = self.validator(model=f)
722
- self.metrics.pop("fitness", None)
723
- self.run_callbacks("on_fit_epoch_end")
773
+ model = self.best if self.best.exists() else None
774
+ with torch_distributed_zero_first(LOCAL_RANK): # strip only on GPU 0; other GPUs should wait
775
+ if RANK in {-1, 0}:
776
+ ckpt = strip_optimizer(self.last) if self.last.exists() else {}
777
+ if model:
778
+ # update best.pt train_metrics from last.pt
779
+ strip_optimizer(self.best, updates={"train_results": ckpt.get("train_results")})
780
+ if model:
781
+ LOGGER.info(f"\nValidating {model}...")
782
+ self.validator.args.plots = self.args.plots
783
+ self.validator.args.compile = False # disable final val compile as too slow
784
+ self.metrics = self.validator(model=model)
785
+ self.metrics.pop("fitness", None)
786
+ self.run_callbacks("on_fit_epoch_end")
724
787
 
725
788
  def check_resume(self, overrides):
726
789
  """Check if resume checkpoint exists and update arguments accordingly."""
@@ -731,7 +794,7 @@ class BaseTrainer:
731
794
  last = Path(check_file(resume) if exists else get_latest_run())
732
795
 
733
796
  # Check that resume data YAML exists, otherwise strip to force re-download of dataset
734
- ckpt_args = attempt_load_weights(last).args
797
+ ckpt_args = load_checkpoint(last)[0].args
735
798
  if not isinstance(ckpt_args["data"], dict) and not Path(ckpt_args["data"]).exists():
736
799
  ckpt_args["data"] = self.args.data
737
800
 
@@ -754,18 +817,54 @@ class BaseTrainer:
754
817
  ) from e
755
818
  self.resume = resume
756
819
 
820
+ def _load_checkpoint_state(self, ckpt):
821
+ """Load optimizer, scaler, EMA, and best_fitness from checkpoint."""
822
+ if ckpt.get("optimizer") is not None:
823
+ self.optimizer.load_state_dict(ckpt["optimizer"])
824
+ if ckpt.get("scaler") is not None:
825
+ self.scaler.load_state_dict(ckpt["scaler"])
826
+ if self.ema and ckpt.get("ema"):
827
+ self.ema = ModelEMA(self.model) # validation with EMA creates inference tensors that can't be updated
828
+ self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict())
829
+ self.ema.updates = ckpt["updates"]
830
+ self.best_fitness = ckpt.get("best_fitness", 0.0)
831
+
832
+ def _handle_nan_recovery(self, epoch):
833
+ """Detect and recover from NaN/Inf loss and fitness collapse by loading last checkpoint."""
834
+ loss_nan = self.loss is not None and not self.loss.isfinite()
835
+ fitness_nan = self.fitness is not None and not np.isfinite(self.fitness)
836
+ fitness_collapse = self.best_fitness and self.best_fitness > 0 and self.fitness == 0
837
+ corrupted = RANK in {-1, 0} and loss_nan and (fitness_nan or fitness_collapse)
838
+ reason = "Loss NaN/Inf" if loss_nan else "Fitness NaN/Inf" if fitness_nan else "Fitness collapse"
839
+ if RANK != -1: # DDP: broadcast to all ranks
840
+ broadcast_list = [corrupted if RANK == 0 else None]
841
+ dist.broadcast_object_list(broadcast_list, 0)
842
+ corrupted = broadcast_list[0]
843
+ if not corrupted:
844
+ return False
845
+ if epoch == self.start_epoch or not self.last.exists():
846
+ LOGGER.warning(f"{reason} detected but can not recover from last.pt...")
847
+ return False # Cannot recover on first epoch, let training continue
848
+ self.nan_recovery_attempts += 1
849
+ if self.nan_recovery_attempts > 3:
850
+ raise RuntimeError(f"Training failed: NaN persisted for {self.nan_recovery_attempts} epochs")
851
+ LOGGER.warning(f"{reason} detected (attempt {self.nan_recovery_attempts}/3), recovering from last.pt...")
852
+ self._model_train() # set model to train mode before loading checkpoint to avoid inference tensor errors
853
+ _, ckpt = load_checkpoint(self.last)
854
+ ema_state = ckpt["ema"].float().state_dict()
855
+ if not all(torch.isfinite(v).all() for v in ema_state.values() if isinstance(v, torch.Tensor)):
856
+ raise RuntimeError(f"Checkpoint {self.last} is corrupted with NaN/Inf weights")
857
+ unwrap_model(self.model).load_state_dict(ema_state) # Load EMA weights into model
858
+ self._load_checkpoint_state(ckpt) # Load optimizer/scaler/EMA/best_fitness
859
+ del ckpt, ema_state
860
+ self.scheduler.last_epoch = epoch - 1
861
+ return True
862
+
757
863
  def resume_training(self, ckpt):
758
864
  """Resume YOLO training from given epoch and best fitness."""
759
865
  if ckpt is None or not self.resume:
760
866
  return
761
- best_fitness = 0.0
762
867
  start_epoch = ckpt.get("epoch", -1) + 1
763
- if ckpt.get("optimizer", None) is not None:
764
- self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer
765
- best_fitness = ckpt["best_fitness"]
766
- if self.ema and ckpt.get("ema"):
767
- self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict()) # EMA
768
- self.ema.updates = ckpt["updates"]
769
868
  assert start_epoch > 0, (
770
869
  f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n"
771
870
  f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
@@ -776,7 +875,7 @@ class BaseTrainer:
776
875
  f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
777
876
  )
778
877
  self.epochs += ckpt["epoch"] # finetune additional epochs
779
- self.best_fitness = best_fitness
878
+ self._load_checkpoint_state(ckpt)
780
879
  self.start_epoch = start_epoch
781
880
  if start_epoch > (self.epochs - self.args.close_mosaic):
782
881
  self._close_dataloader_mosaic()
@@ -790,18 +889,16 @@ class BaseTrainer:
790
889
  self.train_loader.dataset.close_mosaic(hyp=copy(self.args))
791
890
 
792
891
  def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
793
- """
794
- Construct an optimizer for the given model.
892
+ """Construct an optimizer for the given model.
795
893
 
796
894
  Args:
797
895
  model (torch.nn.Module): The model for which to build an optimizer.
798
- name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected
799
- based on the number of iterations. Default: 'auto'.
800
- lr (float, optional): The learning rate for the optimizer. Default: 0.001.
801
- momentum (float, optional): The momentum factor for the optimizer. Default: 0.9.
802
- decay (float, optional): The weight decay for the optimizer. Default: 1e-5.
803
- iterations (float, optional): The number of iterations, which determines the optimizer if
804
- name is 'auto'. Default: 1e5.
896
+ name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected based on the
897
+ number of iterations.
898
+ lr (float, optional): The learning rate for the optimizer.
899
+ momentum (float, optional): The momentum factor for the optimizer.
900
+ decay (float, optional): The weight decay for the optimizer.
901
+ iterations (float, optional): The number of iterations, which determines the optimizer if name is 'auto'.
805
902
 
806
903
  Returns:
807
904
  (torch.optim.Optimizer): The constructed optimizer.