dgenerate-ultralytics-headless 8.3.143__py3-none-any.whl → 8.3.145__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/METADATA +2 -2
  2. dgenerate_ultralytics_headless-8.3.145.dist-info/RECORD +272 -0
  3. tests/conftest.py +7 -24
  4. tests/test_cli.py +1 -1
  5. tests/test_cuda.py +7 -2
  6. tests/test_engine.py +7 -8
  7. tests/test_exports.py +16 -16
  8. tests/test_integrations.py +1 -1
  9. tests/test_solutions.py +11 -11
  10. ultralytics/__init__.py +1 -1
  11. ultralytics/cfg/__init__.py +16 -13
  12. ultralytics/data/annotator.py +6 -5
  13. ultralytics/data/augment.py +127 -126
  14. ultralytics/data/base.py +54 -51
  15. ultralytics/data/build.py +47 -23
  16. ultralytics/data/converter.py +47 -43
  17. ultralytics/data/dataset.py +51 -50
  18. ultralytics/data/loaders.py +77 -44
  19. ultralytics/data/split.py +22 -9
  20. ultralytics/data/split_dota.py +63 -39
  21. ultralytics/data/utils.py +59 -39
  22. ultralytics/engine/exporter.py +79 -27
  23. ultralytics/engine/model.py +52 -51
  24. ultralytics/engine/predictor.py +37 -28
  25. ultralytics/engine/results.py +191 -161
  26. ultralytics/engine/trainer.py +36 -19
  27. ultralytics/engine/tuner.py +12 -9
  28. ultralytics/engine/validator.py +7 -9
  29. ultralytics/hub/__init__.py +11 -13
  30. ultralytics/hub/auth.py +22 -2
  31. ultralytics/hub/google/__init__.py +19 -19
  32. ultralytics/hub/session.py +37 -51
  33. ultralytics/hub/utils.py +19 -5
  34. ultralytics/models/fastsam/model.py +30 -12
  35. ultralytics/models/fastsam/predict.py +5 -6
  36. ultralytics/models/fastsam/utils.py +3 -3
  37. ultralytics/models/fastsam/val.py +10 -6
  38. ultralytics/models/nas/model.py +9 -5
  39. ultralytics/models/nas/predict.py +6 -6
  40. ultralytics/models/nas/val.py +3 -3
  41. ultralytics/models/rtdetr/model.py +7 -6
  42. ultralytics/models/rtdetr/predict.py +14 -7
  43. ultralytics/models/rtdetr/train.py +10 -4
  44. ultralytics/models/rtdetr/val.py +36 -9
  45. ultralytics/models/sam/amg.py +30 -12
  46. ultralytics/models/sam/build.py +22 -22
  47. ultralytics/models/sam/model.py +10 -9
  48. ultralytics/models/sam/modules/blocks.py +76 -80
  49. ultralytics/models/sam/modules/decoders.py +6 -8
  50. ultralytics/models/sam/modules/encoders.py +23 -26
  51. ultralytics/models/sam/modules/memory_attention.py +13 -1
  52. ultralytics/models/sam/modules/sam.py +57 -26
  53. ultralytics/models/sam/modules/tiny_encoder.py +232 -237
  54. ultralytics/models/sam/modules/transformer.py +13 -13
  55. ultralytics/models/sam/modules/utils.py +11 -19
  56. ultralytics/models/sam/predict.py +114 -101
  57. ultralytics/models/utils/loss.py +98 -77
  58. ultralytics/models/utils/ops.py +116 -67
  59. ultralytics/models/yolo/classify/predict.py +5 -5
  60. ultralytics/models/yolo/classify/train.py +32 -28
  61. ultralytics/models/yolo/classify/val.py +7 -8
  62. ultralytics/models/yolo/detect/predict.py +1 -0
  63. ultralytics/models/yolo/detect/train.py +15 -14
  64. ultralytics/models/yolo/detect/val.py +37 -36
  65. ultralytics/models/yolo/model.py +106 -23
  66. ultralytics/models/yolo/obb/predict.py +3 -4
  67. ultralytics/models/yolo/obb/train.py +14 -6
  68. ultralytics/models/yolo/obb/val.py +29 -23
  69. ultralytics/models/yolo/pose/predict.py +9 -8
  70. ultralytics/models/yolo/pose/train.py +24 -16
  71. ultralytics/models/yolo/pose/val.py +44 -26
  72. ultralytics/models/yolo/segment/predict.py +5 -5
  73. ultralytics/models/yolo/segment/train.py +11 -7
  74. ultralytics/models/yolo/segment/val.py +2 -2
  75. ultralytics/models/yolo/world/train.py +33 -23
  76. ultralytics/models/yolo/world/train_world.py +11 -3
  77. ultralytics/models/yolo/yoloe/predict.py +11 -11
  78. ultralytics/models/yolo/yoloe/train.py +73 -21
  79. ultralytics/models/yolo/yoloe/train_seg.py +10 -7
  80. ultralytics/models/yolo/yoloe/val.py +42 -18
  81. ultralytics/nn/autobackend.py +59 -15
  82. ultralytics/nn/modules/__init__.py +4 -4
  83. ultralytics/nn/modules/activation.py +4 -1
  84. ultralytics/nn/modules/block.py +178 -111
  85. ultralytics/nn/modules/conv.py +6 -5
  86. ultralytics/nn/modules/head.py +469 -121
  87. ultralytics/nn/modules/transformer.py +147 -58
  88. ultralytics/nn/tasks.py +227 -20
  89. ultralytics/nn/text_model.py +30 -33
  90. ultralytics/solutions/ai_gym.py +4 -6
  91. ultralytics/solutions/analytics.py +7 -4
  92. ultralytics/solutions/config.py +10 -10
  93. ultralytics/solutions/distance_calculation.py +11 -10
  94. ultralytics/solutions/heatmap.py +2 -2
  95. ultralytics/solutions/instance_segmentation.py +7 -4
  96. ultralytics/solutions/object_blurrer.py +3 -3
  97. ultralytics/solutions/object_counter.py +15 -11
  98. ultralytics/solutions/object_cropper.py +3 -2
  99. ultralytics/solutions/parking_management.py +29 -28
  100. ultralytics/solutions/queue_management.py +6 -6
  101. ultralytics/solutions/region_counter.py +10 -3
  102. ultralytics/solutions/security_alarm.py +3 -3
  103. ultralytics/solutions/similarity_search.py +85 -24
  104. ultralytics/solutions/solutions.py +189 -79
  105. ultralytics/solutions/speed_estimation.py +28 -22
  106. ultralytics/solutions/streamlit_inference.py +17 -12
  107. ultralytics/solutions/trackzone.py +4 -4
  108. ultralytics/trackers/basetrack.py +16 -23
  109. ultralytics/trackers/bot_sort.py +30 -20
  110. ultralytics/trackers/byte_tracker.py +70 -64
  111. ultralytics/trackers/track.py +4 -8
  112. ultralytics/trackers/utils/gmc.py +31 -58
  113. ultralytics/trackers/utils/kalman_filter.py +37 -37
  114. ultralytics/trackers/utils/matching.py +1 -1
  115. ultralytics/utils/__init__.py +105 -89
  116. ultralytics/utils/autobatch.py +16 -3
  117. ultralytics/utils/autodevice.py +54 -24
  118. ultralytics/utils/benchmarks.py +45 -29
  119. ultralytics/utils/callbacks/base.py +3 -3
  120. ultralytics/utils/callbacks/clearml.py +9 -9
  121. ultralytics/utils/callbacks/comet.py +67 -25
  122. ultralytics/utils/callbacks/dvc.py +7 -10
  123. ultralytics/utils/callbacks/mlflow.py +2 -5
  124. ultralytics/utils/callbacks/neptune.py +7 -13
  125. ultralytics/utils/callbacks/raytune.py +1 -1
  126. ultralytics/utils/callbacks/tensorboard.py +5 -6
  127. ultralytics/utils/callbacks/wb.py +14 -14
  128. ultralytics/utils/checks.py +14 -13
  129. ultralytics/utils/dist.py +5 -5
  130. ultralytics/utils/downloads.py +94 -67
  131. ultralytics/utils/errors.py +5 -5
  132. ultralytics/utils/export.py +61 -47
  133. ultralytics/utils/files.py +23 -22
  134. ultralytics/utils/instance.py +48 -52
  135. ultralytics/utils/loss.py +78 -40
  136. ultralytics/utils/metrics.py +186 -130
  137. ultralytics/utils/ops.py +186 -190
  138. ultralytics/utils/patches.py +15 -17
  139. ultralytics/utils/plotting.py +71 -27
  140. ultralytics/utils/tal.py +21 -15
  141. ultralytics/utils/torch_utils.py +53 -50
  142. ultralytics/utils/triton.py +5 -4
  143. ultralytics/utils/tuner.py +5 -5
  144. dgenerate_ultralytics_headless-8.3.143.dist-info/RECORD +0 -272
  145. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/WHEEL +0 -0
  146. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/entry_points.txt +0 -0
  147. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/licenses/LICENSE +0 -0
  148. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/top_level.txt +0 -0
