ultralytics 8.3.88__py3-none-any.whl → 8.3.90__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (155) hide show
  1. tests/conftest.py +2 -2
  2. tests/test_cli.py +13 -11
  3. tests/test_cuda.py +10 -1
  4. tests/test_integrations.py +1 -5
  5. tests/test_python.py +16 -16
  6. tests/test_solutions.py +9 -9
  7. ultralytics/__init__.py +1 -1
  8. ultralytics/cfg/__init__.py +3 -1
  9. ultralytics/cfg/models/11/yolo11-cls.yaml +5 -5
  10. ultralytics/cfg/models/11/yolo11-obb.yaml +5 -5
  11. ultralytics/cfg/models/11/yolo11-pose.yaml +5 -5
  12. ultralytics/cfg/models/11/yolo11-seg.yaml +5 -5
  13. ultralytics/cfg/models/11/yolo11.yaml +5 -5
  14. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +5 -5
  15. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +5 -5
  16. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -5
  17. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -5
  18. ultralytics/cfg/models/v8/yolov8-p6.yaml +5 -5
  19. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -5
  20. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -5
  21. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -5
  22. ultralytics/cfg/models/v8/yolov8.yaml +5 -5
  23. ultralytics/cfg/models/v9/yolov9c-seg.yaml +1 -1
  24. ultralytics/cfg/models/v9/yolov9c.yaml +1 -1
  25. ultralytics/cfg/models/v9/yolov9e-seg.yaml +1 -1
  26. ultralytics/cfg/models/v9/yolov9e.yaml +1 -1
  27. ultralytics/cfg/models/v9/yolov9m.yaml +1 -1
  28. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  29. ultralytics/cfg/models/v9/yolov9t.yaml +1 -1
  30. ultralytics/data/annotator.py +9 -14
  31. ultralytics/data/base.py +125 -39
  32. ultralytics/data/build.py +63 -24
  33. ultralytics/data/converter.py +34 -33
  34. ultralytics/data/dataset.py +207 -53
  35. ultralytics/data/loaders.py +1 -0
  36. ultralytics/data/split_dota.py +39 -12
  37. ultralytics/data/utils.py +33 -47
  38. ultralytics/engine/exporter.py +19 -17
  39. ultralytics/engine/model.py +69 -90
  40. ultralytics/engine/predictor.py +106 -21
  41. ultralytics/engine/trainer.py +32 -23
  42. ultralytics/engine/tuner.py +31 -38
  43. ultralytics/engine/validator.py +75 -41
  44. ultralytics/hub/__init__.py +21 -26
  45. ultralytics/hub/auth.py +9 -12
  46. ultralytics/hub/session.py +76 -21
  47. ultralytics/hub/utils.py +19 -17
  48. ultralytics/models/fastsam/model.py +23 -17
  49. ultralytics/models/fastsam/predict.py +36 -16
  50. ultralytics/models/fastsam/utils.py +5 -5
  51. ultralytics/models/fastsam/val.py +6 -6
  52. ultralytics/models/nas/model.py +29 -24
  53. ultralytics/models/nas/predict.py +14 -11
  54. ultralytics/models/nas/val.py +11 -13
  55. ultralytics/models/rtdetr/model.py +20 -11
  56. ultralytics/models/rtdetr/predict.py +21 -21
  57. ultralytics/models/rtdetr/train.py +25 -24
  58. ultralytics/models/rtdetr/val.py +47 -14
  59. ultralytics/models/sam/__init__.py +1 -1
  60. ultralytics/models/sam/amg.py +50 -4
  61. ultralytics/models/sam/model.py +8 -14
  62. ultralytics/models/sam/modules/decoders.py +18 -21
  63. ultralytics/models/sam/modules/encoders.py +25 -46
  64. ultralytics/models/sam/modules/memory_attention.py +19 -15
  65. ultralytics/models/sam/modules/sam.py +18 -25
  66. ultralytics/models/sam/modules/tiny_encoder.py +19 -29
  67. ultralytics/models/sam/modules/transformer.py +35 -57
  68. ultralytics/models/sam/modules/utils.py +15 -15
  69. ultralytics/models/sam/predict.py +0 -3
  70. ultralytics/models/utils/loss.py +87 -36
  71. ultralytics/models/utils/ops.py +26 -31
  72. ultralytics/models/yolo/classify/predict.py +30 -12
  73. ultralytics/models/yolo/classify/train.py +83 -19
  74. ultralytics/models/yolo/classify/val.py +45 -23
  75. ultralytics/models/yolo/detect/predict.py +29 -19
  76. ultralytics/models/yolo/detect/train.py +90 -23
  77. ultralytics/models/yolo/detect/val.py +150 -29
  78. ultralytics/models/yolo/model.py +1 -2
  79. ultralytics/models/yolo/obb/predict.py +18 -13
  80. ultralytics/models/yolo/obb/train.py +12 -8
  81. ultralytics/models/yolo/obb/val.py +35 -22
  82. ultralytics/models/yolo/pose/predict.py +28 -15
  83. ultralytics/models/yolo/pose/train.py +21 -8
  84. ultralytics/models/yolo/pose/val.py +51 -31
  85. ultralytics/models/yolo/segment/predict.py +27 -16
  86. ultralytics/models/yolo/segment/train.py +11 -8
  87. ultralytics/models/yolo/segment/val.py +110 -29
  88. ultralytics/models/yolo/world/train.py +43 -16
  89. ultralytics/models/yolo/world/train_world.py +61 -36
  90. ultralytics/nn/autobackend.py +28 -14
  91. ultralytics/nn/modules/__init__.py +12 -12
  92. ultralytics/nn/modules/activation.py +12 -3
  93. ultralytics/nn/modules/block.py +587 -84
  94. ultralytics/nn/modules/conv.py +418 -54
  95. ultralytics/nn/modules/head.py +3 -4
  96. ultralytics/nn/modules/transformer.py +320 -34
  97. ultralytics/nn/modules/utils.py +17 -3
  98. ultralytics/nn/tasks.py +226 -79
  99. ultralytics/solutions/ai_gym.py +2 -2
  100. ultralytics/solutions/analytics.py +4 -4
  101. ultralytics/solutions/heatmap.py +4 -4
  102. ultralytics/solutions/instance_segmentation.py +10 -4
  103. ultralytics/solutions/object_blurrer.py +2 -2
  104. ultralytics/solutions/object_counter.py +2 -2
  105. ultralytics/solutions/object_cropper.py +2 -2
  106. ultralytics/solutions/parking_management.py +9 -9
  107. ultralytics/solutions/queue_management.py +1 -1
  108. ultralytics/solutions/region_counter.py +2 -2
  109. ultralytics/solutions/security_alarm.py +7 -7
  110. ultralytics/solutions/solutions.py +7 -4
  111. ultralytics/solutions/speed_estimation.py +2 -2
  112. ultralytics/solutions/streamlit_inference.py +6 -6
  113. ultralytics/solutions/trackzone.py +9 -2
  114. ultralytics/solutions/vision_eye.py +4 -4
  115. ultralytics/trackers/basetrack.py +1 -1
  116. ultralytics/trackers/bot_sort.py +23 -22
  117. ultralytics/trackers/byte_tracker.py +4 -4
  118. ultralytics/trackers/track.py +2 -1
  119. ultralytics/trackers/utils/gmc.py +26 -27
  120. ultralytics/trackers/utils/kalman_filter.py +31 -29
  121. ultralytics/trackers/utils/matching.py +7 -7
  122. ultralytics/utils/__init__.py +37 -35
  123. ultralytics/utils/autobatch.py +5 -5
  124. ultralytics/utils/benchmarks.py +111 -18
  125. ultralytics/utils/callbacks/base.py +3 -3
  126. ultralytics/utils/callbacks/clearml.py +11 -11
  127. ultralytics/utils/callbacks/comet.py +35 -22
  128. ultralytics/utils/callbacks/dvc.py +11 -10
  129. ultralytics/utils/callbacks/hub.py +8 -8
  130. ultralytics/utils/callbacks/mlflow.py +1 -1
  131. ultralytics/utils/callbacks/neptune.py +12 -10
  132. ultralytics/utils/callbacks/raytune.py +1 -1
  133. ultralytics/utils/callbacks/tensorboard.py +6 -6
  134. ultralytics/utils/callbacks/wb.py +16 -16
  135. ultralytics/utils/checks.py +139 -68
  136. ultralytics/utils/dist.py +15 -2
  137. ultralytics/utils/downloads.py +37 -56
  138. ultralytics/utils/files.py +12 -13
  139. ultralytics/utils/instance.py +117 -52
  140. ultralytics/utils/loss.py +28 -33
  141. ultralytics/utils/metrics.py +246 -181
  142. ultralytics/utils/ops.py +65 -61
  143. ultralytics/utils/patches.py +8 -6
  144. ultralytics/utils/plotting.py +72 -59
  145. ultralytics/utils/tal.py +88 -57
  146. ultralytics/utils/torch_utils.py +202 -64
  147. ultralytics/utils/triton.py +13 -3
  148. ultralytics/utils/tuner.py +13 -25
  149. {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/METADATA +2 -2
  150. ultralytics-8.3.90.dist-info/RECORD +250 -0
  151. ultralytics-8.3.88.dist-info/RECORD +0 -250
  152. {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/LICENSE +0 -0
  153. {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/WHEEL +0 -0
  154. {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/entry_points.txt +0 -0
  155. {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,12 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+
3
4
  from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
4
5
 
5
6
  try:
6
7
  assert not TESTS_RUNNING # do not log pytest
7
8
  assert SETTINGS["neptune"] is True # verify integration is enabled
9
+
8
10
  import neptune
9
11
  from neptune.types import File
10
12
 
@@ -16,27 +18,27 @@ except (ImportError, AssertionError):
16
18
  neptune = None
17
19
 
18
20
 
19
- def _log_scalars(scalars, step=0):
21
+ def _log_scalars(scalars: dict, step: int = 0) -> None:
20
22
  """Log scalars to the NeptuneAI experiment logger."""
21
23
  if run:
22
24
  for k, v in scalars.items():
23
25
  run[k].append(value=v, step=step)
24
26
 
25
27
 
26
- def _log_images(imgs_dict, group=""):
27
- """Log scalars to the NeptuneAI experiment logger."""
28
+ def _log_images(imgs_dict: dict, group: str = "") -> None:
29
+ """Log images to the NeptuneAI experiment logger."""
28
30
  if run:
29
31
  for k, v in imgs_dict.items():
30
32
  run[f"{group}/{k}"].upload(File(v))
31
33
 
32
34
 
33
- def _log_plot(title, plot_path):
35
+ def _log_plot(title: str, plot_path: str) -> None:
34
36
  """
35
37
  Log plots to the NeptuneAI experiment logger.
36
38
 
37
39
  Args:
38
40
  title (str): Title of the plot.
39
- plot_path (PosixPath | str): Path to the saved image file.
41
+ plot_path (str): Path to the saved image file.
40
42
  """
41
43
  import matplotlib.image as mpimg
42
44
  import matplotlib.pyplot as plt
@@ -48,7 +50,7 @@ def _log_plot(title, plot_path):
48
50
  run[f"Plots/{title}"].upload(fig)
49
51
 
50
52
 
51
- def on_pretrain_routine_start(trainer):
53
+ def on_pretrain_routine_start(trainer) -> None:
52
54
  """Callback function called before the training routine starts."""
53
55
  try:
54
56
  global run
@@ -62,7 +64,7 @@ def on_pretrain_routine_start(trainer):
62
64
  LOGGER.warning(f"WARNING ⚠️ NeptuneAI installed but not initialized correctly, not logging this run. {e}")
63
65
 
64
66
 
65
- def on_train_epoch_end(trainer):
67
+ def on_train_epoch_end(trainer) -> None:
66
68
  """Callback function called at end of each training epoch."""
67
69
  _log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
68
70
  _log_scalars(trainer.lr, trainer.epoch + 1)
@@ -70,7 +72,7 @@ def on_train_epoch_end(trainer):
70
72
  _log_images({f.stem: str(f) for f in trainer.save_dir.glob("train_batch*.jpg")}, "Mosaic")
71
73
 
72
74
 
73
- def on_fit_epoch_end(trainer):
75
+ def on_fit_epoch_end(trainer) -> None:
74
76
  """Callback function called at end of each fit (train+val) epoch."""
75
77
  if run and trainer.epoch == 0:
76
78
  from ultralytics.utils.torch_utils import model_info_for_loggers
@@ -79,14 +81,14 @@ def on_fit_epoch_end(trainer):
79
81
  _log_scalars(trainer.metrics, trainer.epoch + 1)
80
82
 
81
83
 
82
- def on_val_end(validator):
84
+ def on_val_end(validator) -> None:
83
85
  """Callback function called at end of each validation."""
84
86
  if run:
85
87
  # Log val_labels and val_pred
86
88
  _log_images({f.stem: str(f) for f in validator.save_dir.glob("val*.jpg")}, "Validation")
87
89
 
88
90
 
89
- def on_train_end(trainer):
91
+ def on_train_end(trainer) -> None:
90
92
  """Callback function called at end of training."""
91
93
  if run:
92
94
  # Log final results, CM matrix + PR plots
@@ -14,7 +14,7 @@ except (ImportError, AssertionError):
14
14
 
15
15
  def on_fit_epoch_end(trainer):
16
16
  """Sends training metrics to Ray Tune at end of each epoch."""
17
- if ray.train._internal.session.get_session(): # replacement for deprecated ray.tune.is_session_enabled()
17
+ if ray.train._internal.session.get_session(): # check if Ray Tune session is active
18
18
  metrics = trainer.metrics
19
19
  session.report({**metrics, **{"epoch": trainer.epoch + 1}})
20
20
 
@@ -23,14 +23,14 @@ except (ImportError, AssertionError, TypeError, AttributeError):
23
23
  SummaryWriter = None
24
24
 
25
25
 
26
- def _log_scalars(scalars, step=0):
26
+ def _log_scalars(scalars: dict, step: int = 0) -> None:
27
27
  """Logs scalar values to TensorBoard."""
28
28
  if WRITER:
29
29
  for k, v in scalars.items():
30
30
  WRITER.add_scalar(k, v, step)
31
31
 
32
32
 
33
- def _log_tensorboard_graph(trainer):
33
+ def _log_tensorboard_graph(trainer) -> None:
34
34
  """Log model graph to TensorBoard."""
35
35
  # Input image
36
36
  imgsz = trainer.args.imgsz
@@ -66,7 +66,7 @@ def _log_tensorboard_graph(trainer):
66
66
  LOGGER.warning(f"{PREFIX}WARNING ⚠️ TensorBoard graph visualization failure {e}")
67
67
 
68
68
 
69
- def on_pretrain_routine_start(trainer):
69
+ def on_pretrain_routine_start(trainer) -> None:
70
70
  """Initialize TensorBoard logging with SummaryWriter."""
71
71
  if SummaryWriter:
72
72
  try:
@@ -77,19 +77,19 @@ def on_pretrain_routine_start(trainer):
77
77
  LOGGER.warning(f"{PREFIX}WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}")
78
78
 
79
79
 
80
- def on_train_start(trainer):
80
+ def on_train_start(trainer) -> None:
81
81
  """Log TensorBoard graph."""
82
82
  if WRITER:
83
83
  _log_tensorboard_graph(trainer)
84
84
 
85
85
 
86
- def on_train_epoch_end(trainer):
86
+ def on_train_epoch_end(trainer) -> None:
87
87
  """Logs scalar statistics at the end of a training epoch."""
88
88
  _log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
89
89
  _log_scalars(trainer.lr, trainer.epoch + 1)
90
90
 
91
91
 
92
- def on_fit_epoch_end(trainer):
92
+ def on_fit_epoch_end(trainer) -> None:
93
93
  """Logs epoch metrics at end of training epoch."""
94
94
  _log_scalars(trainer.metrics, trainer.epoch + 1)
95
95
 
@@ -27,9 +27,9 @@ def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall
27
27
  x (List): Values for the x-axis; expected to have length N.
28
28
  y (List): Corresponding values for the y-axis; also expected to have length N.
29
29
  classes (List): Labels identifying the class of each point; length N.
30
- title (str, optional): Title for the plot; defaults to 'Precision Recall Curve'.
31
- x_title (str, optional): Label for the x-axis; defaults to 'Recall'.
32
- y_title (str, optional): Label for the y-axis; defaults to 'Precision'.
30
+ title (str): Title for the plot; defaults to 'Precision Recall Curve'.
31
+ x_title (str): Label for the x-axis; defaults to 'Recall'.
32
+ y_title (str): Label for the y-axis; defaults to 'Precision'.
33
33
 
34
34
  Returns:
35
35
  (wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization.
@@ -63,16 +63,16 @@ def _plot_curve(
63
63
 
64
64
  Args:
65
65
  x (np.ndarray): Data points for the x-axis with length N.
66
- y (np.ndarray): Corresponding data points for the y-axis with shape CxN, where C is the number of classes.
67
- names (list, optional): Names of the classes corresponding to the y-axis data; length C. Defaults to [].
68
- id (str, optional): Unique identifier for the logged data in wandb. Defaults to 'precision-recall'.
69
- title (str, optional): Title for the visualization plot. Defaults to 'Precision Recall Curve'.
70
- x_title (str, optional): Label for the x-axis. Defaults to 'Recall'.
71
- y_title (str, optional): Label for the y-axis. Defaults to 'Precision'.
72
- num_x (int, optional): Number of interpolated data points for visualization. Defaults to 100.
73
- only_mean (bool, optional): Flag to indicate if only the mean curve should be plotted. Defaults to True.
74
-
75
- Note:
66
+ y (np.ndarray): Corresponding data points for the y-axis with shape (C, N), where C is the number of classes.
67
+ names (List): Names of the classes corresponding to the y-axis data; length C.
68
+ id (str): Unique identifier for the logged data in wandb.
69
+ title (str): Title for the visualization plot.
70
+ x_title (str): Label for the x-axis.
71
+ y_title (str): Label for the y-axis.
72
+ num_x (int): Number of interpolated data points for visualization.
73
+ only_mean (bool): Flag to indicate if only the mean curve should be plotted.
74
+
75
+ Notes:
76
76
  The function leverages the '_custom_table' function to generate the actual visualization.
77
77
  """
78
78
  import numpy as np
@@ -108,7 +108,7 @@ def _log_plots(plots, step):
108
108
 
109
109
 
110
110
  def on_pretrain_routine_start(trainer):
111
- """Initiate and start project if module is present."""
111
+ """Initiate and start wandb project if module is present."""
112
112
  if not wb.run:
113
113
  wb.init(
114
114
  project=str(trainer.args.project).replace("/", "-") if trainer.args.project else "Ultralytics",
@@ -118,7 +118,7 @@ def on_pretrain_routine_start(trainer):
118
118
 
119
119
 
120
120
  def on_fit_epoch_end(trainer):
121
- """Logs training metrics and model information at the end of an epoch."""
121
+ """Log training metrics and model information at the end of an epoch."""
122
122
  wb.run.log(trainer.metrics, step=trainer.epoch + 1)
123
123
  _log_plots(trainer.plots, step=trainer.epoch + 1)
124
124
  _log_plots(trainer.validator.plots, step=trainer.epoch + 1)
@@ -135,7 +135,7 @@ def on_train_epoch_end(trainer):
135
135
 
136
136
 
137
137
  def on_train_end(trainer):
138
- """Save the best model as an artifact at end of training."""
138
+ """Save the best model as an artifact and log final plots at the end of training."""
139
139
  _log_plots(trainer.validator.plots, step=trainer.epoch + 1)
140
140
  _log_plots(trainer.plots, step=trainer.epoch + 1)
141
141
  art = wb.Artifact(type="model", name=f"run_{wb.run.id}_model")
@@ -55,17 +55,14 @@ def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""):
55
55
 
56
56
  Args:
57
57
  file_path (Path): Path to the requirements.txt file.
58
- package (str, optional): Python package to use instead of requirements.txt file, i.e. package='ultralytics'.
58
+ package (str, optional): Python package to use instead of requirements.txt file.
59
59
 
60
60
  Returns:
61
- (List[Dict[str, str]]): List of parsed requirements as dictionaries with `name` and `specifier` keys.
61
+ (List[SimpleNamespace]): List of parsed requirements as SimpleNamespace objects with `name` and `specifier` attributes.
62
62
 
63
- Example:
64
- ```python
65
- from ultralytics.utils.checks import parse_requirements
66
-
67
- parse_requirements(package="ultralytics")
68
- ```
63
+ Examples:
64
+ >>> from ultralytics.utils.checks import parse_requirements
65
+ >>> parse_requirements(package="ultralytics")
69
66
  """
70
67
  if package:
71
68
  requires = [x for x in metadata.distribution(package).requires if "extra == " not in x]
@@ -85,14 +82,13 @@ def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""):
85
82
 
86
83
  def parse_version(version="0.0.0") -> tuple:
87
84
  """
88
- Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version. This
89
- function replaces deprecated 'pkg_resources.parse_version(v)'.
85
+ Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version.
90
86
 
91
87
  Args:
92
88
  version (str): Version string, i.e. '2.0.1+cpu'
93
89
 
94
90
  Returns:
95
- (tuple): Tuple of integers representing the numeric part of the version and the extra string, i.e. (2, 0, 1)
91
+ (tuple): Tuple of integers representing the numeric part of the version, i.e. (2, 0, 1)
96
92
  """
97
93
  try:
98
94
  return tuple(map(int, re.findall(r"\d+", version)[:3])) # '2.0.1+cpu' -> (2, 0, 1)
@@ -124,7 +120,7 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
124
120
  stride, update it to the nearest multiple of the stride that is greater than or equal to the given floor value.
125
121
 
126
122
  Args:
127
- imgsz (int | cList[int]): Image size.
123
+ imgsz (int | List[int]): Image size.
128
124
  stride (int): Stride value.
129
125
  min_dim (int): Minimum number of dimensions.
130
126
  max_dim (int): Maximum number of dimensions.
@@ -186,28 +182,26 @@ def check_version(
186
182
  Args:
187
183
  current (str): Current version or package name to get version from.
188
184
  required (str): Required version or range (in pip-style format).
189
- name (str, optional): Name to be used in warning message.
190
- hard (bool, optional): If True, raise an AssertionError if the requirement is not met.
191
- verbose (bool, optional): If True, print warning message if requirement is not met.
192
- msg (str, optional): Extra message to display if verbose.
185
+ name (str): Name to be used in warning message.
186
+ hard (bool): If True, raise an AssertionError if the requirement is not met.
187
+ verbose (bool): If True, print warning message if requirement is not met.
188
+ msg (str): Extra message to display if verbose.
193
189
 
194
190
  Returns:
195
191
  (bool): True if requirement is met, False otherwise.
196
192
 
197
- Example:
198
- ```python
199
- # Check if current version is exactly 22.04
200
- check_version(current="22.04", required="==22.04")
193
+ Examples:
194
+ Check if current version is exactly 22.04
195
+ >>> check_version(current="22.04", required="==22.04")
201
196
 
202
- # Check if current version is greater than or equal to 22.04
203
- check_version(current="22.10", required="22.04") # assumes '>=' inequality if none passed
197
+ Check if current version is greater than or equal to 22.04
198
+ >>> check_version(current="22.10", required="22.04") # assumes '>=' inequality if none passed
204
199
 
205
- # Check if current version is less than or equal to 22.04
206
- check_version(current="22.04", required="<=22.04")
200
+ Check if current version is less than or equal to 22.04
201
+ >>> check_version(current="22.04", required="<=22.04")
207
202
 
208
- # Check if current version is between 20.04 (inclusive) and 22.04 (exclusive)
209
- check_version(current="21.10", required=">20.04,<22.04")
210
- ```
203
+ Check if current version is between 20.04 (inclusive) and 22.04 (exclusive)
204
+ >>> check_version(current="21.10", required=">20.04,<22.04")
211
205
  """
212
206
  if not current: # if current is '' or None
213
207
  LOGGER.warning(f"WARNING ⚠️ invalid check_version({current}, {required}) requested, please check values.")
@@ -313,7 +307,7 @@ def check_font(font="Arial.ttf"):
313
307
  font (str): Path or name of font.
314
308
 
315
309
  Returns:
316
- file (Path): Resolved font file path.
310
+ (Path): Resolved font file path.
317
311
  """
318
312
  from matplotlib import font_manager
319
313
 
@@ -341,8 +335,8 @@ def check_python(minimum: str = "3.8.0", hard: bool = True, verbose: bool = Fals
341
335
 
342
336
  Args:
343
337
  minimum (str): Required minimum version of python.
344
- hard (bool, optional): If True, raise an AssertionError if the requirement is not met.
345
- verbose (bool, optional): If True, print warning message if requirement is not met.
338
+ hard (bool): If True, raise an AssertionError if the requirement is not met.
339
+ verbose (bool): If True, print warning message if requirement is not met.
346
340
 
347
341
  Returns:
348
342
  (bool): Whether the installed Python version meets the minimum constraints.
@@ -362,19 +356,17 @@ def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=()
362
356
  install (bool): If True, attempt to auto-update packages that don't meet requirements.
363
357
  cmds (str): Additional commands to pass to the pip install command when auto-updating.
364
358
 
365
- Example:
366
- ```python
367
- from ultralytics.utils.checks import check_requirements
359
+ Examples:
360
+ >>> from ultralytics.utils.checks import check_requirements
368
361
 
369
- # Check a requirements.txt file
370
- check_requirements("path/to/requirements.txt")
362
+ Check a requirements.txt file
363
+ >>> check_requirements("path/to/requirements.txt")
371
364
 
372
- # Check a single package
373
- check_requirements("ultralytics>=8.0.0")
365
+ Check a single package
366
+ >>> check_requirements("ultralytics>=8.0.0")
374
367
 
375
- # Check multiple packages
376
- check_requirements(["numpy", "ultralytics>=8.0.0"])
377
- ```
368
+ Check multiple packages
369
+ >>> check_requirements(["numpy", "ultralytics>=8.0.0"])
378
370
  """
379
371
  prefix = colorstr("red", "bold", "requirements:")
380
372
  if isinstance(requirements, Path): # requirements.txt file
@@ -427,11 +419,7 @@ def check_torchvision():
427
419
  Checks the installed versions of PyTorch and Torchvision to ensure they're compatible.
428
420
 
429
421
  This function checks the installed versions of PyTorch and Torchvision, and warns if they're incompatible according
430
- to the provided compatibility table based on:
431
- https://github.com/pytorch/vision#installation.
432
-
433
- The compatibility table is a dictionary where the keys are PyTorch versions and the values are lists of compatible
434
- Torchvision versions.
422
+ to the compatibility table based on: https://github.com/pytorch/vision#installation.
435
423
  """
436
424
  compatibility_table = {
437
425
  "2.6": ["0.21"],
@@ -460,7 +448,14 @@ def check_torchvision():
460
448
 
461
449
 
462
450
  def check_suffix(file="yolo11n.pt", suffix=".pt", msg=""):
463
- """Check file(s) for acceptable suffix."""
451
+ """
452
+ Check file(s) for acceptable suffix.
453
+
454
+ Args:
455
+ file (str | List[str]): File or list of files to check.
456
+ suffix (str | Tuple[str]): Acceptable suffix or tuple of suffixes.
457
+ msg (str): Additional message to display in case of error.
458
+ """
464
459
  if file and suffix:
465
460
  if isinstance(suffix, str):
466
461
  suffix = (suffix,)
@@ -471,7 +466,16 @@ def check_suffix(file="yolo11n.pt", suffix=".pt", msg=""):
471
466
 
472
467
 
473
468
  def check_yolov5u_filename(file: str, verbose: bool = True):
474
- """Replace legacy YOLOv5 filenames with updated YOLOv5u filenames."""
469
+ """
470
+ Replace legacy YOLOv5 filenames with updated YOLOv5u filenames.
471
+
472
+ Args:
473
+ file (str): Filename to check and potentially update.
474
+ verbose (bool): Whether to print information about the replacement.
475
+
476
+ Returns:
477
+ (str): Updated filename.
478
+ """
475
479
  if "yolov3" in file or "yolov5" in file:
476
480
  if "u.yaml" in file:
477
481
  file = file.replace("u.yaml", ".yaml") # i.e. yolov5nu.yaml -> yolov5n.yaml
@@ -490,7 +494,15 @@ def check_yolov5u_filename(file: str, verbose: bool = True):
490
494
 
491
495
 
492
496
  def check_model_file_from_stem(model="yolo11n"):
493
- """Return a model filename from a valid model stem."""
497
+ """
498
+ Return a model filename from a valid model stem.
499
+
500
+ Args:
501
+ model (str): Model stem to check.
502
+
503
+ Returns:
504
+ (str | Path): Model filename with appropriate suffix.
505
+ """
494
506
  if model and not Path(model).suffix and Path(model).stem in downloads.GITHUB_ASSETS_STEMS:
495
507
  return Path(model).with_suffix(".pt") # add suffix, i.e. yolo11n -> yolo11n.pt
496
508
  else:
@@ -498,7 +510,19 @@ def check_model_file_from_stem(model="yolo11n"):
498
510
 
499
511
 
500
512
  def check_file(file, suffix="", download=True, download_dir=".", hard=True):
501
- """Search/download file (if necessary) and return path."""
513
+ """
514
+ Search/download file (if necessary) and return path.
515
+
516
+ Args:
517
+ file (str): File name or path.
518
+ suffix (str): File suffix to check.
519
+ download (bool): Whether to download the file if it doesn't exist locally.
520
+ download_dir (str): Directory to download the file to.
521
+ hard (bool): Whether to raise an error if the file is not found.
522
+
523
+ Returns:
524
+ (str): Path to the file.
525
+ """
502
526
  check_suffix(file, suffix) # optional
503
527
  file = str(file).strip() # convert to string and strip spaces
504
528
  file = check_yolov5u_filename(file) # yolov5n -> yolov5nu
@@ -526,7 +550,17 @@ def check_file(file, suffix="", download=True, download_dir=".", hard=True):
526
550
 
527
551
 
528
552
  def check_yaml(file, suffix=(".yaml", ".yml"), hard=True):
529
- """Search/download YAML file (if necessary) and return path, checking suffix."""
553
+ """
554
+ Search/download YAML file (if necessary) and return path, checking suffix.
555
+
556
+ Args:
557
+ file (str): File name or path.
558
+ suffix (tuple): Acceptable file suffixes.
559
+ hard (bool): Whether to raise an error if the file is not found.
560
+
561
+ Returns:
562
+ (str): Path to the YAML file.
563
+ """
530
564
  return check_file(file, suffix, hard=hard)
531
565
 
532
566
 
@@ -548,7 +582,15 @@ def check_is_path_safe(basedir, path):
548
582
 
549
583
 
550
584
  def check_imshow(warn=False):
551
- """Check if environment supports image displays."""
585
+ """
586
+ Check if environment supports image displays.
587
+
588
+ Args:
589
+ warn (bool): Whether to warn if environment doesn't support image displays.
590
+
591
+ Returns:
592
+ (bool): True if environment supports image displays, False otherwise.
593
+ """
552
594
  try:
553
595
  if LINUX:
554
596
  assert not IS_COLAB and not IS_KAGGLE
@@ -565,7 +607,13 @@ def check_imshow(warn=False):
565
607
 
566
608
 
567
609
  def check_yolo(verbose=True, device=""):
568
- """Return a human-readable YOLO software and hardware summary."""
610
+ """
611
+ Return a human-readable YOLO software and hardware summary.
612
+
613
+ Args:
614
+ verbose (bool): Whether to print verbose information.
615
+ device (str): Device to use for YOLO.
616
+ """
569
617
  import psutil
570
618
 
571
619
  from ultralytics.utils.torch_utils import select_device
@@ -593,7 +641,12 @@ def check_yolo(verbose=True, device=""):
593
641
 
594
642
 
595
643
  def collect_system_info():
596
- """Collect and print relevant system information including OS, Python, RAM, CPU, and CUDA."""
644
+ """
645
+ Collect and print relevant system information including OS, Python, RAM, CPU, and CUDA.
646
+
647
+ Returns:
648
+ (dict): Dictionary containing system information.
649
+ """
597
650
  import psutil
598
651
 
599
652
  from ultralytics.utils import ENVIRONMENT # scope to avoid circular import
@@ -650,24 +703,22 @@ def collect_system_info():
650
703
 
651
704
  def check_amp(model):
652
705
  """
653
- Checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLO11 model. If the checks fail, it means
654
- there are anomalies with AMP on the system that may cause NaN losses or zero-mAP results, so AMP will be disabled
655
- during training.
706
+ Checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLO11 model.
707
+
708
+ If the checks fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP
709
+ results, so AMP will be disabled during training.
656
710
 
657
711
  Args:
658
712
  model (nn.Module): A YOLO11 model instance.
659
713
 
660
- Example:
661
- ```python
662
- from ultralytics import YOLO
663
- from ultralytics.utils.checks import check_amp
664
-
665
- model = YOLO("yolo11n.pt").model.cuda()
666
- check_amp(model)
667
- ```
668
-
669
714
  Returns:
670
715
  (bool): Returns True if the AMP functionality works correctly with YOLO11 model, else False.
716
+
717
+ Examples:
718
+ >>> from ultralytics import YOLO
719
+ >>> from ultralytics.utils.checks import check_amp
720
+ >>> model = YOLO("yolo11n.pt").model.cuda()
721
+ >>> check_amp(model)
671
722
  """
672
723
  from ultralytics.utils.torch_utils import autocast
673
724
 
@@ -726,7 +777,15 @@ def check_amp(model):
726
777
 
727
778
 
728
779
  def git_describe(path=ROOT): # path must be a directory
729
- """Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe."""
780
+ """
781
+ Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe.
782
+
783
+ Args:
784
+ path (Path): Path to git repository.
785
+
786
+ Returns:
787
+ (str): Human-readable git description.
788
+ """
730
789
  try:
731
790
  return subprocess.check_output(f"git -C {path} describe --tags --long --always", shell=True).decode()[:-1]
732
791
  except Exception:
@@ -734,7 +793,14 @@ def git_describe(path=ROOT): # path must be a directory
734
793
 
735
794
 
736
795
  def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
737
- """Print function arguments (optional args dict)."""
796
+ """
797
+ Print function arguments (optional args dict).
798
+
799
+ Args:
800
+ args (dict, optional): Arguments to print.
801
+ show_file (bool): Whether to show the file name.
802
+ show_func (bool): Whether to show the function name.
803
+ """
738
804
 
739
805
  def strip_auth(v):
740
806
  """Clean longer Ultralytics HUB URLs by stripping potential authentication information."""
@@ -786,7 +852,12 @@ def cuda_is_available() -> bool:
786
852
 
787
853
 
788
854
  def is_rockchip():
789
- """Check if the current environment is running on a Rockchip SoC."""
855
+ """
856
+ Check if the current environment is running on a Rockchip SoC.
857
+
858
+ Returns:
859
+ (bool): True if running on a Rockchip SoC, False otherwise.
860
+ """
790
861
  if LINUX and ARM64:
791
862
  try:
792
863
  with open("/proc/device-tree/compatible") as f:
ultralytics/utils/dist.py CHANGED
@@ -12,10 +12,13 @@ from .torch_utils import TORCH_1_9
12
12
 
13
13
  def find_free_network_port() -> int:
14
14
  """
15
- Finds a free port on localhost.
15
+ Find a free port on localhost.
16
16
 
17
17
  It is useful in single-node training when we don't want to connect to a real main node but have to set the
18
18
  `MASTER_PORT` environment variable.
19
+
20
+ Returns:
21
+ (int): The available network port number.
19
22
  """
20
23
  with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
21
24
  s.bind(("127.0.0.1", 0))
@@ -54,7 +57,17 @@ if __name__ == "__main__":
54
57
 
55
58
 
56
59
  def generate_ddp_command(world_size, trainer):
57
- """Generates and returns command for distributed training."""
60
+ """
61
+ Generate command for distributed training.
62
+
63
+ Args:
64
+ world_size (int): Number of processes to spawn for distributed training.
65
+ trainer (object): The trainer object containing configuration for distributed training.
66
+
67
+ Returns:
68
+ cmd (List[str]): The command to execute for distributed training.
69
+ file (str): Path to the temporary file created for DDP training.
70
+ """
58
71
  import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
59
72
 
60
73
  if not trainer.resume: