ultralytics 8.3.89__py3-none-any.whl → 8.3.91__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (156) hide show
  1. tests/conftest.py +2 -2
  2. tests/test_cli.py +13 -11
  3. tests/test_cuda.py +10 -1
  4. tests/test_exports.py +2 -2
  5. tests/test_integrations.py +1 -5
  6. tests/test_python.py +16 -16
  7. tests/test_solutions.py +9 -9
  8. ultralytics/__init__.py +1 -1
  9. ultralytics/cfg/__init__.py +3 -1
  10. ultralytics/cfg/models/11/yolo11-cls.yaml +5 -5
  11. ultralytics/cfg/models/11/yolo11-obb.yaml +5 -5
  12. ultralytics/cfg/models/11/yolo11-pose.yaml +5 -5
  13. ultralytics/cfg/models/11/yolo11-seg.yaml +5 -5
  14. ultralytics/cfg/models/11/yolo11.yaml +5 -5
  15. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +5 -5
  16. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +5 -5
  17. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -5
  18. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -5
  19. ultralytics/cfg/models/v8/yolov8-p6.yaml +5 -5
  20. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -5
  21. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -5
  22. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -5
  23. ultralytics/cfg/models/v8/yolov8.yaml +5 -5
  24. ultralytics/cfg/models/v9/yolov9c-seg.yaml +1 -1
  25. ultralytics/cfg/models/v9/yolov9c.yaml +1 -1
  26. ultralytics/cfg/models/v9/yolov9e-seg.yaml +1 -1
  27. ultralytics/cfg/models/v9/yolov9e.yaml +1 -1
  28. ultralytics/cfg/models/v9/yolov9m.yaml +1 -1
  29. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  30. ultralytics/cfg/models/v9/yolov9t.yaml +1 -1
  31. ultralytics/data/annotator.py +9 -14
  32. ultralytics/data/base.py +118 -30
  33. ultralytics/data/build.py +63 -24
  34. ultralytics/data/converter.py +5 -5
  35. ultralytics/data/dataset.py +207 -53
  36. ultralytics/data/loaders.py +1 -0
  37. ultralytics/data/split_dota.py +39 -12
  38. ultralytics/data/utils.py +15 -19
  39. ultralytics/engine/exporter.py +24 -23
  40. ultralytics/engine/model.py +67 -88
  41. ultralytics/engine/predictor.py +106 -21
  42. ultralytics/engine/trainer.py +32 -23
  43. ultralytics/engine/tuner.py +21 -18
  44. ultralytics/engine/validator.py +75 -41
  45. ultralytics/hub/__init__.py +12 -13
  46. ultralytics/hub/auth.py +9 -12
  47. ultralytics/hub/session.py +76 -21
  48. ultralytics/hub/utils.py +19 -17
  49. ultralytics/models/fastsam/model.py +20 -11
  50. ultralytics/models/fastsam/predict.py +36 -16
  51. ultralytics/models/fastsam/utils.py +5 -5
  52. ultralytics/models/fastsam/val.py +6 -6
  53. ultralytics/models/nas/model.py +22 -11
  54. ultralytics/models/nas/predict.py +9 -4
  55. ultralytics/models/nas/val.py +5 -5
  56. ultralytics/models/rtdetr/model.py +20 -11
  57. ultralytics/models/rtdetr/predict.py +18 -15
  58. ultralytics/models/rtdetr/train.py +20 -16
  59. ultralytics/models/rtdetr/val.py +42 -6
  60. ultralytics/models/sam/__init__.py +1 -1
  61. ultralytics/models/sam/amg.py +50 -4
  62. ultralytics/models/sam/model.py +8 -14
  63. ultralytics/models/sam/modules/decoders.py +18 -21
  64. ultralytics/models/sam/modules/encoders.py +25 -46
  65. ultralytics/models/sam/modules/memory_attention.py +19 -15
  66. ultralytics/models/sam/modules/sam.py +18 -25
  67. ultralytics/models/sam/modules/tiny_encoder.py +19 -29
  68. ultralytics/models/sam/modules/transformer.py +35 -57
  69. ultralytics/models/sam/modules/utils.py +15 -15
  70. ultralytics/models/sam/predict.py +0 -3
  71. ultralytics/models/utils/loss.py +87 -36
  72. ultralytics/models/utils/ops.py +26 -31
  73. ultralytics/models/yolo/classify/predict.py +24 -3
  74. ultralytics/models/yolo/classify/train.py +77 -10
  75. ultralytics/models/yolo/classify/val.py +40 -15
  76. ultralytics/models/yolo/detect/predict.py +23 -10
  77. ultralytics/models/yolo/detect/train.py +85 -15
  78. ultralytics/models/yolo/detect/val.py +145 -21
  79. ultralytics/models/yolo/model.py +1 -2
  80. ultralytics/models/yolo/obb/predict.py +12 -4
  81. ultralytics/models/yolo/obb/train.py +7 -0
  82. ultralytics/models/yolo/obb/val.py +25 -7
  83. ultralytics/models/yolo/pose/predict.py +22 -6
  84. ultralytics/models/yolo/pose/train.py +17 -1
  85. ultralytics/models/yolo/pose/val.py +46 -21
  86. ultralytics/models/yolo/segment/predict.py +22 -8
  87. ultralytics/models/yolo/segment/train.py +6 -0
  88. ultralytics/models/yolo/segment/val.py +100 -14
  89. ultralytics/models/yolo/world/train.py +38 -8
  90. ultralytics/models/yolo/world/train_world.py +39 -10
  91. ultralytics/nn/autobackend.py +28 -14
  92. ultralytics/nn/modules/__init__.py +3 -0
  93. ultralytics/nn/modules/activation.py +12 -3
  94. ultralytics/nn/modules/block.py +587 -84
  95. ultralytics/nn/modules/conv.py +418 -54
  96. ultralytics/nn/modules/head.py +3 -4
  97. ultralytics/nn/modules/transformer.py +320 -34
  98. ultralytics/nn/modules/utils.py +17 -3
  99. ultralytics/nn/tasks.py +221 -69
  100. ultralytics/solutions/ai_gym.py +2 -2
  101. ultralytics/solutions/analytics.py +4 -4
  102. ultralytics/solutions/heatmap.py +4 -4
  103. ultralytics/solutions/instance_segmentation.py +10 -4
  104. ultralytics/solutions/object_blurrer.py +2 -2
  105. ultralytics/solutions/object_counter.py +2 -2
  106. ultralytics/solutions/object_cropper.py +2 -2
  107. ultralytics/solutions/parking_management.py +9 -9
  108. ultralytics/solutions/queue_management.py +1 -1
  109. ultralytics/solutions/region_counter.py +2 -2
  110. ultralytics/solutions/security_alarm.py +7 -7
  111. ultralytics/solutions/solutions.py +7 -4
  112. ultralytics/solutions/speed_estimation.py +2 -2
  113. ultralytics/solutions/streamlit_inference.py +6 -6
  114. ultralytics/solutions/trackzone.py +9 -2
  115. ultralytics/solutions/vision_eye.py +4 -4
  116. ultralytics/trackers/basetrack.py +1 -1
  117. ultralytics/trackers/bot_sort.py +23 -22
  118. ultralytics/trackers/byte_tracker.py +4 -4
  119. ultralytics/trackers/track.py +2 -1
  120. ultralytics/trackers/utils/gmc.py +26 -27
  121. ultralytics/trackers/utils/kalman_filter.py +31 -29
  122. ultralytics/trackers/utils/matching.py +7 -7
  123. ultralytics/utils/__init__.py +32 -27
  124. ultralytics/utils/autobatch.py +5 -5
  125. ultralytics/utils/benchmarks.py +111 -18
  126. ultralytics/utils/callbacks/base.py +3 -3
  127. ultralytics/utils/callbacks/clearml.py +11 -11
  128. ultralytics/utils/callbacks/comet.py +42 -24
  129. ultralytics/utils/callbacks/dvc.py +11 -10
  130. ultralytics/utils/callbacks/hub.py +8 -8
  131. ultralytics/utils/callbacks/mlflow.py +1 -1
  132. ultralytics/utils/callbacks/neptune.py +12 -10
  133. ultralytics/utils/callbacks/raytune.py +1 -1
  134. ultralytics/utils/callbacks/tensorboard.py +6 -6
  135. ultralytics/utils/callbacks/wb.py +16 -16
  136. ultralytics/utils/checks.py +116 -35
  137. ultralytics/utils/dist.py +15 -2
  138. ultralytics/utils/downloads.py +13 -9
  139. ultralytics/utils/files.py +12 -13
  140. ultralytics/utils/instance.py +112 -45
  141. ultralytics/utils/loss.py +28 -33
  142. ultralytics/utils/metrics.py +246 -181
  143. ultralytics/utils/ops.py +61 -53
  144. ultralytics/utils/patches.py +8 -6
  145. ultralytics/utils/plotting.py +65 -45
  146. ultralytics/utils/tal.py +88 -57
  147. ultralytics/utils/torch_utils.py +181 -33
  148. ultralytics/utils/triton.py +13 -3
  149. ultralytics/utils/tuner.py +8 -16
  150. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/METADATA +1 -1
  151. ultralytics-8.3.91.dist-info/RECORD +250 -0
  152. ultralytics-8.3.89.dist-info/RECORD +0 -250
  153. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/LICENSE +0 -0
  154. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/WHEEL +0 -0
  155. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/entry_points.txt +0 -0
  156. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,12 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+