@@ -36,7 +36,7 @@ def _log_images(path: Path, prefix: str = "") -> None:
36
36
 
37
37
  Args:
38
38
  path (Path): Path to the image file to be logged.
39
- prefix (str): Optional prefix to add to the image name when logging.
39
+ prefix (str, optional): Optional prefix to add to the image name when logging.
40
40
 
41
41
  Examples:
42
42
  >>> from pathlib import Path
@@ -77,11 +77,8 @@ def _log_confusion_matrix(validator) -> None:
77
77
  the matrix into lists of target and prediction labels.
78
78
 
79
79
  Args:
80
- validator (BaseValidator): The validator object containing the confusion matrix and class names.
81
- Must have attributes: confusion_matrix.matrix, confusion_matrix.task, and names.
82
-
83
- Returns:
84
- None
80
+ validator (BaseValidator): The validator object containing the confusion matrix and class names. Must have
81
+ attributes: confusion_matrix.matrix, confusion_matrix.task, and names.
85
82
  """
86
83
  targets = []
87
84
  preds = []
@@ -99,7 +96,7 @@ def _log_confusion_matrix(validator) -> None:
99
96
 
100
97
 
101
98
  def on_pretrain_routine_start(trainer) -> None:
102
- """Initializes DVCLive logger for training metadata during pre-training routine."""
99
+ """Initialize DVCLive logger for training metadata during pre-training routine."""
103
100
  try:
104
101
  global live
105
102
  live = dvclive.Live(save_dvc_exp=True, cache_images=True)
@@ -109,18 +106,18 @@ def on_pretrain_routine_start(trainer) -> None:
109
106
 
110
107
 
111
108
  def on_pretrain_routine_end(trainer) -> None:
112
- """Logs plots related to the training process at the end of the pretraining routine."""
109
+ """Log plots related to the training process at the end of the pretraining routine."""
113
110
  _log_plots(trainer.plots, "train")
114
111
 
115
112
 
116
113
  def on_train_start(trainer) -> None:
117
- """Logs the training parameters if DVCLive logging is active."""
114
+ """Log the training parameters if DVCLive logging is active."""
118
115
  if live:
119
116
  live.log_params(trainer.args)
120
117
 
121
118
 
122
119
  def on_train_epoch_start(trainer) -> None:
123
- """Sets the global variable _training_epoch value to True at the start of training each epoch."""
120
+ """Set the global variable _training_epoch value to True at the start of training each epoch."""
124
121
  global _training_epoch
125
122
  _training_epoch = True
126
123
 
@@ -55,14 +55,11 @@ def on_pretrain_routine_end(trainer):
55
55
  Args:
56
56
  trainer (ultralytics.engine.trainer.BaseTrainer): The training object with arguments and parameters to log.
57
57
 
58
- Global:
59
- mlflow: The imported mlflow module to use for logging.
60
-
61
58
  Environment Variables:
62
59
  MLFLOW_TRACKING_URI: The URI for MLflow tracking. If not set, defaults to 'runs/mlflow'.
63
60
  MLFLOW_EXPERIMENT_NAME: The name of the MLflow experiment. If not set, defaults to trainer.args.project.
64
61
  MLFLOW_RUN: The name of the MLflow run. If not set, defaults to trainer.args.name.
65
- MLFLOW_KEEP_RUN_ACTIVE: Boolean indicating whether to keep the MLflow run active after the end of training.
62
+ MLFLOW_KEEP_RUN_ACTIVE: Boolean indicating whether to keep the MLflow run active after training ends.
66
63
  """
