ultralytics 8.3.89__py3-none-any.whl → 8.3.90__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (155) hide show
  1. tests/conftest.py +2 -2
  2. tests/test_cli.py +13 -11
  3. tests/test_cuda.py +10 -1
  4. tests/test_integrations.py +1 -5
  5. tests/test_python.py +16 -16
  6. tests/test_solutions.py +9 -9
  7. ultralytics/__init__.py +1 -1
  8. ultralytics/cfg/__init__.py +3 -1
  9. ultralytics/cfg/models/11/yolo11-cls.yaml +5 -5
  10. ultralytics/cfg/models/11/yolo11-obb.yaml +5 -5
  11. ultralytics/cfg/models/11/yolo11-pose.yaml +5 -5
  12. ultralytics/cfg/models/11/yolo11-seg.yaml +5 -5
  13. ultralytics/cfg/models/11/yolo11.yaml +5 -5
  14. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +5 -5
  15. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +5 -5
  16. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -5
  17. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -5
  18. ultralytics/cfg/models/v8/yolov8-p6.yaml +5 -5
  19. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -5
  20. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -5
  21. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -5
  22. ultralytics/cfg/models/v8/yolov8.yaml +5 -5
  23. ultralytics/cfg/models/v9/yolov9c-seg.yaml +1 -1
  24. ultralytics/cfg/models/v9/yolov9c.yaml +1 -1
  25. ultralytics/cfg/models/v9/yolov9e-seg.yaml +1 -1
  26. ultralytics/cfg/models/v9/yolov9e.yaml +1 -1
  27. ultralytics/cfg/models/v9/yolov9m.yaml +1 -1
  28. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  29. ultralytics/cfg/models/v9/yolov9t.yaml +1 -1
  30. ultralytics/data/annotator.py +9 -14
  31. ultralytics/data/base.py +118 -30
  32. ultralytics/data/build.py +63 -24
  33. ultralytics/data/converter.py +5 -5
  34. ultralytics/data/dataset.py +207 -53
  35. ultralytics/data/loaders.py +1 -0
  36. ultralytics/data/split_dota.py +39 -12
  37. ultralytics/data/utils.py +13 -19
  38. ultralytics/engine/exporter.py +19 -17
  39. ultralytics/engine/model.py +67 -88
  40. ultralytics/engine/predictor.py +106 -21
  41. ultralytics/engine/trainer.py +32 -23
  42. ultralytics/engine/tuner.py +21 -18
  43. ultralytics/engine/validator.py +75 -41
  44. ultralytics/hub/__init__.py +12 -13
  45. ultralytics/hub/auth.py +9 -12
  46. ultralytics/hub/session.py +76 -21
  47. ultralytics/hub/utils.py +19 -17
  48. ultralytics/models/fastsam/model.py +20 -11
  49. ultralytics/models/fastsam/predict.py +36 -16
  50. ultralytics/models/fastsam/utils.py +5 -5
  51. ultralytics/models/fastsam/val.py +6 -6
  52. ultralytics/models/nas/model.py +22 -11
  53. ultralytics/models/nas/predict.py +9 -4
  54. ultralytics/models/nas/val.py +5 -5
  55. ultralytics/models/rtdetr/model.py +20 -11
  56. ultralytics/models/rtdetr/predict.py +18 -15
  57. ultralytics/models/rtdetr/train.py +20 -16
  58. ultralytics/models/rtdetr/val.py +42 -6
  59. ultralytics/models/sam/__init__.py +1 -1
  60. ultralytics/models/sam/amg.py +50 -4
  61. ultralytics/models/sam/model.py +8 -14
  62. ultralytics/models/sam/modules/decoders.py +18 -21
  63. ultralytics/models/sam/modules/encoders.py +25 -46
  64. ultralytics/models/sam/modules/memory_attention.py +19 -15
  65. ultralytics/models/sam/modules/sam.py +18 -25
  66. ultralytics/models/sam/modules/tiny_encoder.py +19 -29
  67. ultralytics/models/sam/modules/transformer.py +35 -57
  68. ultralytics/models/sam/modules/utils.py +15 -15
  69. ultralytics/models/sam/predict.py +0 -3
  70. ultralytics/models/utils/loss.py +87 -36
  71. ultralytics/models/utils/ops.py +26 -31
  72. ultralytics/models/yolo/classify/predict.py +24 -3
  73. ultralytics/models/yolo/classify/train.py +77 -10
  74. ultralytics/models/yolo/classify/val.py +40 -15
  75. ultralytics/models/yolo/detect/predict.py +23 -10
  76. ultralytics/models/yolo/detect/train.py +85 -15
  77. ultralytics/models/yolo/detect/val.py +145 -21
  78. ultralytics/models/yolo/model.py +1 -2
  79. ultralytics/models/yolo/obb/predict.py +12 -4
  80. ultralytics/models/yolo/obb/train.py +7 -0
  81. ultralytics/models/yolo/obb/val.py +25 -7
  82. ultralytics/models/yolo/pose/predict.py +22 -6
  83. ultralytics/models/yolo/pose/train.py +17 -1
  84. ultralytics/models/yolo/pose/val.py +46 -21
  85. ultralytics/models/yolo/segment/predict.py +22 -8
  86. ultralytics/models/yolo/segment/train.py +6 -0
  87. ultralytics/models/yolo/segment/val.py +100 -14
  88. ultralytics/models/yolo/world/train.py +38 -8
  89. ultralytics/models/yolo/world/train_world.py +39 -10
  90. ultralytics/nn/autobackend.py +28 -14
  91. ultralytics/nn/modules/__init__.py +3 -0
  92. ultralytics/nn/modules/activation.py +12 -3
  93. ultralytics/nn/modules/block.py +587 -84
  94. ultralytics/nn/modules/conv.py +418 -54
  95. ultralytics/nn/modules/head.py +3 -4
  96. ultralytics/nn/modules/transformer.py +320 -34
  97. ultralytics/nn/modules/utils.py +17 -3
  98. ultralytics/nn/tasks.py +221 -69
  99. ultralytics/solutions/ai_gym.py +2 -2
  100. ultralytics/solutions/analytics.py +4 -4
  101. ultralytics/solutions/heatmap.py +4 -4
  102. ultralytics/solutions/instance_segmentation.py +10 -4
  103. ultralytics/solutions/object_blurrer.py +2 -2
  104. ultralytics/solutions/object_counter.py +2 -2
  105. ultralytics/solutions/object_cropper.py +2 -2
  106. ultralytics/solutions/parking_management.py +9 -9
  107. ultralytics/solutions/queue_management.py +1 -1
  108. ultralytics/solutions/region_counter.py +2 -2
  109. ultralytics/solutions/security_alarm.py +7 -7
  110. ultralytics/solutions/solutions.py +7 -4
  111. ultralytics/solutions/speed_estimation.py +2 -2
  112. ultralytics/solutions/streamlit_inference.py +6 -6
  113. ultralytics/solutions/trackzone.py +9 -2
  114. ultralytics/solutions/vision_eye.py +4 -4
  115. ultralytics/trackers/basetrack.py +1 -1
  116. ultralytics/trackers/bot_sort.py +23 -22
  117. ultralytics/trackers/byte_tracker.py +4 -4
  118. ultralytics/trackers/track.py +2 -1
  119. ultralytics/trackers/utils/gmc.py +26 -27
  120. ultralytics/trackers/utils/kalman_filter.py +31 -29
  121. ultralytics/trackers/utils/matching.py +7 -7
  122. ultralytics/utils/__init__.py +32 -27
  123. ultralytics/utils/autobatch.py +5 -5
  124. ultralytics/utils/benchmarks.py +111 -18
  125. ultralytics/utils/callbacks/base.py +3 -3
  126. ultralytics/utils/callbacks/clearml.py +11 -11
  127. ultralytics/utils/callbacks/comet.py +35 -22
  128. ultralytics/utils/callbacks/dvc.py +11 -10
  129. ultralytics/utils/callbacks/hub.py +8 -8
  130. ultralytics/utils/callbacks/mlflow.py +1 -1
  131. ultralytics/utils/callbacks/neptune.py +12 -10
  132. ultralytics/utils/callbacks/raytune.py +1 -1
  133. ultralytics/utils/callbacks/tensorboard.py +6 -6
  134. ultralytics/utils/callbacks/wb.py +16 -16
  135. ultralytics/utils/checks.py +116 -35
  136. ultralytics/utils/dist.py +15 -2
  137. ultralytics/utils/downloads.py +13 -9
  138. ultralytics/utils/files.py +12 -13
  139. ultralytics/utils/instance.py +112 -45
  140. ultralytics/utils/loss.py +28 -33
  141. ultralytics/utils/metrics.py +246 -181
  142. ultralytics/utils/ops.py +61 -53
  143. ultralytics/utils/patches.py +8 -6
  144. ultralytics/utils/plotting.py +64 -45
  145. ultralytics/utils/tal.py +88 -57
  146. ultralytics/utils/torch_utils.py +181 -33
  147. ultralytics/utils/triton.py +13 -3
  148. ultralytics/utils/tuner.py +8 -16
  149. {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/METADATA +1 -1
  150. ultralytics-8.3.90.dist-info/RECORD +250 -0
  151. ultralytics-8.3.89.dist-info/RECORD +0 -250
  152. {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/LICENSE +0 -0
  153. {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/WHEEL +0 -0
  154. {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/entry_points.txt +0 -0
  155. {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/top_level.txt +0 -0
@@ -65,28 +65,54 @@ Example:
65
65
 
66
66
  class BasePredictor:
67
67
  """
68
- BasePredictor.
69
-
70
68
  A base class for creating predictors.
71
69
 
70
+ This class provides the foundation for prediction functionality, handling model setup, inference,
71
+ and result processing across various input sources.
72
+
72
73
  Attributes:
73
74
  args (SimpleNamespace): Configuration for the predictor.
74
75
  save_dir (Path): Directory to save results.
75
76
  done_warmup (bool): Whether the predictor has finished setup.
76
- model (nn.Module): Model used for prediction.
77
+ model (torch.nn.Module): Model used for prediction.
77
78
  data (dict): Data configuration.
78
79
  device (torch.device): Device used for prediction.
79
80
  dataset (Dataset): Dataset used for prediction.
80
- vid_writer (dict): Dictionary of {save_path: video_writer, ...} writer for saving video output.
81
+ vid_writer (dict): Dictionary of {save_path: video_writer} for saving video output.
82
+ plotted_img (numpy.ndarray): Last plotted image.
83
+ source_type (SimpleNamespace): Type of input source.
84
+ seen (int): Number of images processed.
85
+ windows (List): List of window names for visualization.
86
+ batch (tuple): Current batch data.
87
+ results (List): Current batch results.
88
+ transforms (callable): Image transforms for classification.
89
+ callbacks (dict): Callback functions for different events.
90
+ txt_path (Path): Path to save text results.
91
+ _lock (threading.Lock): Lock for thread-safe inference.
92
+
93
+ Methods:
94
+ preprocess: Prepare input image before inference.
95
+ inference: Run inference on a given image.
96
+ postprocess: Process raw predictions into structured results.
97
+ predict_cli: Run prediction for command line interface.
98
+ setup_source: Set up input source and inference mode.
99
+ stream_inference: Stream inference on input source.
100
+ setup_model: Initialize and configure the model.
101
+ write_results: Write inference results to files.
102
+ save_predicted_images: Save prediction visualizations.
103
+ show: Display results in a window.
104
+ run_callbacks: Execute registered callbacks for an event.
105
+ add_callback: Register a new callback function.
81
106
  """
82
107
 
83
108
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
84
109
  """
85
- Initializes the BasePredictor class.
110
+ Initialize the BasePredictor class.
86
111
 
87
112
  Args:
88
- cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
89
- overrides (dict, optional): Configuration overrides. Defaults to None.
113
+ cfg (str | dict): Path to a configuration file or a configuration dictionary.
114
+ overrides (dict | None): Configuration overrides.
115
+ _callbacks (dict | None): Dictionary of callback functions.
90
116
  """
91
117
  self.args = get_cfg(cfg, overrides)
92
118
  self.save_dir = get_save_dir(self.args)
@@ -120,7 +146,7 @@ class BasePredictor:
120
146
  Prepares input image before inference.
121
147
 
122
148
  Args:
123
- im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
149
+ im (torch.Tensor | List(np.ndarray)): Images of shape (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
124
150
  """
125
151
  not_tensor = not isinstance(im, torch.Tensor)
126
152
  if not_tensor:
@@ -136,7 +162,7 @@ class BasePredictor:
136
162
  return im
137
163
 
138
164
  def inference(self, im, *args, **kwargs):
139
- """Runs inference on a given image using the specified model and arguments."""
165
+ """Run inference on a given image using the specified model and arguments."""
140
166
  visualize = (
141
167
  increment_path(self.save_dir / Path(self.batch[0][0]).stem, mkdir=True)
142
168
  if self.args.visualize and (not self.source_type.tensor)
@@ -149,10 +175,10 @@ class BasePredictor:
149
175
  Pre-transform input image before inference.
150
176
 
151
177
  Args:
152
- im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
178
+ im (List[np.ndarray]): Images of shape (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
153
179
 
154
180
  Returns:
155
- (list): A list of transformed images.
181
+ (List[np.ndarray]): A list of transformed images.
156
182
  """
157
183
  same_shapes = len({x.shape for x in im}) == 1
158
184
  letterbox = LetterBox(
@@ -163,11 +189,24 @@ class BasePredictor:
163
189
  return [letterbox(image=x) for x in im]
164
190
 
165
191
  def postprocess(self, preds, img, orig_imgs):
166
- """Post-processes predictions for an image and returns them."""
192
+ """Post-process predictions for an image and return them."""
167
193
  return preds
168
194
 
169
195
  def __call__(self, source=None, model=None, stream=False, *args, **kwargs):
170
- """Performs inference on an image or stream."""
196
+ """
197
+ Perform inference on an image or stream.
198
+
199
+ Args:
200
+ source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor | None):
201
+ Source for inference.
202
+ model (str | Path | torch.nn.Module | None): Model for inference.
203
+ stream (bool): Whether to stream the inference results. If True, returns a generator.
204
+ *args (Any): Additional arguments for the inference method.
205
+ **kwargs (Any): Additional keyword arguments for the inference method.
206
+
207
+ Returns:
208
+ (List[ultralytics.engine.results.Results] | generator): Results objects or generator of Results objects.
209
+ """
171
210
  self.stream = stream
172
211
  if stream:
173
212
  return self.stream_inference(source, model, *args, **kwargs)
@@ -182,6 +221,11 @@ class BasePredictor:
182
221
  the inputs in a streaming manner. This method ensures that no outputs accumulate in memory by consuming the
183
222
  generator without storing results.
184
223
 
224
+ Args:
225
+ source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor | None):
226
+ Source for inference.
227
+ model (str | Path | torch.nn.Module | None): Model for inference.
228
+
185
229
  Note:
186
230
  Do not modify this function or remove the generator. The generator ensures that no outputs are
187
231
  accumulated in memory, which is critical for preventing memory issues during long-running predictions.
@@ -191,7 +235,13 @@ class BasePredictor:
191
235
  pass
192
236
 
193
237
  def setup_source(self, source):
194
- """Sets up source and inference mode."""
238
+ """
239
+ Set up source and inference mode.
240
+
241
+ Args:
242
+ source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor):
243
+ Source for inference.
244
+ """
195
245
  self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
196
246
  self.transforms = (
197
247
  getattr(
@@ -220,7 +270,19 @@ class BasePredictor:
220
270
 
221
271
  @smart_inference_mode()
222
272
  def stream_inference(self, source=None, model=None, *args, **kwargs):
223
- """Streams real-time inference on camera feed and saves results to file."""
273
+ """
274
+ Stream real-time inference on camera feed and save results to file.
275
+
276
+ Args:
277
+ source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor | None):
278
+ Source for inference.
279
+ model (str | Path | torch.nn.Module | None): Model for inference.
280
+ *args (Any): Additional arguments for the inference method.
281
+ **kwargs (Any): Additional keyword arguments for the inference method.
282
+
283
+ Yields:
284
+ (ultralytics.engine.results.Results): Results objects.
285
+ """
224
286
  if self.args.verbose:
225
287
  LOGGER.info("")
226
288
 
@@ -306,7 +368,13 @@ class BasePredictor:
306
368
  self.run_callbacks("on_predict_end")
307
369
 
308
370
  def setup_model(self, model, verbose=True):
309
- """Initialize YOLO model with given parameters and set it to evaluation mode."""
371
+ """
372
+ Initialize YOLO model with given parameters and set it to evaluation mode.
373
+
374
+ Args:
375
+ model (str | Path | torch.nn.Module | None): Model to load or use.
376
+ verbose (bool): Whether to print verbose output.
377
+ """
310
378
  self.model = AutoBackend(
311
379
  weights=model or self.args.model,
312
380
  device=select_device(self.args.device, verbose=verbose),
@@ -323,7 +391,18 @@ class BasePredictor:
323
391
  self.model.eval()
324
392
 
325
393
  def write_results(self, i, p, im, s):
326
- """Write inference results to a file or directory."""
394
+ """
395
+ Write inference results to a file or directory.
396
+
397
+ Args:
398
+ i (int): Index of the current image in the batch.
399
+ p (Path): Path to the current image.
400
+ im (torch.Tensor): Preprocessed image tensor.
401
+ s (List[str]): List of result strings.
402
+
403
+ Returns:
404
+ (str): String with result information.
405
+ """
327
406
  string = "" # print string
328
407
  if len(im.shape) == 3:
329
408
  im = im[None] # expand for batch dim
@@ -363,7 +442,13 @@ class BasePredictor:
363
442
  return string
364
443
 
365
444
  def save_predicted_images(self, save_path="", frame=0):
366
- """Save video predictions as mp4 at specified path."""
445
+ """
446
+ Save video predictions as mp4 or images as jpg at specified path.
447
+
448
+ Args:
449
+ save_path (str): Path to save the results.
450
+ frame (int): Frame number for video mode.
451
+ """
367
452
  im = self.plotted_img
368
453
 
369
454
  # Save videos and streams
@@ -391,7 +476,7 @@ class BasePredictor:
391
476
  cv2.imwrite(str(Path(save_path).with_suffix(".jpg")), im) # save to JPG for best support
392
477
 
393
478
  def show(self, p=""):
394
- """Display an image in a window using the OpenCV imshow function."""
479
+ """Display an image in a window."""
395
480
  im = self.plotted_img
396
481
  if platform.system() == "Linux" and p not in self.windows:
397
482
  self.windows.append(p)
@@ -401,10 +486,10 @@ class BasePredictor:
401
486
  cv2.waitKey(300 if self.dataset.mode == "image" else 1) # 1 millisecond
402
487
 
403
488
  def run_callbacks(self, event: str):
404
- """Runs all registered callbacks for a specific event."""
489
+ """Run all registered callbacks for a specific event."""
405
490
  for callback in self.callbacks.get(event, []):
406
491
  callback(self)
407
492
 
408
493
  def add_callback(self, event: str, func):
409
- """Add callback."""
494
+ """Add a callback function for a specific event."""
410
495
  self.callbacks[event].append(func)
@@ -87,17 +87,20 @@ class BaseTrainer:
87
87
  fitness (float): Current fitness value.
88
88
  loss (float): Current loss value.
89
89
  tloss (float): Total loss value.
90
- loss_names (list): List of loss names.
90
+ loss_names (List): List of loss names.
91
91
  csv (Path): Path to results CSV file.
92
+ metrics (Dict): Dictionary of metrics.
93
+ plots (Dict): Dictionary of plots.
92
94
  """
93
95
 
94
96
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
95
97
  """
96
- Initializes the BaseTrainer class.
98
+ Initialize the BaseTrainer class.
97
99
 
98
100
  Args:
99
101
  cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
100
- overrides (dict, optional): Configuration overrides. Defaults to None.
102
+ overrides (Dict, optional): Configuration overrides. Defaults to None.
103
+ _callbacks (List, optional): List of callback functions. Defaults to None.
101
104
  """
102
105
  self.args = get_cfg(cfg, overrides)
103
106
  self.check_resume(overrides)
@@ -156,11 +159,11 @@ class BaseTrainer:
156
159
  callbacks.add_integration_callbacks(self)
157
160
 
158
161
  def add_callback(self, event: str, callback):
159
- """Appends the given callback."""
162
+ """Append the given callback to the event's callback list."""
160
163
  self.callbacks[event].append(callback)
161
164
 
162
165
  def set_callback(self, event: str, callback):
163
- """Overrides the existing callbacks with the given callback."""
166
+ """Override the existing callbacks with the given callback for the specified event."""
164
167
  self.callbacks[event] = [callback]
165
168
 
166
169
  def run_callbacks(self, event: str):
@@ -216,7 +219,7 @@ class BaseTrainer:
216
219
  self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
217
220
 
218
221
  def _setup_ddp(self, world_size):
219
- """Initializes and sets the DistributedDataParallel parameters for training."""
222
+ """Initialize and set the DistributedDataParallel parameters for training."""
220
223
  torch.cuda.set_device(RANK)
221
224
  self.device = torch.device("cuda", RANK)
222
225
  # LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
@@ -229,7 +232,7 @@ class BaseTrainer:
229
232
  )
230
233
 
231
234
  def _setup_train(self, world_size):
232
- """Builds dataloaders and optimizer on correct rank process."""
235
+ """Build dataloaders and optimizer on correct rank process."""
233
236
  # Model
234
237
  self.run_callbacks("on_pretrain_routine_start")
235
238
  ckpt = self.setup_model()
@@ -317,7 +320,7 @@ class BaseTrainer:
317
320
  self.run_callbacks("on_pretrain_routine_end")
318
321
 
319
322
  def _do_train(self, world_size=1):
320
- """Train completed, evaluate and plot if specified by arguments."""
323
+ """Train the model with the specified world size."""
321
324
  if world_size > 1:
322
325
  self._setup_ddp(world_size)
323
326
  self._setup_train(world_size)
@@ -477,7 +480,7 @@ class BaseTrainer:
477
480
  self.run_callbacks("teardown")
478
481
 
479
482
  def auto_batch(self, max_num_obj=0):
480
- """Get batch size by calculating memory occupation of model."""
483
+ """Calculate optimal batch size based on model and device memory constraints."""
481
484
  return check_train_batch_size(
482
485
  model=self.model,
483
486
  imgsz=self.args.imgsz,
@@ -487,12 +490,12 @@ class BaseTrainer:
487
490
  ) # returns batch size
488
491
 
489
492
  def _get_memory(self, fraction=False):
490
- """Get accelerator memory utilization in GB or fraction."""
493
+ """Get accelerator memory utilization in GB or as a fraction of total memory."""
491
494
  memory, total = 0, 0
492
495
  if self.device.type == "mps":
493
496
  memory = torch.mps.driver_allocated_memory()
494
497
  if fraction:
495
- total = torch.mps.get_mem_info()[0]
498
+ return __import__("psutil").virtual_memory().percent / 100
496
499
  elif self.device.type == "cpu":
497
500
  pass
498
501
  else:
@@ -502,7 +505,7 @@ class BaseTrainer:
502
505
  return ((memory / total) if total > 0 else 0) if fraction else (memory / 2**30)
503
506
 
504
507
  def _clear_memory(self):
505
- """Clear accelerator memory on different platforms."""
508
+ """Clear accelerator memory by calling garbage collector and emptying cache."""
506
509
  gc.collect()
507
510
  if self.device.type == "mps":
508
511
  torch.mps.empty_cache()
@@ -512,7 +515,7 @@ class BaseTrainer:
512
515
  torch.cuda.empty_cache()
513
516
 
514
517
  def read_results_csv(self):
515
- """Read results.csv into a dict using pandas."""
518
+ """Read results.csv into a dictionary using pandas."""
516
519
  import pandas as pd # scope for faster 'import ultralytics'
517
520
 
518
521
  return pd.read_csv(self.csv).to_dict(orient="list")
@@ -554,9 +557,10 @@ class BaseTrainer:
554
557
 
555
558
  def get_dataset(self):
556
559
  """
557
- Get train, val path from data dict if it exists.
560
+ Get train and validation datasets from data dictionary.
558
561
 
559
- Returns None if data format is not recognized.
562
+ Returns:
563
+ (tuple): A tuple containing the training and validation/test datasets.
560
564
  """
561
565
  try:
562
566
  if self.args.task == "classify":
@@ -580,7 +584,12 @@ class BaseTrainer:
580
584
  return data["train"], data.get("val") or data.get("test")
581
585
 
582
586
  def setup_model(self):
583
- """Load/create/download model for any task."""
587
+ """
588
+ Load, create, or download model for any task.
589
+
590
+ Returns:
591
+ (dict): Optional checkpoint to resume training from.
592
+ """
584
593
  if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
585
594
  return
586
595
 
@@ -610,9 +619,10 @@ class BaseTrainer:
610
619
 
611
620
  def validate(self):
612
621
  """
613
- Runs validation on test set using self.validator.
622
+ Run validation on test set using self.validator.
614
623
 
615
- The returned dict is expected to contain "fitness" key.
624
+ Returns:
625
+ (tuple): A tuple containing metrics dictionary and fitness score.
616
626
  """
617
627
  metrics = self.validator(self)
618
628
  fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
@@ -646,7 +656,7 @@ class BaseTrainer:
646
656
  return {"loss": loss_items} if loss_items is not None else ["loss"]
647
657
 
648
658
  def set_model_attributes(self):
649
- """To set or update model parameters before training."""
659
+ """Set or update model parameters before training."""
650
660
  self.model.names = self.data["names"]
651
661
 
652
662
  def build_targets(self, preds, targets):
@@ -667,7 +677,7 @@ class BaseTrainer:
667
677
  pass
668
678
 
669
679
  def save_metrics(self, metrics):
670
- """Saves training metrics to a CSV file."""
680
+ """Save training metrics to a CSV file."""
671
681
  keys, vals = list(metrics.keys()), list(metrics.values())
672
682
  n = len(metrics) + 2 # number of cols
673
683
  s = "" if self.csv.exists() else (("%s," * n % tuple(["epoch", "time"] + keys)).rstrip(",") + "\n") # header
@@ -685,7 +695,7 @@ class BaseTrainer:
685
695
  self.plots[path] = {"data": data, "timestamp": time.time()}
686
696
 
687
697
  def final_eval(self):
688
- """Performs final evaluation and validation for object detection YOLO model."""
698
+ """Perform final evaluation and validation for object detection YOLO model."""
689
699
  ckpt = {}
690
700
  for f in self.last, self.best:
691
701
  if f.exists():
@@ -769,8 +779,7 @@ class BaseTrainer:
769
779
 
770
780
  def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
771
781
  """
772
- Constructs an optimizer for the given model, based on the specified optimizer name, learning rate, momentum,
773
- weight decay, and number of iterations.
782
+ Construct an optimizer for the given model.
774
783
 
775
784
  Args:
776
785
  model (torch.nn.Module): The model for which to build an optimizer.
@@ -29,22 +29,22 @@ from ultralytics.utils.plotting import plot_tune_results
29
29
 
30
30
  class Tuner:
31
31
  """
32
- Class responsible for hyperparameter tuning of YOLO models.
32
+ A class for hyperparameter tuning of YOLO models.
33
33
 
34
- The class evolves YOLO model hyperparameters over a given number of iterations
35
- by mutating them according to the search space and retraining the model to evaluate their performance.
34
+ The class evolves YOLO model hyperparameters over a given number of iterations by mutating them according to the
35
+ search space and retraining the model to evaluate their performance.
36
36
 
37
37
  Attributes:
38
- space (dict): Hyperparameter search space containing bounds and scaling factors for mutation.
38
+ space (Dict): Hyperparameter search space containing bounds and scaling factors for mutation.
39
39
  tune_dir (Path): Directory where evolution logs and results will be saved.
40
40
  tune_csv (Path): Path to the CSV file where evolution logs are saved.
41
+ args (Dict): Configuration arguments for the tuning process.
42
+ callbacks (List): Callback functions to be executed during tuning.
43
+ prefix (str): Prefix string for logging messages.
41
44
 
42
45
  Methods:
43
- _mutate(hyp: dict) -> dict:
44
- Mutates the given hyperparameters within the bounds specified in `self.space`.
45
-
46
- __call__():
47
- Executes the hyperparameter evolution across multiple iterations.
46
+ _mutate: Mutates the given hyperparameters within the specified bounds.
47
+ __call__: Executes the hyperparameter evolution across multiple iterations.
48
48
 
49
49
  Examples:
50
50
  Tune hyperparameters for YOLO11n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations.
@@ -53,6 +53,7 @@ class Tuner:
53
53
  >>> model.tune(
54
54
  ... data="coco8.yaml", epochs=10, iterations=300, optimizer="AdamW", plots=False, save=False, val=False
55
55
  ... )
56
+
56
57
  Tune with custom search space.
57
58
  >>> model.tune(space={key1: val1, key2: val2}) # custom search space dictionary
58
59
  """
@@ -62,7 +63,8 @@ class Tuner:
62
63
  Initialize the Tuner with configurations.
63
64
 
64
65
  Args:
65
- args (dict, optional): Configuration for hyperparameter evolution.
66
+ args (Dict): Configuration for hyperparameter evolution.
67
+ _callbacks (List, optional): Callback functions to be executed during tuning.
66
68
  """
67
69
  self.space = args.pop("space", None) or { # key: (min, max, gain(optional))
68
70
  # 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
@@ -104,7 +106,7 @@ class Tuner:
104
106
 
105
107
  def _mutate(self, parent="single", n=5, mutation=0.8, sigma=0.2):
106
108
  """
107
- Mutates the hyperparameters based on bounds and scaling factors specified in `self.space`.
109
+ Mutate hyperparameters based on bounds and scaling factors specified in `self.space`.
108
110
 
109
111
  Args:
110
112
  parent (str): Parent selection method: 'single' or 'weighted'.
@@ -113,7 +115,7 @@ class Tuner:
113
115
  sigma (float): Standard deviation for Gaussian random number generator.
114
116
 
115
117
  Returns:
116
- (dict): A dictionary containing mutated hyperparameters.
118
+ (Dict): A dictionary containing mutated hyperparameters.
117
119
  """
118
120
  if self.tune_csv.exists(): # if CSV file exists: select best hyps and mutate
119
121
  # Select parent(s)
@@ -150,22 +152,23 @@ class Tuner:
150
152
 
151
153
  def __call__(self, model=None, iterations=10, cleanup=True):
152
154
  """
153
- Executes the hyperparameter evolution process when the Tuner instance is called.
155
+ Execute the hyperparameter evolution process when the Tuner instance is called.
154
156
 
155
157
  This method iterates through the number of iterations, performing the following steps in each iteration:
158
+
156
159
  1. Load the existing hyperparameters or initialize new ones.
157
160
  2. Mutate the hyperparameters using the `mutate` method.
158
161
  3. Train a YOLO model with the mutated hyperparameters.
159
162
  4. Log the fitness score and mutated hyperparameters to a CSV file.
160
163
 
161
164
  Args:
162
- model (Model): A pre-initialized YOLO model to be used for training.
163
- iterations (int): The number of generations to run the evolution for.
164
- cleanup (bool): Whether to delete iteration weights to reduce storage space used during tuning.
165
+ model (Model): A pre-initialized YOLO model to be used for training.
166
+ iterations (int): The number of generations to run the evolution for.
167
+ cleanup (bool): Whether to delete iteration weights to reduce storage space used during tuning.
165
168
 
166
169
  Note:
167
- The method utilizes the `self.tune_csv` Path object to read and log hyperparameters and fitness scores.
168
- Ensure this path is set correctly in the Tuner instance.
170
+ The method utilizes the `self.tune_csv` Path object to read and log hyperparameters and fitness scores.
171
+ Ensure this path is set correctly in the Tuner instance.
169
172
  """
170
173
  t0 = time.time()
171
174
  best_save_dir, best_metrics = None, None