3
4
  from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
4
5
 
5
6
  try:
6
7
  assert not TESTS_RUNNING # do not log pytest
7
8
  assert SETTINGS["neptune"] is True # verify integration is enabled
9
+
8
10
  import neptune
9
11
  from neptune.types import File
10
12
 
@@ -16,27 +18,27 @@ except (ImportError, AssertionError):
16
18
  neptune = None
17
19
 
18
20
 
19
- def _log_scalars(scalars, step=0):
21
+ def _log_scalars(scalars: dict, step: int = 0) -> None:
20
22
  """Log scalars to the NeptuneAI experiment logger."""
21
23
  if run:
22
24
  for k, v in scalars.items():
23
25
  run[k].append(value=v, step=step)
24
26
 
25
27
 
26
- def _log_images(imgs_dict, group=""):
27
- """Log scalars to the NeptuneAI experiment logger."""
28
+ def _log_images(imgs_dict: dict, group: str = "") -> None:
29
+ """Log images to the NeptuneAI experiment logger."""
28
30
  if run:
29
31
  for k, v in imgs_dict.items():
30
32
  run[f"{group}/{k}"].upload(File(v))
31
33
 
32
34
 
33
- def _log_plot(title, plot_path):
35
+ def _log_plot(title: str, plot_path: str) -> None:
34
36
  """
35
37
  Log plots to the NeptuneAI experiment logger.
36
38
 
37
39
  Args:
38
40
  title (str): Title of the plot.
39
- plot_path (PosixPath | str): Path to the saved image file.
41
+ plot_path (str): Path to the saved image file.
40
42
  """
41
43
  import matplotlib.image as mpimg
42
44
  import matplotlib.pyplot as plt
@@ -48,7 +50,7 @@ def _log_plot(title, plot_path):
48
50
  run[f"Plots/{title}"].upload(fig)
49
51
 
50
52
 
51
- def on_pretrain_routine_start(trainer):
53
+ def on_pretrain_routine_start(trainer) -> None:
52
54
  """Callback function called before the training routine starts."""
53
55
  try:
54
56
  global run
@@ -62,7 +64,7 @@ def on_pretrain_routine_start(trainer):
62
64
  LOGGER.warning(f"WARNING ⚠️ NeptuneAI installed but not initialized correctly, not logging this run. {e}")
63
65
 
64
66
 
65
- def on_train_epoch_end(trainer):
67
+ def on_train_epoch_end(trainer) -> None:
66
68
  """Callback function called at end of each training epoch."""
67
69
  _log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
68
70
  _log_scalars(trainer.lr, trainer.epoch + 1)
@@ -70,7 +72,7 @@ def on_train_epoch_end(trainer):
70
72
  _log_images({f.stem: str(f) for f in trainer.save_dir.glob("train_batch*.jpg")}, "Mosaic")
71
73
 
72
74
 
73
- def on_fit_epoch_end(trainer):
75
+ def on_fit_epoch_end(trainer) -> None:
74
76
  """Callback function called at end of each fit (train+val) epoch."""
75
77
  if run and trainer.epoch == 0:
76
78
  from ultralytics.utils.torch_utils import model_info_for_loggers
@@ -79,14 +81,14 @@ def on_fit_epoch_end(trainer):
79
81
  _log_scalars(trainer.metrics, trainer.epoch + 1)
80
82
 
81
83
 
82
- def on_val_end(validator):
84
+ def on_val_end(validator) -> None:
83
85
  """Callback function called at end of each validation."""
84
86
  if run:
85
87
  # Log val_labels and val_pred
86
88
  _log_images({f.stem: str(f) for f in validator.save_dir.glob("val*.jpg")}, "Validation")
87
89
 
88
90
 
89
- def on_train_end(trainer):
91
+ def on_train_end(trainer) -> None:
90
92
  """Callback function called at end of training."""
91
93
  if run:
92
94
  # Log final results, CM matrix + PR plots
@@ -14,7 +14,7 @@ except (ImportError, AssertionError):
14
14
 
15
15
  def on_fit_epoch_end(trainer):
16
16
  """Sends training metrics to Ray Tune at end of each epoch."""
17
- if ray.train._internal.session.get_session(): # replacement for deprecated ray.tune.is_session_enabled()
17
+ if ray.train._internal.session.get_session(): # check if Ray Tune session is active
18
18
  metrics = trainer.metrics
19
19
  session.report({**metrics, **{"epoch": trainer.epoch + 1}})
20
20
 
@@ -23,14 +23,14 @@ except (ImportError, AssertionError, TypeError, AttributeError):
23
23
  SummaryWriter = None
24
24
 
25
25
 
26
- def _log_scalars(scalars, step=0):
26
+ def _log_scalars(scalars: dict, step: int = 0) -> None:
27
27
  """Logs scalar values to TensorBoard."""
28
28
  if WRITER:
29
29
  for k, v in scalars.items():
30
30
  WRITER.add_scalar(k, v, step)
31
31
 
32
32
 
33
- def _log_tensorboard_graph(trainer):
33
+ def _log_tensorboard_graph(trainer) -> None:
34
34
  """Log model graph to TensorBoard."""
35
35
  # Input image
36
36
  imgsz = trainer.args.imgsz
@@ -66,7 +66,7 @@ def _log_tensorboard_graph(trainer):
66
66
  LOGGER.warning(f"{PREFIX}WARNING ⚠️ TensorBoard graph visualization failure {e}")
67
67
 
68
68
 
69
- def on_pretrain_routine_start(trainer):
69
+ def on_pretrain_routine_start(trainer) -> None:
70
70
  """Initialize TensorBoard logging with SummaryWriter."""
71
71
  if SummaryWriter:
72
72
  try:
@@ -77,19 +77,19 @@ def on_pretrain_routine_start(trainer):
77
77
  LOGGER.warning(f"{PREFIX}WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}")
78
78
 
79
79
 
80
- def on_train_start(trainer):
80
+ def on_train_start(trainer) -> None:
81
81
  """Log TensorBoard graph."""
82
82
  if WRITER:
83
83
  _log_tensorboard_graph(trainer)
84
84
 
85
85
 
86
- def on_train_epoch_end(trainer):
86
+ def on_train_epoch_end(trainer) -> None:
87
87
  """Logs scalar statistics at the end of a training epoch."""
88
88
  _log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
89
89
  _log_scalars(trainer.lr, trainer.epoch + 1)
90
90
 
91
91
 
92
- def on_fit_epoch_end(trainer):
92
+ def on_fit_epoch_end(trainer) -> None:
93
93
  """Logs epoch metrics at end of training epoch."""
94
94
  _log_scalars(trainer.metrics, trainer.epoch + 1)
95
95
 
@@ -27,9 +27,9 @@ def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall
27
27
  x (List): Values for the x-axis; expected to have length N.
28
28
  y (List): Corresponding values for the y-axis; also expected to have length N.
29
29
  classes (List): Labels identifying the class of each point; length N.
30
- title (str, optional): Title for the plot; defaults to 'Precision Recall Curve'.
31
- x_title (str, optional): Label for the x-axis; defaults to 'Recall'.
32
- y_title (str, optional): Label for the y-axis; defaults to 'Precision'.
30
+ title (str): Title for the plot; defaults to 'Precision Recall Curve'.
31
+ x_title (str): Label for the x-axis; defaults to 'Recall'.
32
+ y_title (str): Label for the y-axis; defaults to 'Precision'.
33
33
 
34
34
  Returns:
35
35
  (wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization.
@@ -63,16 +63,16 @@ def _plot_curve(
63
63
 
64
64
  Args:
65
65
  x (np.ndarray): Data points for the x-axis with length N.
66
- y (np.ndarray): Corresponding data points for the y-axis with shape CxN, where C is the number of classes.
67
- names (list, optional): Names of the classes corresponding to the y-axis data; length C. Defaults to [].
68
- id (str, optional): Unique identifier for the logged data in wandb. Defaults to 'precision-recall'.
69
- title (str, optional): Title for the visualization plot. Defaults to 'Precision Recall Curve'.
70
- x_title (str, optional): Label for the x-axis. Defaults to 'Recall'.
71
- y_title (str, optional): Label for the y-axis. Defaults to 'Precision'.
72
- num_x (int, optional): Number of interpolated data points for visualization. Defaults to 100.
73
- only_mean (bool, optional): Flag to indicate if only the mean curve should be plotted. Defaults to True.
74
-
75
- Note:
66
+ y (np.ndarray): Corresponding data points for the y-axis with shape (C, N), where C is the number of classes.
67
+ names (List): Names of the classes corresponding to the y-axis data; length C.
68
+ id (str): Unique identifier for the logged data in wandb.
69
+ title (str): Title for the visualization plot.
70
+ x_title (str): Label for the x-axis.
71
+ y_title (str): Label for the y-axis.
72
+ num_x (int): Number of interpolated data points for visualization.
73
+ only_mean (bool): Flag to indicate if only the mean curve should be plotted.
74
+
75
+ Notes:
76
76
  The function leverages the '_custom_table' function to generate the actual visualization.
77
77
  """
78
78
  import numpy as np
@@ -108,7 +108,7 @@ def _log_plots(plots, step):
108
108
 
109
109
 
110
110
  def on_pretrain_routine_start(trainer):
111
- """Initiate and start project if module is present."""
111
+ """Initiate and start wandb project if module is present."""
112
112
  if not wb.run:
113
113
  wb.init(
114
114
  project=str(trainer.args.project).replace("/", "-") if trainer.args.project else "Ultralytics",
@@ -118,7 +118,7 @@ def on_pretrain_routine_start(trainer):
118
118
 
119
119
 
120
120
  def on_fit_epoch_end(trainer):
121
- """Logs training metrics and model information at the end of an epoch."""
121
+ """Log training metrics and model information at the end of an epoch."""
122
122
  wb.run.log(trainer.metrics, step=trainer.epoch + 1)
123
123
  _log_plots(trainer.plots, step=trainer.epoch + 1)
124
124
  _log_plots(trainer.validator.plots, step=trainer.epoch + 1)
@@ -135,7 +135,7 @@ def on_train_epoch_end(trainer):
135
135
 
136
136
 
137
137
  def on_train_end(trainer):
138
- """Save the best model as an artifact at end of training."""
138
+ """Save the best model as an artifact and log final plots at the end of training."""
139
139
  _log_plots(trainer.validator.plots, step=trainer.epoch + 1)
140
140
  _log_plots(trainer.plots, step=trainer.epoch + 1)
141
141
  art = wb.Artifact(type="model", name=f"run_{wb.run.id}_model")
@@ -55,10 +55,10 @@ def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""):
55
55
 
56
56
  Args:
57
57
  file_path (Path): Path to the requirements.txt file.
58
- package (str, optional): Python package to use instead of requirements.txt file, i.e. package='ultralytics'.
58
+ package (str, optional): Python package to use instead of requirements.txt file.
59
59
 
60
60
  Returns:
61
- (List[Dict[str, str]]): List of parsed requirements as dictionaries with `name` and `specifier` keys.
61
+ (List[SimpleNamespace]): List of parsed requirements as SimpleNamespace objects with `name` and `specifier` attributes.
62
62
 
63
63
  Examples:
64
64
  >>> from ultralytics.utils.checks import parse_requirements
@@ -82,14 +82,13 @@ def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""):
82
82
 
83
83
  def parse_version(version="0.0.0") -> tuple:
84
84
  """
85
- Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version. This
86
- function replaces deprecated 'pkg_resources.parse_version(v)'.
85
+ Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version.
87
86
 
88
87
  Args:
89
88
  version (str): Version string, i.e. '2.0.1+cpu'
90
89
 
91
90
  Returns:
92
- (tuple): Tuple of integers representing the numeric part of the version and the extra string, i.e. (2, 0, 1)
91
+ (tuple): Tuple of integers representing the numeric part of the version, i.e. (2, 0, 1)
93
92
  """
94
93
  try:
95
94
  return tuple(map(int, re.findall(r"\d+", version)[:3])) # '2.0.1+cpu' -> (2, 0, 1)
@@ -121,7 +120,7 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
121
120
  stride, update it to the nearest multiple of the stride that is greater than or equal to the given floor value.
122
121
 
123
122
  Args:
124
- imgsz (int | cList[int]): Image size.
123
+ imgsz (int | List[int]): Image size.
125
124
  stride (int): Stride value.
126
125
  min_dim (int): Minimum number of dimensions.
127
126
  max_dim (int): Maximum number of dimensions.
@@ -183,10 +182,10 @@ def check_version(
183
182
  Args:
184
183
  current (str): Current version or package name to get version from.
185
184
  required (str): Required version or range (in pip-style format).
186
- name (str, optional): Name to be used in warning message.
187
- hard (bool, optional): If True, raise an AssertionError if the requirement is not met.
188
- verbose (bool, optional): If True, print warning message if requirement is not met.
189
- msg (str, optional): Extra message to display if verbose.
185
+ name (str): Name to be used in warning message.
186
+ hard (bool): If True, raise an AssertionError if the requirement is not met.
187
+ verbose (bool): If True, print warning message if requirement is not met.
188
+ msg (str): Extra message to display if verbose.
190
189
 
191
190
  Returns:
192
191
  (bool): True if requirement is met, False otherwise.
@@ -308,7 +307,7 @@ def check_font(font="Arial.ttf"):
308
307
  font (str): Path or name of font.
309
308
 
310
309
  Returns:
311
- file (Path): Resolved font file path.
310
+ (Path): Resolved font file path.
312
311
  """
313
312
  from matplotlib import font_manager
314
313
 
@@ -336,8 +335,8 @@ def check_python(minimum: str = "3.8.0", hard: bool = True, verbose: bool = Fals
336
335
 
337
336
  Args:
338
337
  minimum (str): Required minimum version of python.
339
- hard (bool, optional): If True, raise an AssertionError if the requirement is not met.
340
- verbose (bool, optional): If True, print warning message if requirement is not met.
338
+ hard (bool): If True, raise an AssertionError if the requirement is not met.
339
+ verbose (bool): If True, print warning message if requirement is not met.
341
340
 
342
341
  Returns:
343
342
  (bool): Whether the installed Python version meets the minimum constraints.
@@ -420,11 +419,7 @@ def check_torchvision():
420
419
  Checks the installed versions of PyTorch and Torchvision to ensure they're compatible.
421
420
 
422
421
  This function checks the installed versions of PyTorch and Torchvision, and warns if they're incompatible according
423
- to the provided compatibility table based on:
424
- https://github.com/pytorch/vision#installation.
425
-
426
- The compatibility table is a dictionary where the keys are PyTorch versions and the values are lists of compatible
427
- Torchvision versions.
422
+ to the compatibility table based on: https://github.com/pytorch/vision#installation.
428
423
  """
429
424
  compatibility_table = {
430
425
  "2.6": ["0.21"],
@@ -453,7 +448,14 @@ def check_torchvision():
453
448
 
454
449
 
455
450
  def check_suffix(file="yolo11n.pt", suffix=".pt", msg=""):
456
- """Check file(s) for acceptable suffix."""
451
+ """
452
+ Check file(s) for acceptable suffix.
453
+
454
+ Args:
455
+ file (str | List[str]): File or list of files to check.
456
+ suffix (str | Tuple[str]): Acceptable suffix or tuple of suffixes.
457
+ msg (str): Additional message to display in case of error.
458
+ """
457
459
  if file and suffix:
458
460
  if isinstance(suffix, str):
459
461
  suffix = (suffix,)
@@ -464,7 +466,16 @@ def check_suffix(file="yolo11n.pt", suffix=".pt", msg=""):
464
466
 
465
467
 
466
468
  def check_yolov5u_filename(file: str, verbose: bool = True):
467
- """Replace legacy YOLOv5 filenames with updated YOLOv5u filenames."""
469
+ """
470
+ Replace legacy YOLOv5 filenames with updated YOLOv5u filenames.
471
+
472
+ Args:
473
+ file (str): Filename to check and potentially update.
474
+ verbose (bool): Whether to print information about the replacement.
475
+
476
+ Returns:
477
+ (str): Updated filename.
478
+ """
468
479
  if "yolov3" in file or "yolov5" in file:
469
480
  if "u.yaml" in file:
470
481
  file = file.replace("u.yaml", ".yaml") # i.e. yolov5nu.yaml -> yolov5n.yaml
@@ -483,7 +494,15 @@ def check_yolov5u_filename(file: str, verbose: bool = True):
483
494
 
484
495
 
485
496
  def check_model_file_from_stem(model="yolo11n"):
486
- """Return a model filename from a valid model stem."""
497
+ """
498
+ Return a model filename from a valid model stem.
499
+
500
+ Args:
501
+ model (str): Model stem to check.
502
+
503
+ Returns:
504
+ (str | Path): Model filename with appropriate suffix.
505
+ """
487
506
  if model and not Path(model).suffix and Path(model).stem in downloads.GITHUB_ASSETS_STEMS:
488
507
  return Path(model).with_suffix(".pt") # add suffix, i.e. yolo11n -> yolo11n.pt
489
508
  else:
@@ -491,7 +510,19 @@ def check_model_file_from_stem(model="yolo11n"):
491
510
 
492
511
 
493
512
  def check_file(file, suffix="", download=True, download_dir=".", hard=True):
494
- """Search/download file (if necessary) and return path."""
513
+ """
514
+ Search/download file (if necessary) and return path.
515
+
516
+ Args:
517
+ file (str): File name or path.
518
+ suffix (str): File suffix to check.
519
+ download (bool): Whether to download the file if it doesn't exist locally.
520
+ download_dir (str): Directory to download the file to.
521
+ hard (bool): Whether to raise an error if the file is not found.
522
+
523
+ Returns:
524
+ (str): Path to the file.
525
+ """
495
526
  check_suffix(file, suffix) # optional
496
527
  file = str(file).strip() # convert to string and strip spaces
497
528
  file = check_yolov5u_filename(file) # yolov5n -> yolov5nu
@@ -519,7 +550,17 @@ def check_file(file, suffix="", download=True, download_dir=".", hard=True):
519
550
 
520
551
 
521
552
  def check_yaml(file, suffix=(".yaml", ".yml"), hard=True):
522
- """Search/download YAML file (if necessary) and return path, checking suffix."""
553
+ """
554
+ Search/download YAML file (if necessary) and return path, checking suffix.
555
+
556
+ Args:
557
+ file (str): File name or path.
558
+ suffix (tuple): Acceptable file suffixes.
559
+ hard (bool): Whether to raise an error if the file is not found.
560
+
561
+ Returns:
562
+ (str): Path to the YAML file.
563
+ """
523
564
  return check_file(file, suffix, hard=hard)
524
565
 
525
566
 
@@ -541,7 +582,15 @@ def check_is_path_safe(basedir, path):
541
582
 
542
583
 
543
584
  def check_imshow(warn=False):
544
- """Check if environment supports image displays."""
585
+ """
586
+ Check if environment supports image displays.
587
+
588
+ Args:
589
+ warn (bool): Whether to warn if environment doesn't support image displays.
590
+
591
+ Returns:
592
+ (bool): True if environment supports image displays, False otherwise.
593
+ """
545
594
  try:
546
595
  if LINUX:
547
596
  assert not IS_COLAB and not IS_KAGGLE
@@ -558,7 +607,13 @@ def check_imshow(warn=False):
558
607
 
559
608
 
560
609
  def check_yolo(verbose=True, device=""):
561
- """Return a human-readable YOLO software and hardware summary."""
610
+ """
611
+ Return a human-readable YOLO software and hardware summary.
612
+
613
+ Args:
614
+ verbose (bool): Whether to print verbose information.
615
+ device (str): Device to use for YOLO.
616
+ """
562
617
  import psutil
563
618
 
564
619
  from ultralytics.utils.torch_utils import select_device
@@ -586,7 +641,12 @@ def check_yolo(verbose=True, device=""):
586
641
 
587
642
 
588
643
  def collect_system_info():
589
- """Collect and print relevant system information including OS, Python, RAM, CPU, and CUDA."""
644
+ """
645
+ Collect and print relevant system information including OS, Python, RAM, CPU, and CUDA.
646
+
647
+ Returns:
648
+ (dict): Dictionary containing system information.
649
+ """
590
650
  import psutil
591
651
 
592
652
  from ultralytics.utils import ENVIRONMENT # scope to avoid circular import
@@ -643,21 +703,22 @@ def collect_system_info():
643
703
 
644
704
  def check_amp(model):
645
705
  """
646
- Checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLO11 model. If the checks fail, it means
647
- there are anomalies with AMP on the system that may cause NaN losses or zero-mAP results, so AMP will be disabled
648
- during training.
706
+ Checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLO11 model.
707
+
708
+ If the checks fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP
709
+ results, so AMP will be disabled during training.
649
710
 
650
711
  Args:
651
712
  model (nn.Module): A YOLO11 model instance.
652
713
 
714
+ Returns:
715
+ (bool): Returns True if the AMP functionality works correctly with YOLO11 model, else False.
716
+
653
717
  Examples:
654
718
  >>> from ultralytics import YOLO
655
719
  >>> from ultralytics.utils.checks import check_amp
656
720
  >>> model = YOLO("yolo11n.pt").model.cuda()
657
721
  >>> check_amp(model)
658
-
659
- Returns:
660
- (bool): Returns True if the AMP functionality works correctly with YOLO11 model, else False.
661
722
  """
662
723
  from ultralytics.utils.torch_utils import autocast
663
724
 
@@ -716,7 +777,15 @@ def check_amp(model):
716
777
 
717
778
 
718
779
  def git_describe(path=ROOT): # path must be a directory
719
- """Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe."""
780
+ """
781
+ Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe.
782
+
783
+ Args:
784
+ path (Path): Path to git repository.
785
+
786
+ Returns:
787
+ (str): Human-readable git description.
788
+ """
720
789
  try:
721
790
  return subprocess.check_output(f"git -C {path} describe --tags --long --always", shell=True).decode()[:-1]
722
791
  except Exception:
@@ -724,7 +793,14 @@ def git_describe(path=ROOT): # path must be a directory
724
793
 
725
794
 
726
795
  def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
727
- """Print function arguments (optional args dict)."""
796
+ """
797
+ Print function arguments (optional args dict).
798
+
799
+ Args:
800
+ args (dict, optional): Arguments to print.
801
+ show_file (bool): Whether to show the file name.
802
+ show_func (bool): Whether to show the function name.
803
+ """
728
804
 
729
805
  def strip_auth(v):
730
806
  """Clean longer Ultralytics HUB URLs by stripping potential authentication information."""
@@ -776,7 +852,12 @@ def cuda_is_available() -> bool:
776
852
 
777
853
 
778
854
  def is_rockchip():
779
- """Check if the current environment is running on a Rockchip SoC."""
855
+ """
856
+ Check if the current environment is running on a Rockchip SoC.
857
+
858
+ Returns:
859
+ (bool): True if running on a Rockchip SoC, False otherwise.
860
+ """
780
861
  if LINUX and ARM64:
781
862
  try:
782
863
  with open("/proc/device-tree/compatible") as f:
ultralytics/utils/dist.py CHANGED
@@ -12,10 +12,13 @@ from .torch_utils import TORCH_1_9
12
12
 
13
13
  def find_free_network_port() -> int:
14
14
  """
15
- Finds a free port on localhost.
15
+ Find a free port on localhost.
16
16
 
17
17
  It is useful in single-node training when we don't want to connect to a real main node but have to set the
18
18
  `MASTER_PORT` environment variable.
19
+
20
+ Returns:
21
+ (int): The available network port number.
19
22
  """
20
23
  with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
21
24
  s.bind(("127.0.0.1", 0))
@@ -54,7 +57,17 @@ if __name__ == "__main__":
54
57
 
55
58
 
56
59
  def generate_ddp_command(world_size, trainer):
57
- """Generates and returns command for distributed training."""
60
+ """
61
+ Generate command for distributed training.
62
+
63
+ Args:
64
+ world_size (int): Number of processes to spawn for distributed training.
65
+ trainer (object): The trainer object containing configuration for distributed training.
66
+
67
+ Returns:
68
+ cmd (List[str]): The command to execute for distributed training.
69
+ file (str): Path to the temporary file created for DDP training.
70
+ """
58
71
  import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
59
72
 
60
73
  if not trainer.resume:
@@ -65,7 +65,7 @@ def is_url(url, check=False):
65
65
 
66
66
  def delete_dsstore(path, files_to_delete=(".DS_Store", "__MACOSX")):
67
67
  """
68
- Deletes all ".DS_store" files under a specified directory.
68
+ Delete all ".DS_store" files in a specified directory.
69
69
 
70
70
  Args:
71
71
  path (str, optional): The directory path where the ".DS_store" files should be deleted.
@@ -75,7 +75,7 @@ def delete_dsstore(path, files_to_delete=(".DS_Store", "__MACOSX")):
75
75
  >>> from ultralytics.utils.downloads import delete_dsstore
76
76
  >>> delete_dsstore("path/to/dir")
77
77
 
78
- Note:
78
+ Notes:
79
79
  ".DS_store" files are created by the Apple operating system and contain metadata about folders and files. They
80
80
  are hidden system files and can cause issues when transferring files between different operating systems.
81
81
  """
@@ -132,7 +132,7 @@ def unzip_file(file, path=None, exclude=(".DS_Store", "__MACOSX"), exist_ok=Fals
132
132
 
133
133
  Args:
134
134
  file (str | Path): The path to the zipfile to be extracted.
135
- path (str, optional): The path to extract the zipfile to. Defaults to None.
135
+ path (str | Path, optional): The path to extract the zipfile to. Defaults to None.
136
136
  exclude (tuple, optional): A tuple of filename strings to be excluded. Defaults to ('.DS_Store', '__MACOSX').
137
137
  exist_ok (bool, optional): Whether to overwrite existing contents if they exist. Defaults to False.
138
138
  progress (bool, optional): Whether to display a progress bar. Defaults to True.
@@ -280,7 +280,7 @@ def safe_download(
280
280
  url (str): The URL of the file to be downloaded.
281
281
  file (str, optional): The filename of the downloaded file.
282
282
  If not provided, the file will be saved with the same name as the URL.
283
- dir (str, optional): The directory to save the downloaded file.
283
+ dir (str | Path, optional): The directory to save the downloaded file.
284
284
  If not provided, the file will be saved in the current working directory.
285
285
  unzip (bool, optional): Whether to unzip the downloaded file. Default: True.
286
286
  delete (bool, optional): Whether to delete the downloaded file after unzipping. Default: False.
@@ -291,6 +291,9 @@ def safe_download(
291
291
  exist_ok (bool, optional): Whether to overwrite existing contents during unzipping. Defaults to False.
292
292
  progress (bool, optional): Whether to display a progress bar during the download. Default: True.
293
293
 
294
+ Returns:
295
+ (Path | str): The path to the downloaded file or extracted directory.
296
+
294
297
  Examples:
295
298
  >>> from ultralytics.utils.downloads import safe_download
296
299
  >>> link = "https://ultralytics.com/assets/bus.jpg"
@@ -359,6 +362,7 @@ def safe_download(
359
362
  if delete:
360
363
  f.unlink() # remove zip
361
364
  return unzip_dir
365
+ return f
362
366
 
363
367
 
364
368
  def get_github_assets(repo="ultralytics/assets", version="latest", retry=False):
@@ -372,7 +376,8 @@ def get_github_assets(repo="ultralytics/assets", version="latest", retry=False):
372
376
  retry (bool, optional): Flag to retry the request in case of a failure. Defaults to False.
373
377
 
374
378
  Returns:
375
- (tuple): A tuple containing the release tag and a list of asset names.
379
+ (str): The release tag.
380
+ (List[str]): A list of asset names.
376
381
 
377
382
  Examples:
378
383
  >>> tag, assets = get_github_assets(repo="ultralytics/assets", version="latest")
@@ -392,14 +397,13 @@ def get_github_assets(repo="ultralytics/assets", version="latest", retry=False):
392
397
 
393
398
  def attempt_download_asset(file, repo="ultralytics/assets", release="v8.3.0", **kwargs):
394
399
  """
395
- Attempt to download a file from GitHub release assets if it is not found locally. The function checks for the file
396
- locally first, then tries to download it from the specified GitHub repository release.
400
+ Attempt to download a file from GitHub release assets if it is not found locally.
397
401
 
398
402
  Args:
399
403
  file (str | Path): The filename or file path to be downloaded.
400
404
  repo (str, optional): The GitHub repository in the format 'owner/repo'. Defaults to 'ultralytics/assets'.
401
405
  release (str, optional): The specific release version to be downloaded. Defaults to 'v8.3.0'.
402
- **kwargs (any): Additional keyword arguments for the download process.
406
+ **kwargs (Any): Additional keyword arguments for the download process.
403
407
 
404
408
  Returns:
405
409
  (str): The path to the downloaded file.
@@ -448,7 +452,7 @@ def download(url, dir=Path.cwd(), unzip=True, delete=False, curl=False, threads=
448
452
  specified.
449
453
 
450
454
  Args:
451
- url (str | list): The URL or list of URLs of the files to be downloaded.
455
+ url (str | List[str]): The URL or list of URLs of the files to be downloaded.
452
456
  dir (Path, optional): The directory where the files will be saved. Defaults to the current working directory.
453
457
  unzip (bool, optional): Flag to unzip the files after downloading. Defaults to True.
454
458
  delete (bool, optional): Flag to delete the zip files after extraction. Defaults to False.