67
64
  global mlflow
68
65
 
@@ -107,7 +104,7 @@ def on_fit_epoch_end(trainer):
107
104
 
108
105
 
109
106
  def on_train_end(trainer):
110
- """Log model artifacts at the end of the training."""
107
+ """Log model artifacts at the end of training."""
111
108
  if not mlflow:
112
109
  return
113
110
  mlflow.log_artifact(str(trainer.best.parent)) # log save_dir/weights directory with best.pt and last.pt
@@ -23,7 +23,7 @@ def _log_scalars(scalars: dict, step: int = 0) -> None:
23
23
 
24
24
  Args:
25
25
  scalars (dict): Dictionary of scalar values to log to NeptuneAI.
26
- step (int): The current step or iteration number for logging.
26
+ step (int, optional): The current step or iteration number for logging.
27
27
 
28
28
  Examples:
29
29
  >>> metrics = {"mAP": 0.85, "loss": 0.32}
@@ -55,13 +55,7 @@ def _log_images(imgs_dict: dict, group: str = "") -> None:
55
55
 
56
56
 
57
57
  def _log_plot(title: str, plot_path: str) -> None:
58
- """
59
- Log plots to the NeptuneAI experiment logger.
60
-
61
- Args:
62
- title (str): Title of the plot.
63
- plot_path (str): Path to the saved image file.
64
- """
58
+ """Log plots to the NeptuneAI experiment logger."""
65
59
  import matplotlib.image as mpimg
66
60
  import matplotlib.pyplot as plt
67
61
 
@@ -73,7 +67,7 @@ def _log_plot(title: str, plot_path: str) -> None:
73
67
 
74
68
 
75
69
  def on_pretrain_routine_start(trainer) -> None:
76
- """Callback function called before the training routine starts."""
70
+ """Initialize NeptuneAI run and log hyperparameters before training starts."""
77
71
  try:
78
72
  global run
79
73
  run = neptune.init_run(
@@ -87,7 +81,7 @@ def on_pretrain_routine_start(trainer) -> None:
87
81
 
88
82
 
89
83
  def on_train_epoch_end(trainer) -> None:
90
- """Callback function called at end of each training epoch."""
84
+ """Log training metrics and learning rate at the end of each training epoch."""
91
85
  _log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
92
86
  _log_scalars(trainer.lr, trainer.epoch + 1)
93
87
  if trainer.epoch == 1:
@@ -95,7 +89,7 @@ def on_train_epoch_end(trainer) -> None:
95
89
 
96
90
 
97
91
  def on_fit_epoch_end(trainer) -> None:
98
- """Callback function called at end of each fit (train+val) epoch."""
92
+ """Log model info and validation metrics at the end of each fit epoch."""
99
93
  if run and trainer.epoch == 0:
100
94
  from ultralytics.utils.torch_utils import model_info_for_loggers
101
95
 
@@ -104,14 +98,14 @@ def on_fit_epoch_end(trainer) -> None:
104
98
 
105
99
 
106
100
  def on_val_end(validator) -> None:
107
- """Callback function called at end of each validation."""
101
+ """Log validation images at the end of validation."""
108
102
  if run:
109
103
  # Log val_labels and val_pred
110
104
  _log_images({f.stem: str(f) for f in validator.save_dir.glob("val*.jpg")}, "Validation")
111
105
 
112
106
 
113
107
  def on_train_end(trainer) -> None:
114
- """Callback function called at end of training."""
108
+ """Log final results, plots, and model weights at the end of training."""
115
109
  if run:
116
110
  # Log final results, CM matrix + PR plots
117
111
  files = [
@@ -14,7 +14,7 @@ except (ImportError, AssertionError):
14
14
 
15
15
  def on_fit_epoch_end(trainer):
16
16
  """
17
- Reports training metrics to Ray Tune at epoch end when a Ray session is active.
17
+ Report training metrics to Ray Tune at epoch end when a Ray session is active.
18
18
 
19
19
  Captures metrics from the trainer object and sends them to Ray Tune with the current epoch number,
20
20
  enabling hyperparameter tuning optimization. Only executes when within an active Ray Tune session.
@@ -31,7 +31,7 @@ def _log_scalars(scalars: dict, step: int = 0) -> None:
31
31
  step (int): Global step value to record with the scalar values. Used for x-axis in TensorBoard graphs.
32
32
 
33
33
  Examples:
34
- >>> # Log training metrics
34
+ Log training metrics
35
35
  >>> metrics = {"loss": 0.5, "accuracy": 0.95}
36
36
  >>> _log_scalars(metrics, step=100)
37
37
  """
@@ -49,9 +49,8 @@ def _log_tensorboard_graph(trainer) -> None:
49
49
  approach for models like RTDETR that may require special handling.
50
50
 
51
51
  Args:
52
- trainer (BaseTrainer): The trainer object containing the model to visualize. Must have attributes:
53
- - model: PyTorch model to visualize
54
- - args: Configuration arguments with 'imgsz' attribute
52
+ trainer (ultralytics.engine.trainer.BaseTrainer): The trainer object containing the model to visualize.
53
+ Must have attributes model and args with imgsz.
55
54
 
56
55
  Notes:
57
56
  This function requires TensorBoard integration to be enabled and the global WRITER to be initialized.
@@ -110,13 +109,13 @@ def on_train_start(trainer) -> None:
110
109
 
111
110
 
112
111
  def on_train_epoch_end(trainer) -> None:
113
- """Logs scalar statistics at the end of a training epoch."""
112
+ """Log scalar statistics at the end of a training epoch."""
114
113
  _log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
115
114
  _log_scalars(trainer.lr, trainer.epoch + 1)
116
115
 
117
116
 
118
117
  def on_fit_epoch_end(trainer) -> None:
119
- """Logs epoch metrics at end of training epoch."""
118
+ """Log epoch metrics at end of training epoch."""
120
119
  _log_scalars(trainer.metrics, trainer.epoch + 1)
121
120
 
122
121
 
@@ -27,9 +27,9 @@ def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall
27
27
  x (list): Values for the x-axis; expected to have length N.
28
28
  y (list): Corresponding values for the y-axis; also expected to have length N.
29
29
  classes (list): Labels identifying the class of each point; length N.
30
- title (str): Title for the plot; defaults to 'Precision Recall Curve'.
31
- x_title (str): Label for the x-axis; defaults to 'Recall'.
32
- y_title (str): Label for the y-axis; defaults to 'Precision'.
30
+ title (str, optional): Title for the plot.
31
+ x_title (str, optional): Label for the x-axis.
32
+ y_title (str, optional): Label for the y-axis.
33
33
 
34
34
  Returns:
35
35
  (wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization.
@@ -64,13 +64,13 @@ def _plot_curve(
64
64
  Args:
65
65
  x (np.ndarray): Data points for the x-axis with length N.
66
66
  y (np.ndarray): Corresponding data points for the y-axis with shape (C, N), where C is the number of classes.
67
- names (list): Names of the classes corresponding to the y-axis data; length C.
68
- id (str): Unique identifier for the logged data in wandb.
69
- title (str): Title for the visualization plot.
70
- x_title (str): Label for the x-axis.
71
- y_title (str): Label for the y-axis.
72
- num_x (int): Number of interpolated data points for visualization.
73
- only_mean (bool): Flag to indicate if only the mean curve should be plotted.
67
+ names (list, optional): Names of the classes corresponding to the y-axis data; length C.
68
+ id (str, optional): Unique identifier for the logged data in wandb.
69
+ title (str, optional): Title for the visualization plot.
70
+ x_title (str, optional): Label for the x-axis.
71
+ y_title (str, optional): Label for the y-axis.
72
+ num_x (int, optional): Number of interpolated data points for visualization.
73
+ only_mean (bool, optional): Flag to indicate if only the mean curve should be plotted.
74
74
 
75
75
  Notes:
76
76
  The function leverages the '_custom_table' function to generate the actual visualization.
@@ -111,9 +111,9 @@ def _log_plots(plots, step):
111
111
  step (int): The step/epoch at which to log the plots in the WandB run.
112
112
 
113
113
  Notes:
114
- - The function uses a shallow copy of the plots dictionary to prevent modification during iteration
115
- - Plots are identified by their stem name (filename without extension)
116
- - Each plot is logged as a WandB Image object
114
+ The function uses a shallow copy of the plots dictionary to prevent modification during iteration.
115
+ Plots are identified by their stem name (filename without extension).
116
+ Each plot is logged as a WandB Image object.
117
117
  """
118
118
  for name, params in plots.copy().items(): # shallow copy to prevent plots dict changing during iteration
119
119
  timestamp = params["timestamp"]
@@ -123,7 +123,7 @@ def _log_plots(plots, step):
123
123
 
124
124
 
125
125
  def on_pretrain_routine_start(trainer):
126
- """Initiate and start wandb project if module is present."""
126
+ """Initialize and start wandb project if module is present."""
127
127
  if not wb.run:
128
128
  wb.init(
129
129
  project=str(trainer.args.project).replace("/", "-") if trainer.args.project else "Ultralytics",
@@ -58,7 +58,8 @@ def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""):
58
58
  package (str, optional): Python package to use instead of requirements.txt file.
59
59
 
60
60
  Returns:
61
- (List[SimpleNamespace]): List of parsed requirements as SimpleNamespace objects with `name` and `specifier` attributes.
61
+ requirements (List[SimpleNamespace]): List of parsed requirements as SimpleNamespace objects with `name` and
62
+ `specifier` attributes.
62
63
 
63
64
  Examples:
64
65
  >>> from ultralytics.utils.checks import parse_requirements
@@ -89,7 +90,7 @@ def parse_version(version="0.0.0") -> tuple:
89
90
  version (str): Version string, i.e. '2.0.1+cpu'
90
91
 
91
92
  Returns:
92
- (Tuple[int, int, int]): Tuple of integers representing the numeric part of the version, i.e. (2, 0, 1)
93
+ (tuple): Tuple of integers representing the numeric part of the version, i.e. (2, 0, 1)
93
94
  """
94
95
  try:
95
96
  return tuple(map(int, re.findall(r"\d+", version)[:3])) # '2.0.1+cpu' -> (2, 0, 1)
@@ -167,7 +168,7 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
167
168
 
168
169
  @functools.lru_cache
169
170
  def check_uv():
170
- """Check if uv is installed and can run successfully."""
171
+ """Check if uv package manager is installed and can run successfully."""
171
172
  try:
172
173
  return subprocess.run(["uv", "-V"], capture_output=True).returncode == 0
173
174
  except FileNotFoundError:
@@ -265,7 +266,7 @@ def check_version(
265
266
 
266
267
  def check_latest_pypi_version(package_name="ultralytics"):
267
268
  """
268
- Returns the latest version of a PyPI package without downloading or installing it.
269
+ Return the latest version of a PyPI package without downloading or installing it.
269
270
 
270
271
  Args:
271
272
  package_name (str): The name of the package to find the latest version for.
@@ -286,7 +287,7 @@ def check_latest_pypi_version(package_name="ultralytics"):
286
287
 
287
288
  def check_pip_update_available():
288
289
  """
289
- Checks if a new version of the ultralytics package is available on PyPI.
290
+ Check if a new version of the ultralytics package is available on PyPI.
290
291
 
291
292
  Returns:
292
293
  (bool): True if an update is available, False otherwise.
@@ -360,9 +361,9 @@ def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=()
360
361
  Check if installed dependencies meet Ultralytics YOLO models requirements and attempt to auto-update if needed.
361
362
 
362
363
  Args:
363
- requirements (Union[Path, str, List[str]]): Path to a requirements.txt file, a single package requirement as a
364
+ requirements (Path | str | List[str]): Path to a requirements.txt file, a single package requirement as a
364
365
  string, or a list of package requirements as strings.
365
- exclude (Tuple[str]): Tuple of package names to exclude from checking.
366
+ exclude (tuple): Tuple of package names to exclude from checking.
366
367
  install (bool): If True, attempt to auto-update packages that don't meet requirements.
367
368
  cmds (str): Additional commands to pass to the pip install command when auto-updating.
368
369
 
@@ -432,7 +433,7 @@ def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=()
432
433
 
433
434
  def check_torchvision():
434
435
  """
435
- Checks the installed versions of PyTorch and Torchvision to ensure they're compatible.
436
+ Check the installed versions of PyTorch and Torchvision to ensure they're compatible.
436
437
 
437
438
  This function checks the installed versions of PyTorch and Torchvision, and warns if they're incompatible according
438
439
  to the compatibility table based on: https://github.com/pytorch/vision#installation.
@@ -470,7 +471,7 @@ def check_suffix(file="yolo11n.pt", suffix=".pt", msg=""):
470
471
 
471
472
  Args:
472
473
  file (str | List[str]): File or list of files to check.
473
- suffix (str | Tuple[str]): Acceptable suffix or tuple of suffixes.
474
+ suffix (str | tuple): Acceptable suffix or tuple of suffixes.
474
475
  msg (str): Additional message to display in case of error.
475
476
  """
476
477
  if file and suffix:
@@ -531,7 +532,7 @@ def check_file(file, suffix="", download=True, download_dir=".", hard=True):
531
532
 
532
533
  Args:
533
534
  file (str): File name or path.
534
- suffix (str | Tuple[str]): Acceptable suffix or tuple of suffixes to validate against the file.
535
+ suffix (str | tuple): Acceptable suffix or tuple of suffixes to validate against the file.
535
536
  download (bool): Whether to download the file if it doesn't exist locally.
536
537
  download_dir (str): Directory to download the file to.
537
538
  hard (bool): Whether to raise an error if the file is not found.
@@ -571,7 +572,7 @@ def check_yaml(file, suffix=(".yaml", ".yml"), hard=True):
571
572
 
572
573
  Args:
573
574
  file (str | Path): File name or path.
574
- suffix (Tuple[str]): Tuple of acceptable YAML file suffixes.
575
+ suffix (tuple): Tuple of acceptable YAML file suffixes.
575
576
  hard (bool): Whether to raise an error if the file is not found or multiple files are found.
576
577
 
577
578
  Returns:
@@ -720,13 +721,13 @@ def collect_system_info():
720
721
 
721
722
  def check_amp(model):
722
723
  """
723
- Checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLO11 model.
724
+ Check the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLO model.
724
725
 
725
726
  If the checks fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP
726
727
  results, so AMP will be disabled during training.
727
728
 
728
729
  Args:
729
- model (nn.Module): A YOLO11 model instance.
730
+ model (torch.nn.Module): A YOLO model instance.
730
731
 
731
732
  Returns:
732
733
  (bool): Returns True if the AMP functionality works correctly with YOLO11 model, else False.
ultralytics/utils/dist.py CHANGED
@@ -34,8 +34,8 @@ def generate_ddp_file(trainer):
34
34
  The file contains the necessary configuration to initialize the trainer in a distributed environment.
35
35
 
36
36
  Args:
37
- trainer (object): The trainer object containing training configuration and arguments.
38
- Must have args attribute and be a class instance.
37
+ trainer (ultralytics.engine.trainer.BaseTrainer): The trainer containing training configuration and arguments.
38
+ Must have args attribute and be a class instance.
39
39
 
40
40
  Returns:
41
41
  (str): Path to the generated temporary DDP file.
@@ -76,13 +76,13 @@ if __name__ == "__main__":
76
76
  return file.name
77
77
 
78
78
 
79
- def generate_ddp_command(world_size, trainer):
79
+ def generate_ddp_command(world_size: int, trainer):
80
80
  """
81
81
  Generate command for distributed training.
82
82
 
83
83
  Args:
84
84
  world_size (int): Number of processes to spawn for distributed training.
85
- trainer (object): The trainer object containing configuration for distributed training.
85
+ trainer (ultralytics.engine.trainer.BaseTrainer): The trainer containing configuration for distributed training.
86
86
 
87
87
  Returns:
88
88
  cmd (List[str]): The command to execute for distributed training.
@@ -107,7 +107,7 @@ def ddp_cleanup(trainer, file):
107
107
  as a temporary file for DDP training, and deletes it if so.
108
108
 
109
109
  Args:
110
- trainer (object): The trainer object used for distributed training.
110
+ trainer (ultralytics.engine.trainer.BaseTrainer): The trainer used for distributed training.
111
111
  file (str): Path to the file that might need to be deleted.
112
112
 
113
113
  Examples: