ultralytics 8.3.89__py3-none-any.whl → 8.3.91__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (156) hide show
  1. tests/conftest.py +2 -2
  2. tests/test_cli.py +13 -11
  3. tests/test_cuda.py +10 -1
  4. tests/test_exports.py +2 -2
  5. tests/test_integrations.py +1 -5
  6. tests/test_python.py +16 -16
  7. tests/test_solutions.py +9 -9
  8. ultralytics/__init__.py +1 -1
  9. ultralytics/cfg/__init__.py +3 -1
  10. ultralytics/cfg/models/11/yolo11-cls.yaml +5 -5
  11. ultralytics/cfg/models/11/yolo11-obb.yaml +5 -5
  12. ultralytics/cfg/models/11/yolo11-pose.yaml +5 -5
  13. ultralytics/cfg/models/11/yolo11-seg.yaml +5 -5
  14. ultralytics/cfg/models/11/yolo11.yaml +5 -5
  15. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +5 -5
  16. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +5 -5
  17. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -5
  18. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -5
  19. ultralytics/cfg/models/v8/yolov8-p6.yaml +5 -5
  20. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -5
  21. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -5
  22. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -5
  23. ultralytics/cfg/models/v8/yolov8.yaml +5 -5
  24. ultralytics/cfg/models/v9/yolov9c-seg.yaml +1 -1
  25. ultralytics/cfg/models/v9/yolov9c.yaml +1 -1
  26. ultralytics/cfg/models/v9/yolov9e-seg.yaml +1 -1
  27. ultralytics/cfg/models/v9/yolov9e.yaml +1 -1
  28. ultralytics/cfg/models/v9/yolov9m.yaml +1 -1
  29. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  30. ultralytics/cfg/models/v9/yolov9t.yaml +1 -1
  31. ultralytics/data/annotator.py +9 -14
  32. ultralytics/data/base.py +118 -30
  33. ultralytics/data/build.py +63 -24
  34. ultralytics/data/converter.py +5 -5
  35. ultralytics/data/dataset.py +207 -53
  36. ultralytics/data/loaders.py +1 -0
  37. ultralytics/data/split_dota.py +39 -12
  38. ultralytics/data/utils.py +15 -19
  39. ultralytics/engine/exporter.py +24 -23
  40. ultralytics/engine/model.py +67 -88
  41. ultralytics/engine/predictor.py +106 -21
  42. ultralytics/engine/trainer.py +32 -23
  43. ultralytics/engine/tuner.py +21 -18
  44. ultralytics/engine/validator.py +75 -41
  45. ultralytics/hub/__init__.py +12 -13
  46. ultralytics/hub/auth.py +9 -12
  47. ultralytics/hub/session.py +76 -21
  48. ultralytics/hub/utils.py +19 -17
  49. ultralytics/models/fastsam/model.py +20 -11
  50. ultralytics/models/fastsam/predict.py +36 -16
  51. ultralytics/models/fastsam/utils.py +5 -5
  52. ultralytics/models/fastsam/val.py +6 -6
  53. ultralytics/models/nas/model.py +22 -11
  54. ultralytics/models/nas/predict.py +9 -4
  55. ultralytics/models/nas/val.py +5 -5
  56. ultralytics/models/rtdetr/model.py +20 -11
  57. ultralytics/models/rtdetr/predict.py +18 -15
  58. ultralytics/models/rtdetr/train.py +20 -16
  59. ultralytics/models/rtdetr/val.py +42 -6
  60. ultralytics/models/sam/__init__.py +1 -1
  61. ultralytics/models/sam/amg.py +50 -4
  62. ultralytics/models/sam/model.py +8 -14
  63. ultralytics/models/sam/modules/decoders.py +18 -21
  64. ultralytics/models/sam/modules/encoders.py +25 -46
  65. ultralytics/models/sam/modules/memory_attention.py +19 -15
  66. ultralytics/models/sam/modules/sam.py +18 -25
  67. ultralytics/models/sam/modules/tiny_encoder.py +19 -29
  68. ultralytics/models/sam/modules/transformer.py +35 -57
  69. ultralytics/models/sam/modules/utils.py +15 -15
  70. ultralytics/models/sam/predict.py +0 -3
  71. ultralytics/models/utils/loss.py +87 -36
  72. ultralytics/models/utils/ops.py +26 -31
  73. ultralytics/models/yolo/classify/predict.py +24 -3
  74. ultralytics/models/yolo/classify/train.py +77 -10
  75. ultralytics/models/yolo/classify/val.py +40 -15
  76. ultralytics/models/yolo/detect/predict.py +23 -10
  77. ultralytics/models/yolo/detect/train.py +85 -15
  78. ultralytics/models/yolo/detect/val.py +145 -21
  79. ultralytics/models/yolo/model.py +1 -2
  80. ultralytics/models/yolo/obb/predict.py +12 -4
  81. ultralytics/models/yolo/obb/train.py +7 -0
  82. ultralytics/models/yolo/obb/val.py +25 -7
  83. ultralytics/models/yolo/pose/predict.py +22 -6
  84. ultralytics/models/yolo/pose/train.py +17 -1
  85. ultralytics/models/yolo/pose/val.py +46 -21
  86. ultralytics/models/yolo/segment/predict.py +22 -8
  87. ultralytics/models/yolo/segment/train.py +6 -0
  88. ultralytics/models/yolo/segment/val.py +100 -14
  89. ultralytics/models/yolo/world/train.py +38 -8
  90. ultralytics/models/yolo/world/train_world.py +39 -10
  91. ultralytics/nn/autobackend.py +28 -14
  92. ultralytics/nn/modules/__init__.py +3 -0
  93. ultralytics/nn/modules/activation.py +12 -3
  94. ultralytics/nn/modules/block.py +587 -84
  95. ultralytics/nn/modules/conv.py +418 -54
  96. ultralytics/nn/modules/head.py +3 -4
  97. ultralytics/nn/modules/transformer.py +320 -34
  98. ultralytics/nn/modules/utils.py +17 -3
  99. ultralytics/nn/tasks.py +221 -69
  100. ultralytics/solutions/ai_gym.py +2 -2
  101. ultralytics/solutions/analytics.py +4 -4
  102. ultralytics/solutions/heatmap.py +4 -4
  103. ultralytics/solutions/instance_segmentation.py +10 -4
  104. ultralytics/solutions/object_blurrer.py +2 -2
  105. ultralytics/solutions/object_counter.py +2 -2
  106. ultralytics/solutions/object_cropper.py +2 -2
  107. ultralytics/solutions/parking_management.py +9 -9
  108. ultralytics/solutions/queue_management.py +1 -1
  109. ultralytics/solutions/region_counter.py +2 -2
  110. ultralytics/solutions/security_alarm.py +7 -7
  111. ultralytics/solutions/solutions.py +7 -4
  112. ultralytics/solutions/speed_estimation.py +2 -2
  113. ultralytics/solutions/streamlit_inference.py +6 -6
  114. ultralytics/solutions/trackzone.py +9 -2
  115. ultralytics/solutions/vision_eye.py +4 -4
  116. ultralytics/trackers/basetrack.py +1 -1
  117. ultralytics/trackers/bot_sort.py +23 -22
  118. ultralytics/trackers/byte_tracker.py +4 -4
  119. ultralytics/trackers/track.py +2 -1
  120. ultralytics/trackers/utils/gmc.py +26 -27
  121. ultralytics/trackers/utils/kalman_filter.py +31 -29
  122. ultralytics/trackers/utils/matching.py +7 -7
  123. ultralytics/utils/__init__.py +32 -27
  124. ultralytics/utils/autobatch.py +5 -5
  125. ultralytics/utils/benchmarks.py +111 -18
  126. ultralytics/utils/callbacks/base.py +3 -3
  127. ultralytics/utils/callbacks/clearml.py +11 -11
  128. ultralytics/utils/callbacks/comet.py +42 -24
  129. ultralytics/utils/callbacks/dvc.py +11 -10
  130. ultralytics/utils/callbacks/hub.py +8 -8
  131. ultralytics/utils/callbacks/mlflow.py +1 -1
  132. ultralytics/utils/callbacks/neptune.py +12 -10
  133. ultralytics/utils/callbacks/raytune.py +1 -1
  134. ultralytics/utils/callbacks/tensorboard.py +6 -6
  135. ultralytics/utils/callbacks/wb.py +16 -16
  136. ultralytics/utils/checks.py +116 -35
  137. ultralytics/utils/dist.py +15 -2
  138. ultralytics/utils/downloads.py +13 -9
  139. ultralytics/utils/files.py +12 -13
  140. ultralytics/utils/instance.py +112 -45
  141. ultralytics/utils/loss.py +28 -33
  142. ultralytics/utils/metrics.py +246 -181
  143. ultralytics/utils/ops.py +61 -53
  144. ultralytics/utils/patches.py +8 -6
  145. ultralytics/utils/plotting.py +65 -45
  146. ultralytics/utils/tal.py +88 -57
  147. ultralytics/utils/torch_utils.py +181 -33
  148. ultralytics/utils/triton.py +13 -3
  149. ultralytics/utils/tuner.py +8 -16
  150. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/METADATA +1 -1
  151. ultralytics-8.3.91.dist-info/RECORD +250 -0
  152. ultralytics-8.3.89.dist-info/RECORD +0 -250
  153. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/LICENSE +0 -0
  154. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/WHEEL +0 -0
  155. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/entry_points.txt +0 -0
  156. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/top_level.txt +0 -0
@@ -41,43 +41,66 @@ from ultralytics.utils.torch_utils import de_parallel, select_device, smart_infe
41
41
 
42
42
  class BaseValidator:
43
43
  """
44
- BaseValidator.
45
-
46
44
  A base class for creating validators.
47
45
 
46
+ This class provides the foundation for validation processes, including model evaluation, metric computation, and
47
+ result visualization.
48
+
48
49
  Attributes:
49
50
  args (SimpleNamespace): Configuration for the validator.
50
51
  dataloader (DataLoader): Dataloader to use for validation.
51
52
  pbar (tqdm): Progress bar to update during validation.
52
53
  model (nn.Module): Model to validate.
53
- data (dict): Data dictionary.
54
+ data (Dict): Data dictionary containing dataset information.
54
55
  device (torch.device): Device to use for validation.
55
56
  batch_i (int): Current batch index.
56
57
  training (bool): Whether the model is in training mode.
57
- names (dict): Class names.
58
- seen: Records the number of images seen so far during validation.
59
- stats: Placeholder for statistics during validation.
60
- confusion_matrix: Placeholder for a confusion matrix.
61
- nc: Number of classes.
62
- iouv: (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05.
63
- jdict (dict): Dictionary to store JSON validation results.
64
- speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective
65
- batch processing times in milliseconds.
58
+ names (Dict): Class names mapping.
59
+ seen (int): Number of images seen so far during validation.
60
+ stats (Dict): Statistics collected during validation.
61
+ confusion_matrix: Confusion matrix for classification evaluation.
62
+ nc (int): Number of classes.
63
+ iouv (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05.
64
+ jdict (List): List to store JSON validation results.
65
+ speed (Dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective
66
+ batch processing times in milliseconds.
66
67
  save_dir (Path): Directory to save results.
67
- plots (dict): Dictionary to store plots for visualization.
68
- callbacks (dict): Dictionary to store various callback functions.
68
+ plots (Dict): Dictionary to store plots for visualization.
69
+ callbacks (Dict): Dictionary to store various callback functions.
70
+
71
+ Methods:
72
+ __call__: Execute validation process, running inference on dataloader and computing performance metrics.
73
+ match_predictions: Match predictions to ground truth objects using IoU.
74
+ add_callback: Append the given callback to the specified event.
75
+ run_callbacks: Run all callbacks associated with a specified event.
76
+ get_dataloader: Get data loader from dataset path and batch size.
77
+ build_dataset: Build dataset from image path.
78
+ preprocess: Preprocess an input batch.
79
+ postprocess: Postprocess the predictions.
80
+ init_metrics: Initialize performance metrics for the YOLO model.
81
+ update_metrics: Update metrics based on predictions and batch.
82
+ finalize_metrics: Finalize and return all metrics.
83
+ get_stats: Return statistics about the model's performance.
84
+ check_stats: Check statistics.
85
+ print_results: Print the results of the model's predictions.
86
+ get_desc: Get description of the YOLO model.
87
+ on_plot: Register plots (e.g. to be consumed in callbacks).
88
+ plot_val_samples: Plot validation samples during training.
89
+ plot_predictions: Plot YOLO model predictions on batch images.
90
+ pred_to_json: Convert predictions to JSON format.
91
+ eval_json: Evaluate and return JSON format of prediction statistics.
69
92
  """
70
93
 
71
94
  def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
72
95
  """
73
- Initializes a BaseValidator instance.
96
+ Initialize a BaseValidator instance.
74
97
 
75
98
  Args:
76
- dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
99
+ dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
77
100
  save_dir (Path, optional): Directory to save results.
78
- pbar (tqdm.tqdm): Progress bar for displaying progress.
79
- args (SimpleNamespace): Configuration for the validator.
80
- _callbacks (dict): Dictionary to store various callback functions.
101
+ pbar (tqdm.tqdm, optional): Progress bar for displaying progress.
102
+ args (SimpleNamespace, optional): Configuration for the validator.
103
+ _callbacks (Dict, optional): Dictionary to store various callback functions.
81
104
  """
82
105
  self.args = get_cfg(overrides=args)
83
106
  self.dataloader = dataloader
@@ -107,13 +130,22 @@ class BaseValidator:
107
130
 
108
131
  @smart_inference_mode()
109
132
  def __call__(self, trainer=None, model=None):
110
- """Executes validation process, running inference on dataloader and computing performance metrics."""
133
+ """
134
+ Execute validation process, running inference on dataloader and computing performance metrics.
135
+
136
+ Args:
137
+ trainer (object, optional): Trainer object that contains the model to validate.
138
+ model (nn.Module, optional): Model to validate if not using a trainer.
139
+
140
+ Returns:
141
+ stats (dict): Dictionary containing validation statistics.
142
+ """
111
143
  self.training = trainer is not None
112
144
  augment = self.args.augment and (not self.training)
113
145
  if self.training:
114
146
  self.device = trainer.device
115
147
  self.data = trainer.data
116
- # force FP16 val during training
148
+ # Force FP16 val during training
117
149
  self.args.half = self.device.type != "cpu" and trainer.amp
118
150
  model = trainer.ema.ema or trainer.model
119
151
  model = model.half() if self.args.half else model.float()
@@ -221,18 +253,20 @@ class BaseValidator:
221
253
  LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
222
254
  return stats
223
255
 
224
- def match_predictions(self, pred_classes, true_classes, iou, use_scipy=False):
256
+ def match_predictions(
257
+ self, pred_classes: torch.Tensor, true_classes: torch.Tensor, iou: torch.Tensor, use_scipy: bool = False
258
+ ) -> torch.Tensor:
225
259
  """
226
- Matches predictions to ground truth objects (pred_classes, true_classes) using IoU.
260
+ Match predictions to ground truth objects using IoU.
227
261
 
228
262
  Args:
229
- pred_classes (torch.Tensor): Predicted class indices of shape(N,).
230
- true_classes (torch.Tensor): Target class indices of shape(M,).
231
- iou (torch.Tensor): An NxM tensor containing the pairwise IoU values for predictions and ground of truth
263
+ pred_classes (torch.Tensor): Predicted class indices of shape (N,).
264
+ true_classes (torch.Tensor): Target class indices of shape (M,).
265
+ iou (torch.Tensor): An NxM tensor containing the pairwise IoU values for predictions and ground truth.
232
266
  use_scipy (bool): Whether to use scipy for matching (more precise).
233
267
 
234
268
  Returns:
235
- (torch.Tensor): Correct tensor of shape(N,10) for 10 IoU thresholds.
269
+ (torch.Tensor): Correct tensor of shape (N, 10) for 10 IoU thresholds.
236
270
  """
237
271
  # Dx10 matrix, where D - detections, 10 - IoU thresholds
238
272
  correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool)
@@ -264,11 +298,11 @@ class BaseValidator:
264
298
  return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device)
265
299
 
266
300
  def add_callback(self, event: str, callback):
267
- """Appends the given callback."""
301
+ """Append the given callback to the specified event."""
268
302
  self.callbacks[event].append(callback)
269
303
 
270
304
  def run_callbacks(self, event: str):
271
- """Runs all callbacks associated with a specified event."""
305
+ """Run all callbacks associated with a specified event."""
272
306
  for callback in self.callbacks.get(event, []):
273
307
  callback(self)
274
308
 
@@ -277,15 +311,15 @@ class BaseValidator:
277
311
  raise NotImplementedError("get_dataloader function not implemented for this validator")
278
312
 
279
313
  def build_dataset(self, img_path):
280
- """Build dataset."""
314
+ """Build dataset from image path."""
281
315
  raise NotImplementedError("build_dataset function not implemented in validator")
282
316
 
283
317
  def preprocess(self, batch):
284
- """Preprocesses an input batch."""
318
+ """Preprocess an input batch."""
285
319
  return batch
286
320
 
287
321
  def postprocess(self, preds):
288
- """Preprocesses the predictions."""
322
+ """Postprocess the predictions."""
289
323
  return preds
290
324
 
291
325
  def init_metrics(self, model):
@@ -293,23 +327,23 @@ class BaseValidator:
293
327
  pass
294
328
 
295
329
  def update_metrics(self, preds, batch):
296
- """Updates metrics based on predictions and batch."""
330
+ """Update metrics based on predictions and batch."""
297
331
  pass
298
332
 
299
333
  def finalize_metrics(self, *args, **kwargs):
300
- """Finalizes and returns all metrics."""
334
+ """Finalize and return all metrics."""
301
335
  pass
302
336
 
303
337
  def get_stats(self):
304
- """Returns statistics about the model's performance."""
338
+ """Return statistics about the model's performance."""
305
339
  return {}
306
340
 
307
341
  def check_stats(self, stats):
308
- """Checks statistics."""
342
+ """Check statistics."""
309
343
  pass
310
344
 
311
345
  def print_results(self):
312
- """Prints the results of the model's predictions."""
346
+ """Print the results of the model's predictions."""
313
347
  pass
314
348
 
315
349
  def get_desc(self):
@@ -318,20 +352,20 @@ class BaseValidator:
318
352
 
319
353
  @property
320
354
  def metric_keys(self):
321
- """Returns the metric keys used in YOLO training/validation."""
355
+ """Return the metric keys used in YOLO training/validation."""
322
356
  return []
323
357
 
324
358
  def on_plot(self, name, data=None):
325
- """Registers plots (e.g. to be consumed in callbacks)."""
359
+ """Register plots (e.g. to be consumed in callbacks)."""
326
360
  self.plots[Path(name)] = {"data": data, "timestamp": time.time()}
327
361
 
328
362
  # TODO: may need to put these following functions into callback
329
363
  def plot_val_samples(self, batch, ni):
330
- """Plots validation samples during training."""
364
+ """Plot validation samples during training."""
331
365
  pass
332
366
 
333
367
  def plot_predictions(self, batch, preds, ni):
334
- """Plots YOLO model predictions on batch images."""
368
+ """Plot YOLO model predictions on batch images."""
335
369
  pass
336
370
 
337
371
  def pred_to_json(self, preds, batch):
@@ -23,7 +23,7 @@ __all__ = (
23
23
  )
24
24
 
25
25
 
26
- def login(api_key: str = None, save=True) -> bool:
26
+ def login(api_key: str = None, save: bool = True) -> bool:
27
27
  """
28
28
  Log in to the Ultralytics HUB API using the provided API key.
29
29
 
@@ -31,8 +31,8 @@ def login(api_key: str = None, save=True) -> bool:
31
31
  environment variable if successfully authenticated.
32
32
 
33
33
  Args:
34
- api_key (str, optional): API key to use for authentication.
35
- If not provided, it will be retrieved from SETTINGS or HUB_API_KEY environment variable.
34
+ api_key (str, optional): API key to use for authentication. If not provided, it will be retrieved from SETTINGS
35
+ or HUB_API_KEY environment variable.
36
36
  save (bool, optional): Whether to save the API key to SETTINGS if authentication is successful.
37
37
 
38
38
  Returns:
@@ -79,7 +79,7 @@ def logout():
79
79
  LOGGER.info(f"{PREFIX}logged out ✅. To log in again, use 'yolo login'.")
80
80
 
81
81
 
82
- def reset_model(model_id=""):
82
+ def reset_model(model_id: str = ""):
83
83
  """Reset a trained model to an untrained state."""
84
84
  r = requests.post(f"{HUB_API_ROOT}/model-reset", json={"modelId": model_id}, headers={"x-api-key": Auth().api_key})
85
85
  if r.status_code == 200:
@@ -95,8 +95,8 @@ def export_fmts_hub():
95
95
  return list(export_formats()["Argument"][1:]) + ["ultralytics_tflite", "ultralytics_coreml"]
96
96
 
97
97
 
98
- def export_model(model_id="", format="torchscript"):
99
- """Export a model to all formats."""
98
+ def export_model(model_id: str = "", format: str = "torchscript"):
99
+ """Export a model to the specified format."""
100
100
  assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
101
101
  r = requests.post(
102
102
  f"{HUB_API_ROOT}/v1/models/{model_id}/export", json={"format": format}, headers={"x-api-key": Auth().api_key}
@@ -105,7 +105,7 @@ def export_model(model_id="", format="torchscript"):
105
105
  LOGGER.info(f"{PREFIX}{format} export started ✅")
106
106
 
107
107
 
108
- def get_export(model_id="", format="torchscript"):
108
+ def get_export(model_id: str = "", format: str = "torchscript"):
109
109
  """Get an exported model dictionary with download URL."""
110
110
  assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
111
111
  r = requests.post(
@@ -119,17 +119,12 @@ def get_export(model_id="", format="torchscript"):
119
119
 
120
120
  def check_dataset(path: str, task: str) -> None:
121
121
  """
122
- Function for error-checking HUB dataset Zip file before upload. It checks a dataset for errors before it is uploaded
123
- to the HUB. Usage examples are given below.
122
+ Check HUB dataset Zip file for errors before upload.
124
123
 
125
124
  Args:
126
125
  path (str): Path to data.zip (with data.yaml inside data.zip).
127
126
  task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify', 'obb'.
128
127
 
129
- Note:
130
- Download *.zip files from https://github.com/ultralytics/hub/tree/main/example_datasets
131
- i.e. https://github.com/ultralytics/hub/raw/main/example_datasets/coco8.zip for coco8.zip.
132
-
133
128
  Examples:
134
129
  >>> from ultralytics.hub import check_dataset
135
130
  >>> check_dataset("path/to/coco8.zip", task="detect") # detect dataset
@@ -137,6 +132,10 @@ def check_dataset(path: str, task: str) -> None:
137
132
  >>> check_dataset("path/to/coco8-pose.zip", task="pose") # pose dataset
138
133
  >>> check_dataset("path/to/dota8.zip", task="obb") # OBB dataset
139
134
  >>> check_dataset("path/to/imagenet10.zip", task="classify") # classification dataset
135
+
136
+ Note:
137
+ Download *.zip files from https://github.com/ultralytics/hub/tree/main/example_datasets
138
+ i.e. https://github.com/ultralytics/hub/raw/main/example_datasets/coco8.zip for coco8.zip.
140
139
  """
141
140
  HUBDatasetStats(path=path, task=task).get_json()
142
141
  LOGGER.info(f"Checks completed correctly ✅. Upload this dataset to {HUB_WEB_ROOT}/datasets/.")
ultralytics/hub/auth.py CHANGED
@@ -18,14 +18,14 @@ class Auth:
18
18
  3. Prompting the user to enter an API key.
19
19
 
20
20
  Attributes:
21
- id_token (str or bool): Token used for identity verification, initialized as False.
22
- api_key (str or bool): API key for authentication, initialized as False.
21
+ id_token (str | bool): Token used for identity verification, initialized as False.
22
+ api_key (str | bool): API key for authentication, initialized as False.
23
23
  model_key (bool): Placeholder for model key, initialized as False.
24
24
  """
25
25
 
26
26
  id_token = api_key = model_key = False
27
27
 
28
- def __init__(self, api_key="", verbose=False):
28
+ def __init__(self, api_key: str = "", verbose: bool = False):
29
29
  """
30
30
  Initialize Auth class and authenticate user.
31
31
 
@@ -70,12 +70,8 @@ class Auth:
70
70
  elif verbose:
71
71
  LOGGER.info(f"{PREFIX}Get API key from {API_KEY_URL} and then run 'yolo login API_KEY'")
72
72
 
73
- def request_api_key(self, max_attempts=3):
74
- """
75
- Prompt the user to input their API key.
76
-
77
- Returns the model ID.
78
- """
73
+ def request_api_key(self, max_attempts: int = 3) -> bool:
74
+ """Prompt the user to input their API key."""
79
75
  import getpass
80
76
 
81
77
  for attempts in range(max_attempts):
@@ -107,8 +103,9 @@ class Auth:
107
103
 
108
104
  def auth_with_cookies(self) -> bool:
109
105
  """
110
- Attempt to fetch authentication via cookies and set id_token. User must be logged in to HUB and running in a
111
- supported browser.
106
+ Attempt to fetch authentication via cookies and set id_token.
107
+
108
+ User must be logged in to HUB and running in a supported browser.
112
109
 
113
110
  Returns:
114
111
  (bool): True if authentication is successful, False otherwise.
@@ -131,7 +128,7 @@ class Auth:
131
128
  Get the authentication header for making API requests.
132
129
 
133
130
  Returns:
134
- (dict): The authentication header if id_token or API key is set, None otherwise.
131
+ (dict | None): The authentication header if id_token or API key is set, None otherwise.
135
132
  """
136
133
  if self.id_token:
137
134
  return {"authorization": f"Bearer {self.id_token}"}
@@ -20,13 +20,25 @@ class HUBTrainingSession:
20
20
  """
21
21
  HUB training session for Ultralytics HUB YOLO models. Handles model initialization, heartbeats, and checkpointing.
22
22
 
23
+ This class encapsulates the functionality for interacting with Ultralytics HUB during model training, including
24
+ model creation, metrics tracking, and checkpoint uploading.
25
+
23
26
  Attributes:
24
27
  model_id (str): Identifier for the YOLO model being trained.
25
28
  model_url (str): URL for the model in Ultralytics HUB.
26
- rate_limits (dict): Rate limits for different API calls (in seconds).
27
- timers (dict): Timers for rate limiting.
28
- metrics_queue (dict): Queue for the model's metrics.
29
- model (dict): Model data fetched from Ultralytics HUB.
29
+ rate_limits (Dict): Rate limits for different API calls (in seconds).
30
+ timers (Dict): Timers for rate limiting.
31
+ metrics_queue (Dict): Queue for the model's metrics.
32
+ metrics_upload_failed_queue (Dict): Queue for metrics that failed to upload.
33
+ model (Dict): Model data fetched from Ultralytics HUB.
34
+ model_file (str): Path to the model file.
35
+ train_args (Dict): Arguments for training the model.
36
+ client (HUBClient): Client for interacting with Ultralytics HUB.
37
+ filename (str): Filename of the model.
38
+
39
+ Examples:
40
+ >>> session = HUBTrainingSession("https://hub.ultralytics.com/models/example-model")
41
+ >>> session.upload_metrics()
30
42
  """
31
43
 
32
44
  def __init__(self, identifier):
@@ -78,7 +90,16 @@ class HUBTrainingSession:
78
90
 
79
91
  @classmethod
80
92
  def create_session(cls, identifier, args=None):
81
- """Class method to create an authenticated HUBTrainingSession or return None."""
93
+ """
94
+ Create an authenticated HUBTrainingSession or return None.
95
+
96
+ Args:
97
+ identifier (str): Model identifier used to initialize the HUB training session.
98
+ args (Dict, optional): Arguments for creating a new model if identifier is not a HUB model URL.
99
+
100
+ Returns:
101
+ (HUBTrainingSession | None): An authenticated session or None if creation fails.
102
+ """
82
103
  try:
83
104
  session = cls(identifier)
84
105
  if args and not identifier.startswith(f"{HUB_WEB_ROOT}/models/"): # not a HUB model URL
@@ -90,7 +111,15 @@ class HUBTrainingSession:
90
111
  return None
91
112
 
92
113
  def load_model(self, model_id):
93
- """Loads an existing model from Ultralytics HUB using the provided model identifier."""
114
+ """
115
+ Load an existing model from Ultralytics HUB using the provided model identifier.
116
+
117
+ Args:
118
+ model_id (str): The identifier of the model to load.
119
+
120
+ Raises:
121
+ ValueError: If the specified HUB model does not exist.
122
+ """
94
123
  self.model = self.client.model(model_id)
95
124
  if not self.model.data: # then model does not exist
96
125
  raise ValueError(emojis("❌ The specified HUB model does not exist")) # TODO: improve error handling
@@ -108,7 +137,15 @@ class HUBTrainingSession:
108
137
  LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
109
138
 
110
139
  def create_model(self, model_args):
111
- """Initializes a HUB training session with the specified model identifier."""
140
+ """
141
+ Initialize a HUB training session with the specified model arguments.
142
+
143
+ Args:
144
+ model_args (Dict): Arguments for creating the model, including batch size, epochs, image size, etc.
145
+
146
+ Returns:
147
+ (None): If the model could not be created.
148
+ """
112
149
  payload = {
113
150
  "config": {
114
151
  "batchSize": model_args.get("batch", -1),
@@ -146,7 +183,7 @@ class HUBTrainingSession:
146
183
  @staticmethod
147
184
  def _parse_identifier(identifier):
148
185
  """
149
- Parses the given identifier to determine the type of identifier and extract relevant components.
186
+ Parse the given identifier to determine the type and extract relevant components.
150
187
 
151
188
  The method supports different identifier formats:
152
189
  - A HUB model URL https://hub.ultralytics.com/models/MODEL
@@ -176,7 +213,7 @@ class HUBTrainingSession:
176
213
 
177
214
  def _set_train_args(self):
178
215
  """
179
- Initializes training arguments and creates a model entry on the Ultralytics HUB.
216
+ Initialize training arguments and create a model entry on the Ultralytics HUB.
180
217
 
181
218
  This method sets up training arguments based on the model's state and updates them with any additional
182
219
  arguments provided. It handles different states of the model, such as whether it's resumable, pretrained,
@@ -218,10 +255,26 @@ class HUBTrainingSession:
218
255
  *args,
219
256
  **kwargs,
220
257
  ):
221
- """Attempts to execute `request_func` with retries, timeout handling, optional threading, and progress."""
258
+ """
259
+ Attempt to execute `request_func` with retries, timeout handling, optional threading, and progress tracking.
260
+
261
+ Args:
262
+ request_func (callable): The function to execute.
263
+ retry (int): Number of retry attempts.
264
+ timeout (int): Maximum time to wait for the request to complete.
265
+ thread (bool): Whether to run the request in a separate thread.
266
+ verbose (bool): Whether to log detailed messages.
267
+ progress_total (int, optional): Total size for progress tracking.
268
+ stream_response (bool, optional): Whether to stream the response.
269
+ *args (Any): Additional positional arguments for request_func.
270
+ **kwargs (Any): Additional keyword arguments for request_func.
271
+
272
+ Returns:
273
+ (requests.Response | None): The response object if thread=False, otherwise None.
274
+ """
222
275
 
223
276
  def retry_request():
224
- """Attempts to call `request_func` with retries, timeout, and optional threading."""
277
+ """Attempt to call `request_func` with retries, timeout, and optional threading."""
225
278
  t0 = time.time() # Record the start time for the timeout
226
279
  response = None
227
280
  for i in range(retry + 1):
@@ -274,7 +327,15 @@ class HUBTrainingSession:
274
327
 
275
328
  @staticmethod
276
329
  def _should_retry(status_code):
277
- """Determines if a request should be retried based on the HTTP status code."""
330
+ """
331
+ Determine if a request should be retried based on the HTTP status code.
332
+
333
+ Args:
334
+ status_code (int): The HTTP status code from the response.
335
+
336
+ Returns:
337
+ (bool): True if the request should be retried, False otherwise.
338
+ """
278
339
  retry_codes = {
279
340
  HTTPStatus.REQUEST_TIMEOUT,
280
341
  HTTPStatus.BAD_GATEWAY,
@@ -287,9 +348,9 @@ class HUBTrainingSession:
287
348
  Generate a retry message based on the response status code.
288
349
 
289
350
  Args:
290
- response: The HTTP response object.
291
- retry: The number of retry attempts allowed.
292
- timeout: The maximum timeout duration.
351
+ response (requests.Response): The HTTP response object.
352
+ retry (int): The number of retry attempts allowed.
353
+ timeout (int): The maximum timeout duration.
293
354
 
294
355
  Returns:
295
356
  (str): The retry message.
@@ -367,9 +428,6 @@ class HUBTrainingSession:
367
428
  Args:
368
429
  content_length (int): The total size of the content to be downloaded in bytes.
369
430
  response (requests.Response): The response object from the file download request.
370
-
371
- Returns:
372
- None
373
431
  """
374
432
  with TQDM(total=content_length, unit="B", unit_scale=True, unit_divisor=1024) as pbar:
375
433
  for data in response.iter_content(chunk_size=1024):
@@ -382,9 +440,6 @@ class HUBTrainingSession:
382
440
 
383
441
  Args:
384
442
  response (requests.Response): The response object from the file download request.
385
-
386
- Returns:
387
- None
388
443
  """
389
444
  for _ in response.iter_content(chunk_size=1024):
390
445
  pass # Do nothing with data chunks
ultralytics/hub/utils.py CHANGED
@@ -43,7 +43,7 @@ def request_with_credentials(url: str) -> any:
43
43
  url (str): The URL to make the request to.
44
44
 
45
45
  Returns:
46
- (any): The response data from the AJAX request.
46
+ (Any): The response data from the AJAX request.
47
47
 
48
48
  Raises:
49
49
  OSError: If the function is not run in a Google Colab environment.
@@ -83,14 +83,14 @@ def requests_with_progress(method, url, **kwargs):
83
83
  Args:
84
84
  method (str): The HTTP method to use (e.g. 'GET', 'POST').
85
85
  url (str): The URL to send the request to.
86
- **kwargs (any): Additional keyword arguments to pass to the underlying `requests.request` function.
86
+ **kwargs (Any): Additional keyword arguments to pass to the underlying `requests.request` function.
87
87
 
88
88
  Returns:
89
89
  (requests.Response): The response object from the HTTP request.
90
90
 
91
- Note:
91
+ Notes:
92
92
  - If 'progress' is set to True, the progress bar will display the download progress for responses with a known
93
- content length.
93
+ content length.
94
94
  - If 'progress' is a number then progress bar will display assuming content length = progress.
95
95
  """
96
96
  progress = kwargs.pop("progress", False)
@@ -110,18 +110,18 @@ def requests_with_progress(method, url, **kwargs):
110
110
 
111
111
  def smart_request(method, url, retry=3, timeout=30, thread=True, code=-1, verbose=True, progress=False, **kwargs):
112
112
  """
113
- Makes an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout.
113
+ Make an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout.
114
114
 
115
115
  Args:
116
116
  method (str): The HTTP method to use for the request. Choices are 'post' and 'get'.
117
117
  url (str): The URL to make the request to.
118
- retry (int, optional): Number of retries to attempt before giving up. Default is 3.
119
- timeout (int, optional): Timeout in seconds after which the function will give up retrying. Default is 30.
120
- thread (bool, optional): Whether to execute the request in a separate daemon thread. Default is True.
121
- code (int, optional): An identifier for the request, used for logging purposes. Default is -1.
122
- verbose (bool, optional): A flag to determine whether to print out to console or not. Default is True.
123
- progress (bool, optional): Whether to show a progress bar during the request. Default is False.
124
- **kwargs (any): Keyword arguments to be passed to the requests function specified in method.
118
+ retry (int, optional): Number of retries to attempt before giving up.
119
+ timeout (int, optional): Timeout in seconds after which the function will give up retrying.
120
+ thread (bool, optional): Whether to execute the request in a separate daemon thread.
121
+ code (int, optional): An identifier for the request, used for logging purposes.
122
+ verbose (bool, optional): A flag to determine whether to print out to console or not.
123
+ progress (bool, optional): Whether to show a progress bar during the request.
124
+ **kwargs (Any): Keyword arguments to be passed to the requests function specified in method.
125
125
 
126
126
  Returns:
127
127
  (requests.Response): The HTTP response object. If the request is executed in a separate thread, returns None.
@@ -169,20 +169,22 @@ def smart_request(method, url, retry=3, timeout=30, thread=True, code=-1, verbos
169
169
 
170
170
  class Events:
171
171
  """
172
- A class for collecting anonymous event analytics. Event analytics are enabled when sync=True in settings and
173
- disabled when sync=False. Run 'yolo settings' to see and update settings.
172
+ A class for collecting anonymous event analytics.
173
+
174
+ Event analytics are enabled when sync=True in settings and disabled when sync=False. Run 'yolo settings' to see and
175
+ update settings.
174
176
 
175
177
  Attributes:
176
178
  url (str): The URL to send anonymous events.
177
179
  rate_limit (float): The rate limit in seconds for sending events.
178
- metadata (dict): A dictionary containing metadata about the environment.
180
+ metadata (Dict): A dictionary containing metadata about the environment.
179
181
  enabled (bool): A flag to enable or disable Events based on certain conditions.
180
182
  """
181
183
 
182
184
  url = "https://www.google-analytics.com/mp/collect?measurement_id=G-X8NCJYTQXM&api_secret=QLQrATrNSwGRFRLE-cbHJw"
183
185
 
184
186
  def __init__(self):
185
- """Initializes the Events object with default values for events, rate_limit, and metadata."""
187
+ """Initialize the Events object with default values for events, rate_limit, and metadata."""
186
188
  self.events = [] # events list
187
189
  self.rate_limit = 30.0 # rate limit (seconds)
188
190
  self.t = 0.0 # rate limit timer (seconds)
@@ -205,7 +207,7 @@ class Events:
205
207
 
206
208
  def __call__(self, cfg):
207
209
  """
208
- Attempts to add a new event to the events list and send events if the rate limit is reached.
210
+ Attempt to add a new event to the events list and send events if the rate limit is reached.
209
211
 
210
212
  Args:
211
213
  cfg (IterableSimpleNamespace): The configuration object containing mode and task information.