ultralytics 8.3.142__py3-none-any.whl → 8.3.144__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. tests/conftest.py +7 -24
  2. tests/test_cli.py +1 -1
  3. tests/test_cuda.py +7 -2
  4. tests/test_engine.py +7 -8
  5. tests/test_exports.py +16 -16
  6. tests/test_integrations.py +1 -1
  7. tests/test_solutions.py +12 -12
  8. ultralytics/__init__.py +1 -1
  9. ultralytics/cfg/__init__.py +16 -13
  10. ultralytics/data/annotator.py +6 -5
  11. ultralytics/data/augment.py +127 -126
  12. ultralytics/data/base.py +54 -51
  13. ultralytics/data/build.py +47 -23
  14. ultralytics/data/converter.py +47 -43
  15. ultralytics/data/dataset.py +51 -50
  16. ultralytics/data/loaders.py +77 -44
  17. ultralytics/data/split.py +22 -9
  18. ultralytics/data/split_dota.py +63 -39
  19. ultralytics/data/utils.py +59 -39
  20. ultralytics/engine/exporter.py +79 -27
  21. ultralytics/engine/model.py +39 -39
  22. ultralytics/engine/predictor.py +37 -28
  23. ultralytics/engine/results.py +187 -157
  24. ultralytics/engine/trainer.py +36 -19
  25. ultralytics/engine/tuner.py +12 -9
  26. ultralytics/engine/validator.py +7 -9
  27. ultralytics/hub/__init__.py +11 -13
  28. ultralytics/hub/auth.py +22 -2
  29. ultralytics/hub/google/__init__.py +19 -19
  30. ultralytics/hub/session.py +37 -51
  31. ultralytics/hub/utils.py +19 -5
  32. ultralytics/models/fastsam/model.py +30 -12
  33. ultralytics/models/fastsam/predict.py +5 -6
  34. ultralytics/models/fastsam/utils.py +3 -3
  35. ultralytics/models/fastsam/val.py +10 -6
  36. ultralytics/models/nas/model.py +9 -5
  37. ultralytics/models/nas/predict.py +6 -6
  38. ultralytics/models/nas/val.py +3 -3
  39. ultralytics/models/rtdetr/model.py +7 -6
  40. ultralytics/models/rtdetr/predict.py +14 -7
  41. ultralytics/models/rtdetr/train.py +10 -4
  42. ultralytics/models/rtdetr/val.py +36 -9
  43. ultralytics/models/sam/amg.py +30 -12
  44. ultralytics/models/sam/build.py +22 -22
  45. ultralytics/models/sam/model.py +10 -9
  46. ultralytics/models/sam/modules/blocks.py +76 -80
  47. ultralytics/models/sam/modules/decoders.py +6 -8
  48. ultralytics/models/sam/modules/encoders.py +23 -26
  49. ultralytics/models/sam/modules/memory_attention.py +13 -1
  50. ultralytics/models/sam/modules/sam.py +57 -26
  51. ultralytics/models/sam/modules/tiny_encoder.py +232 -237
  52. ultralytics/models/sam/modules/transformer.py +13 -13
  53. ultralytics/models/sam/modules/utils.py +11 -19
  54. ultralytics/models/sam/predict.py +114 -101
  55. ultralytics/models/utils/loss.py +98 -77
  56. ultralytics/models/utils/ops.py +116 -67
  57. ultralytics/models/yolo/classify/predict.py +5 -5
  58. ultralytics/models/yolo/classify/train.py +32 -28
  59. ultralytics/models/yolo/classify/val.py +7 -8
  60. ultralytics/models/yolo/detect/predict.py +1 -0
  61. ultralytics/models/yolo/detect/train.py +15 -14
  62. ultralytics/models/yolo/detect/val.py +37 -36
  63. ultralytics/models/yolo/model.py +106 -23
  64. ultralytics/models/yolo/obb/predict.py +3 -4
  65. ultralytics/models/yolo/obb/train.py +14 -6
  66. ultralytics/models/yolo/obb/val.py +29 -23
  67. ultralytics/models/yolo/pose/predict.py +9 -8
  68. ultralytics/models/yolo/pose/train.py +24 -16
  69. ultralytics/models/yolo/pose/val.py +44 -26
  70. ultralytics/models/yolo/segment/predict.py +5 -5
  71. ultralytics/models/yolo/segment/train.py +11 -7
  72. ultralytics/models/yolo/segment/val.py +2 -2
  73. ultralytics/models/yolo/world/train.py +33 -23
  74. ultralytics/models/yolo/world/train_world.py +11 -3
  75. ultralytics/models/yolo/yoloe/predict.py +11 -11
  76. ultralytics/models/yolo/yoloe/train.py +73 -21
  77. ultralytics/models/yolo/yoloe/train_seg.py +10 -7
  78. ultralytics/models/yolo/yoloe/val.py +42 -18
  79. ultralytics/nn/autobackend.py +59 -15
  80. ultralytics/nn/modules/__init__.py +4 -4
  81. ultralytics/nn/modules/activation.py +4 -1
  82. ultralytics/nn/modules/block.py +178 -111
  83. ultralytics/nn/modules/conv.py +6 -5
  84. ultralytics/nn/modules/head.py +469 -121
  85. ultralytics/nn/modules/transformer.py +147 -58
  86. ultralytics/nn/tasks.py +227 -20
  87. ultralytics/nn/text_model.py +30 -33
  88. ultralytics/solutions/ai_gym.py +1 -1
  89. ultralytics/solutions/analytics.py +7 -4
  90. ultralytics/solutions/config.py +10 -10
  91. ultralytics/solutions/distance_calculation.py +11 -10
  92. ultralytics/solutions/heatmap.py +1 -1
  93. ultralytics/solutions/instance_segmentation.py +6 -3
  94. ultralytics/solutions/object_blurrer.py +3 -3
  95. ultralytics/solutions/object_counter.py +16 -8
  96. ultralytics/solutions/object_cropper.py +12 -5
  97. ultralytics/solutions/parking_management.py +29 -28
  98. ultralytics/solutions/queue_management.py +6 -6
  99. ultralytics/solutions/region_counter.py +10 -3
  100. ultralytics/solutions/security_alarm.py +3 -3
  101. ultralytics/solutions/similarity_search.py +85 -24
  102. ultralytics/solutions/solutions.py +215 -85
  103. ultralytics/solutions/speed_estimation.py +28 -22
  104. ultralytics/solutions/streamlit_inference.py +17 -12
  105. ultralytics/solutions/trackzone.py +4 -4
  106. ultralytics/trackers/basetrack.py +16 -23
  107. ultralytics/trackers/bot_sort.py +30 -20
  108. ultralytics/trackers/byte_tracker.py +70 -64
  109. ultralytics/trackers/track.py +4 -8
  110. ultralytics/trackers/utils/gmc.py +31 -58
  111. ultralytics/trackers/utils/kalman_filter.py +37 -37
  112. ultralytics/trackers/utils/matching.py +1 -1
  113. ultralytics/utils/__init__.py +105 -89
  114. ultralytics/utils/autobatch.py +16 -3
  115. ultralytics/utils/autodevice.py +54 -24
  116. ultralytics/utils/benchmarks.py +42 -28
  117. ultralytics/utils/callbacks/base.py +3 -3
  118. ultralytics/utils/callbacks/clearml.py +9 -9
  119. ultralytics/utils/callbacks/comet.py +67 -25
  120. ultralytics/utils/callbacks/dvc.py +7 -10
  121. ultralytics/utils/callbacks/mlflow.py +2 -5
  122. ultralytics/utils/callbacks/neptune.py +7 -13
  123. ultralytics/utils/callbacks/raytune.py +1 -1
  124. ultralytics/utils/callbacks/tensorboard.py +5 -6
  125. ultralytics/utils/callbacks/wb.py +14 -14
  126. ultralytics/utils/checks.py +14 -13
  127. ultralytics/utils/dist.py +5 -5
  128. ultralytics/utils/downloads.py +94 -67
  129. ultralytics/utils/errors.py +5 -5
  130. ultralytics/utils/export.py +61 -47
  131. ultralytics/utils/files.py +23 -22
  132. ultralytics/utils/instance.py +48 -52
  133. ultralytics/utils/loss.py +78 -40
  134. ultralytics/utils/metrics.py +186 -130
  135. ultralytics/utils/ops.py +186 -190
  136. ultralytics/utils/patches.py +15 -17
  137. ultralytics/utils/plotting.py +71 -27
  138. ultralytics/utils/tal.py +21 -15
  139. ultralytics/utils/torch_utils.py +53 -50
  140. ultralytics/utils/triton.py +5 -4
  141. ultralytics/utils/tuner.py +5 -5
  142. {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/METADATA +1 -1
  143. ultralytics-8.3.144.dist-info/RECORD +272 -0
  144. ultralytics-8.3.142.dist-info/RECORD +0 -272
  145. {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/WHEEL +0 -0
  146. {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/entry_points.txt +0 -0
  147. {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/licenses/LICENSE +0 -0
  148. {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/top_level.txt +0 -0
@@ -139,7 +139,8 @@ class TQDM(rich.tqdm if TQDM_RICH else tqdm.tqdm):
139
139
  A custom TQDM progress bar class that extends the original tqdm functionality.
140
140
 
141
141
  This class modifies the behavior of the original tqdm progress bar based on global settings and provides
142
- additional customization options.
142
+ additional customization options for Ultralytics projects. The progress bar is automatically disabled when
143
+ VERBOSE is False or when explicitly disabled.
143
144
 
144
145
  Attributes:
145
146
  disable (bool): Whether to disable the progress bar. Determined by the global VERBOSE setting and
@@ -148,7 +149,8 @@ class TQDM(rich.tqdm if TQDM_RICH else tqdm.tqdm):
148
149
  explicitly set.
149
150
 
150
151
  Methods:
151
- __init__: Initializes the TQDM object with custom settings.
152
+ __init__: Initialize the TQDM object with custom settings.
153
+ __iter__: Return self as iterator to satisfy Iterable interface.
152
154
 
153
155
  Examples:
154
156
  >>> from ultralytics.utils import TQDM
@@ -159,9 +161,7 @@ class TQDM(rich.tqdm if TQDM_RICH else tqdm.tqdm):
159
161
 
160
162
  def __init__(self, *args, **kwargs):
161
163
  """
162
- Initializes a custom TQDM progress bar.
163
-
164
- This class extends the original tqdm class to provide customized behavior for Ultralytics projects.
164
+ Initialize a custom TQDM progress bar with Ultralytics-specific settings.
165
165
 
166
166
  Args:
167
167
  *args (Any): Variable length argument list to be passed to the original tqdm constructor.
@@ -192,17 +192,17 @@ class DataExportMixin:
192
192
  Mixin class for exporting validation metrics or prediction results in various formats.
193
193
 
194
194
  This class provides utilities to export performance metrics (e.g., mAP, precision, recall) or prediction results
195
- from classification, object detection, segmentation, or pose estimation tasks into various formats, Pandas DataFrame
196
- CSV, XML, HTML, JSON and SQLite (SQL)
195
+ from classification, object detection, segmentation, or pose estimation tasks into various formats: Pandas
196
+ DataFrame, CSV, XML, HTML, JSON and SQLite (SQL).
197
197
 
198
198
  Methods:
199
- to_df(): Convert summary to a Pandas DataFrame.
200
- to_csv(): Export results as a CSV string.
201
- to_xml(): Export results as an XML string (requires `lxml`).
202
- to_html(): Export results as an HTML table.
203
- to_json(): Export results as a JSON string.
204
- tojson(): Deprecated alias for `to_json()`.
205
- to_sql(): Export results to an SQLite database.
199
+ to_df: Convert summary to a Pandas DataFrame.
200
+ to_csv: Export results as a CSV string.
201
+ to_xml: Export results as an XML string (requires `lxml`).
202
+ to_html: Export results as an HTML table.
203
+ to_json: Export results as a JSON string.
204
+ tojson: Deprecated alias for `to_json()`.
205
+ to_sql: Export results to an SQLite database.
206
206
 
207
207
  Examples:
208
208
  >>> model = YOLO("yolov8n.pt")
@@ -218,8 +218,8 @@ class DataExportMixin:
218
218
  Create a pandas DataFrame from the prediction results summary or validation metrics.
219
219
 
220
220
  Args:
221
- normalize (bool, optional): Normalize numerical values for easier comparison. Defaults to False.
222
- decimals (int, optional): Decimal places to round floats. Defaults to 5.
221
+ normalize (bool, optional): Normalize numerical values for easier comparison.
222
+ decimals (int, optional): Decimal places to round floats.
223
223
 
224
224
  Returns:
225
225
  (DataFrame): DataFrame containing the summary data.
@@ -233,8 +233,8 @@ class DataExportMixin:
233
233
  Export results to CSV string format.
234
234
 
235
235
  Args:
236
- normalize (bool, optional): Normalize numeric values. Defaults to False.
237
- decimals (int, optional): Decimal precision. Defaults to 5.
236
+ normalize (bool, optional): Normalize numeric values.
237
+ decimals (int, optional): Decimal precision.
238
238
 
239
239
  Returns:
240
240
  (str): CSV content as string.
@@ -246,13 +246,13 @@ class DataExportMixin:
246
246
  Export results to XML format.
247
247
 
248
248
  Args:
249
- normalize (bool, optional): Normalize numeric values. Defaults to False.
250
- decimals (int, optional): Decimal precision. Defaults to 5.
249
+ normalize (bool, optional): Normalize numeric values.
250
+ decimals (int, optional): Decimal precision.
251
251
 
252
252
  Returns:
253
253
  (str): XML string.
254
254
 
255
- Note:
255
+ Notes:
256
256
  Requires `lxml` package to be installed.
257
257
  """
258
258
  from ultralytics.utils.checks import check_requirements
@@ -266,9 +266,9 @@ class DataExportMixin:
266
266
  Export results to HTML table format.
267
267
 
268
268
  Args:
269
- normalize (bool, optional): Normalize numeric values. Defaults to False.
270
- decimals (int, optional): Decimal precision. Defaults to 5.
271
- index (bool, optional): Whether to include index column in the HTML table. Defaults to False.
269
+ normalize (bool, optional): Normalize numeric values.
270
+ decimals (int, optional): Decimal precision.
271
+ index (bool, optional): Whether to include index column in the HTML table.
272
272
 
273
273
  Returns:
274
274
  (str): HTML representation of the results.
@@ -286,8 +286,8 @@ class DataExportMixin:
286
286
  Export results to JSON format.
287
287
 
288
288
  Args:
289
- normalize (bool, optional): Normalize numeric values. Defaults to False.
290
- decimals (int, optional): Decimal precision. Defaults to 5.
289
+ normalize (bool, optional): Normalize numeric values.
290
+ decimals (int, optional): Decimal precision.
291
291
 
292
292
  Returns:
293
293
  (str): JSON-formatted string of the results.
@@ -299,10 +299,10 @@ class DataExportMixin:
299
299
  Save results to an SQLite database.
300
300
 
301
301
  Args:
302
- normalize (bool, optional): Normalize numeric values. Defaults to False.
303
- decimals (int, optional): Decimal precision. Defaults to 5.
304
- table_name (str, optional): Name of the SQL table. Defaults to "results".
305
- db_path (str, optional): SQLite database file path. Defaults to "results.db".
302
+ normalize (bool, optional): Normalize numeric values.
303
+ decimals (int, optional): Decimal precision.
304
+ table_name (str, optional): Name of the SQL table.
305
+ db_path (str, optional): SQLite database file path.
306
306
  """
307
307
  df = self.to_df(normalize, decimals)
308
308
  if df.empty or df.columns.empty: # Exit if df is None or has no columns (i.e., no schema)
@@ -348,9 +348,9 @@ class SimpleClass:
348
348
  showing all their non-callable attributes. It's useful for debugging and introspection of object states.
349
349
 
350
350
  Methods:
351
- __str__: Returns a human-readable string representation of the object.
352
- __repr__: Returns a machine-readable string representation of the object.
353
- __getattr__: Provides a custom attribute access error message with helpful information.
351
+ __str__: Return a human-readable string representation of the object.
352
+ __repr__: Return a machine-readable string representation of the object.
353
+ __getattr__: Provide a custom attribute access error message with helpful information.
354
354
 
355
355
  Examples:
356
356
  >>> class MyClass(SimpleClass):
@@ -389,7 +389,7 @@ class SimpleClass:
389
389
  return self.__str__()
390
390
 
391
391
  def __getattr__(self, attr):
392
- """Custom attribute access error message with helpful information."""
392
+ """Provide a custom attribute access error message with helpful information."""
393
393
  name = self.__class__.__name__
394
394
  raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
395
395
 
@@ -403,10 +403,10 @@ class IterableSimpleNamespace(SimpleNamespace):
403
403
  configuration parameters.
404
404
 
405
405
  Methods:
406
- __iter__: Returns an iterator of key-value pairs from the namespace's attributes.
407
- __str__: Returns a human-readable string representation of the object.
408
- __getattr__: Provides a custom attribute access error message with helpful information.
409
- get: Retrieves the value of a specified key, or a default value if the key doesn't exist.
406
+ __iter__: Return an iterator of key-value pairs from the namespace's attributes.
407
+ __str__: Return a human-readable string representation of the object.
408
+ __getattr__: Provide a custom attribute access error message with helpful information.
409
+ get: Retrieve the value of a specified key, or a default value if the key doesn't exist.
410
410
 
411
411
  Examples:
412
412
  >>> cfg = IterableSimpleNamespace(a=1, b=2, c=3)
@@ -438,7 +438,7 @@ class IterableSimpleNamespace(SimpleNamespace):
438
438
  return "\n".join(f"{k}={v}" for k, v in vars(self).items())
439
439
 
440
440
  def __getattr__(self, attr):
441
- """Custom attribute access error message with helpful information."""
441
+ """Provide a custom attribute access error message with helpful information."""
442
442
  name = self.__class__.__name__
443
443
  raise AttributeError(
444
444
  f"""
@@ -460,7 +460,7 @@ def plt_settings(rcparams=None, backend="Agg"):
460
460
 
461
461
  Args:
462
462
  rcparams (dict, optional): Dictionary of rc parameters to set.
463
- backend (str, optional): Name of the backend to use. Defaults to 'Agg'.
463
+ backend (str, optional): Name of the backend to use.
464
464
 
465
465
  Returns:
466
466
  (Callable): Decorated function with temporarily set rc parameters and backend.
@@ -484,7 +484,7 @@ def plt_settings(rcparams=None, backend="Agg"):
484
484
  """Decorator to apply temporary rc parameters and backend to a function."""
485
485
 
486
486
  def wrapper(*args, **kwargs):
487
- """Sets rc parameters and backend, calls the original function, and restores the settings."""
487
+ """Set rc parameters and backend, call the original function, and restore the settings."""
488
488
  import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
489
489
 
490
490
  original_backend = plt.get_backend()
@@ -510,15 +510,15 @@ def plt_settings(rcparams=None, backend="Agg"):
510
510
 
511
511
  def set_logging(name="LOGGING_NAME", verbose=True):
512
512
  """
513
- Sets up logging with UTF-8 encoding and configurable verbosity.
513
+ Set up logging with UTF-8 encoding and configurable verbosity.
514
514
 
515
515
  This function configures logging for the Ultralytics library, setting the appropriate logging level and
516
516
  formatter based on the verbosity flag and the current process rank. It handles special cases for Windows
517
517
  environments where UTF-8 encoding might not be the default.
518
518
 
519
519
  Args:
520
- name (str): Name of the logger. Defaults to "LOGGING_NAME".
521
- verbose (bool): Flag to set logging level to INFO if True, ERROR otherwise. Defaults to True.
520
+ name (str): Name of the logger.
521
+ verbose (bool): Flag to set logging level to INFO if True, ERROR otherwise.
522
522
 
523
523
  Returns:
524
524
  (logging.Logger): Configured logger object.
@@ -618,7 +618,7 @@ class ThreadingLocked:
618
618
 
619
619
  @wraps(f)
620
620
  def decorated(*args, **kwargs):
621
- """Applies thread-safety to the decorated function or method."""
621
+ """Apply thread-safety to the decorated function or method."""
622
622
  with self.lock:
623
623
  return f(*args, **kwargs)
624
624
 
@@ -767,7 +767,7 @@ DEFAULT_CFG = IterableSimpleNamespace(**DEFAULT_CFG_DICT)
767
767
 
768
768
  def read_device_model() -> str:
769
769
  """
770
- Reads the device model information from the system and caches it for quick access.
770
+ Read the device model information from the system and cache it for quick access.
771
771
 
772
772
  Returns:
773
773
  (str): Kernel release information.
@@ -816,7 +816,7 @@ def is_jupyter():
816
816
  Returns:
817
817
  (bool): True if running inside a Jupyter Notebook, False otherwise.
818
818
 
819
- Note:
819
+ Notes:
820
820
  - Only works on Colab and Kaggle, other environments like Jupyterlab and Paperspace are not reliably detectable.
821
821
  - "get_ipython" in globals() method suffers false positives when IPython package installed manually.
822
822
  """
@@ -849,7 +849,7 @@ def is_docker() -> bool:
849
849
 
850
850
  def is_raspberrypi() -> bool:
851
851
  """
852
- Determines if the Python environment is running on a Raspberry Pi.
852
+ Determine if the Python environment is running on a Raspberry Pi.
853
853
 
854
854
  Returns:
855
855
  (bool): True if running on a Raspberry Pi, False otherwise.
@@ -859,7 +859,7 @@ def is_raspberrypi() -> bool:
859
859
 
860
860
  def is_jetson() -> bool:
861
861
  """
862
- Determines if the Python environment is running on an NVIDIA Jetson device.
862
+ Determine if the Python environment is running on an NVIDIA Jetson device.
863
863
 
864
864
  Returns:
865
865
  (bool): True if running on an NVIDIA Jetson device, False otherwise.
@@ -887,7 +887,7 @@ def is_online() -> bool:
887
887
 
888
888
  def is_pip_package(filepath: str = __name__) -> bool:
889
889
  """
890
- Determines if the file at the given filepath is part of a pip package.
890
+ Determine if the file at the given filepath is part of a pip package.
891
891
 
892
892
  Args:
893
893
  filepath (str): The filepath to check.
@@ -919,7 +919,7 @@ def is_dir_writeable(dir_path: Union[str, Path]) -> bool:
919
919
 
920
920
  def is_pytest_running():
921
921
  """
922
- Determines whether pytest is currently running or not.
922
+ Determine whether pytest is currently running or not.
923
923
 
924
924
  Returns:
925
925
  (bool): True if pytest is running, False otherwise.
@@ -939,7 +939,7 @@ def is_github_action_running() -> bool:
939
939
 
940
940
  def get_git_dir():
941
941
  """
942
- Determines whether the current file is part of a git repository and if so, returns the repository root directory.
942
+ Determine whether the current file is part of a git repository and if so, return the repository root directory.
943
943
 
944
944
  Returns:
945
945
  (Path | None): Git root directory if found or None if not found.
@@ -951,7 +951,7 @@ def get_git_dir():
951
951
 
952
952
  def is_git_dir():
953
953
  """
954
- Determines whether the current file is part of a git repository.
954
+ Determine whether the current file is part of a git repository.
955
955
 
956
956
  Returns:
957
957
  (bool): True if current file is part of a git repository.
@@ -961,7 +961,7 @@ def is_git_dir():
961
961
 
962
962
  def get_git_origin_url():
963
963
  """
964
- Retrieves the origin URL of a git repository.
964
+ Retrieve the origin URL of a git repository.
965
965
 
966
966
  Returns:
967
967
  (str | None): The origin URL of the git repository or None if not git directory.
@@ -976,7 +976,7 @@ def get_git_origin_url():
976
976
 
977
977
  def get_git_branch():
978
978
  """
979
- Returns the current git branch name. If not in a git repository, returns None.
979
+ Return the current git branch name. If not in a git repository, return None.
980
980
 
981
981
  Returns:
982
982
  (str | None): The current git branch name or None if not a git directory.
@@ -991,7 +991,7 @@ def get_git_branch():
991
991
 
992
992
  def get_default_args(func):
993
993
  """
994
- Returns a dictionary of default arguments for a function.
994
+ Return a dictionary of default arguments for a function.
995
995
 
996
996
  Args:
997
997
  func (callable): The function to inspect.
@@ -1069,8 +1069,7 @@ SETTINGS_FILE = USER_CONFIG_DIR / "settings.json"
1069
1069
 
1070
1070
  def colorstr(*input):
1071
1071
  r"""
1072
- Colors a string based on the provided color and style arguments. Utilizes ANSI escape codes.
1073
- See https://en.wikipedia.org/wiki/ANSI_escape_code for more details.
1072
+ Color a string based on the provided color and style arguments using ANSI escape codes.
1074
1073
 
1075
1074
  This function can be called in two ways:
1076
1075
  - colorstr('color', 'style', 'your string')
@@ -1082,18 +1081,22 @@ def colorstr(*input):
1082
1081
  *input (str | Path): A sequence of strings where the first n-1 strings are color and style arguments,
1083
1082
  and the last string is the one to be colored.
1084
1083
 
1085
- Supported Colors and Styles:
1086
- Basic Colors: 'black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'
1087
- Bright Colors: 'bright_black', 'bright_red', 'bright_green', 'bright_yellow',
1088
- 'bright_blue', 'bright_magenta', 'bright_cyan', 'bright_white'
1089
- Misc: 'end', 'bold', 'underline'
1090
-
1091
1084
  Returns:
1092
1085
  (str): The input string wrapped with ANSI escape codes for the specified color and style.
1093
1086
 
1087
+ Notes:
1088
+ Supported Colors and Styles:
1089
+ - Basic Colors: 'black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'
1090
+ - Bright Colors: 'bright_black', 'bright_red', 'bright_green', 'bright_yellow',
1091
+ 'bright_blue', 'bright_magenta', 'bright_cyan', 'bright_white'
1092
+ - Misc: 'end', 'bold', 'underline'
1093
+
1094
1094
  Examples:
1095
1095
  >>> colorstr("blue", "bold", "hello world")
1096
1096
  >>> "\033[34m\033[1mhello world\033[0m"
1097
+
1098
+ References:
1099
+ https://en.wikipedia.org/wiki/ANSI_escape_code
1097
1100
  """
1098
1101
  *args, string = input if len(input) > 1 else ("blue", "bold", input[0]) # color arguments, string
1099
1102
  colors = {
@@ -1122,7 +1125,7 @@ def colorstr(*input):
1122
1125
 
1123
1126
  def remove_colorstr(input_string):
1124
1127
  """
1125
- Removes ANSI escape codes from a string, effectively un-coloring it.
1128
+ Remove ANSI escape codes from a string, effectively un-coloring it.
1126
1129
 
1127
1130
  Args:
1128
1131
  input_string (str): The string to remove color and style from.
@@ -1140,7 +1143,14 @@ def remove_colorstr(input_string):
1140
1143
 
1141
1144
  class TryExcept(contextlib.ContextDecorator):
1142
1145
  """
1143
- Ultralytics TryExcept class. Use as @TryExcept() decorator or 'with TryExcept():' context manager.
1146
+ Ultralytics TryExcept class for handling exceptions gracefully.
1147
+
1148
+ This class can be used as a decorator or context manager to catch exceptions and optionally print warning messages.
1149
+ It allows code to continue execution even when exceptions occur, which is useful for non-critical operations.
1150
+
1151
+ Attributes:
1152
+ msg (str): Optional message to display when an exception occurs.
1153
+ verbose (bool): Whether to print the exception message.
1144
1154
 
1145
1155
  Examples:
1146
1156
  As a decorator:
@@ -1161,11 +1171,11 @@ class TryExcept(contextlib.ContextDecorator):
1161
1171
  self.verbose = verbose
1162
1172
 
1163
1173
  def __enter__(self):
1164
- """Executes when entering TryExcept context, initializes instance."""
1174
+ """Execute when entering TryExcept context, initialize instance."""
1165
1175
  pass
1166
1176
 
1167
1177
  def __exit__(self, exc_type, value, traceback):
1168
- """Defines behavior when exiting a 'with' block, prints error message if necessary."""
1178
+ """Define behavior when exiting a 'with' block, print error message if necessary."""
1169
1179
  if self.verbose and value:
1170
1180
  LOGGER.warning(f"{self.msg}{': ' if self.msg else ''}{value}")
1171
1181
  return True
@@ -1175,8 +1185,13 @@ class Retry(contextlib.ContextDecorator):
1175
1185
  """
1176
1186
  Retry class for function execution with exponential backoff.
1177
1187
 
1178
- Can be used as a decorator to retry a function on exceptions, up to a specified number of times with an
1179
- exponentially increasing delay between retries.
1188
+ This decorator can be used to retry a function on exceptions, up to a specified number of times with an
1189
+ exponentially increasing delay between retries. It's useful for handling transient failures in network
1190
+ operations or other unreliable processes.
1191
+
1192
+ Attributes:
1193
+ times (int): Maximum number of retry attempts.
1194
+ delay (int): Initial delay between retries in seconds.
1180
1195
 
1181
1196
  Examples:
1182
1197
  Example usage as a decorator:
@@ -1196,7 +1211,7 @@ class Retry(contextlib.ContextDecorator):
1196
1211
  """Decorator implementation for Retry with exponential backoff."""
1197
1212
 
1198
1213
  def wrapped_func(*args, **kwargs):
1199
- """Applies retries to the decorated function or method."""
1214
+ """Apply retries to the decorated function or method."""
1200
1215
  self._attempts = 0
1201
1216
  while self._attempts < self.times:
1202
1217
  try:
@@ -1213,7 +1228,7 @@ class Retry(contextlib.ContextDecorator):
1213
1228
 
1214
1229
  def threaded(func):
1215
1230
  """
1216
- Multi-threads a target function by default and returns the thread or function result.
1231
+ Multi-thread a target function by default and return the thread or function result.
1217
1232
 
1218
1233
  This decorator provides flexible execution of the target function, either in a separate thread or synchronously.
1219
1234
  By default, the function runs in a thread, but this can be controlled via the 'threaded=False' keyword argument
@@ -1235,7 +1250,7 @@ def threaded(func):
1235
1250
  """
1236
1251
 
1237
1252
  def wrapper(*args, **kwargs):
1238
- """Multi-threads a given function based on 'threaded' kwarg and returns the thread or function result."""
1253
+ """Multi-thread a given function based on 'threaded' kwarg and return the thread or function result."""
1239
1254
  if kwargs.pop("threaded", True): # run in thread
1240
1255
  thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
1241
1256
  thread.start()
@@ -1288,7 +1303,7 @@ def set_sentry():
1288
1303
  hint (dict): A dictionary containing additional information about the error.
1289
1304
 
1290
1305
  Returns:
1291
- dict: The modified event or None if the event should not be sent to Sentry.
1306
+ (dict | None): The modified event or None if the event should not be sent to Sentry.
1292
1307
  """
1293
1308
  if "exc_info" in hint:
1294
1309
  exc_type, exc_value, _ = hint["exc_info"]
@@ -1321,19 +1336,19 @@ class JSONDict(dict):
1321
1336
  A dictionary-like class that provides JSON persistence for its contents.
1322
1337
 
1323
1338
  This class extends the built-in dictionary to automatically save its contents to a JSON file whenever they are
1324
- modified. It ensures thread-safe operations using a lock.
1339
+ modified. It ensures thread-safe operations using a lock and handles JSON serialization of Path objects.
1325
1340
 
1326
1341
  Attributes:
1327
1342
  file_path (Path): The path to the JSON file used for persistence.
1328
1343
  lock (threading.Lock): A lock object to ensure thread-safe operations.
1329
1344
 
1330
1345
  Methods:
1331
- _load: Loads the data from the JSON file into the dictionary.
1332
- _save: Saves the current state of the dictionary to the JSON file.
1333
- __setitem__: Stores a key-value pair and persists it to disk.
1334
- __delitem__: Removes an item and updates the persistent storage.
1335
- update: Updates the dictionary and persists changes.
1336
- clear: Clears all entries and updates the persistent storage.
1346
+ _load: Load the data from the JSON file into the dictionary.
1347
+ _save: Save the current state of the dictionary to the JSON file.
1348
+ __setitem__: Store a key-value pair and persist it to disk.
1349
+ __delitem__: Remove an item and update the persistent storage.
1350
+ update: Update the dictionary and persist changes.
1351
+ clear: Clear all entries and update the persistent storage.
1337
1352
 
1338
1353
  Examples:
1339
1354
  >>> json_dict = JSONDict("data.json")
@@ -1414,7 +1429,8 @@ class SettingsManager(JSONDict):
1414
1429
  SettingsManager class for managing and persisting Ultralytics settings.
1415
1430
 
1416
1431
  This class extends JSONDict to provide JSON persistence for settings, ensuring thread-safe operations and default
1417
- values. It validates settings on initialization and provides methods to update or reset settings.
1432
+ values. It validates settings on initialization and provides methods to update or reset settings. The settings
1433
+ include directories for datasets, weights, and runs, as well as various integration flags.
1418
1434
 
1419
1435
  Attributes:
1420
1436
  file (Path): The path to the JSON file used for persistence.
@@ -1423,9 +1439,9 @@ class SettingsManager(JSONDict):
1423
1439
  help_msg (str): A help message for users on how to view and update settings.
1424
1440
 
1425
1441
  Methods:
1426
- _validate_settings: Validates the current settings and resets if necessary.
1427
- update: Updates settings, validating keys and types.
1428
- reset: Resets the settings to default and saves them.
1442
+ _validate_settings: Validate the current settings and reset if necessary.
1443
+ update: Update settings, validating keys and types.
1444
+ reset: Reset the settings to default and save them.
1429
1445
 
1430
1446
  Examples:
1431
1447
  Initialize and update settings:
@@ -1436,7 +1452,7 @@ class SettingsManager(JSONDict):
1436
1452
  """
1437
1453
 
1438
1454
  def __init__(self, file=SETTINGS_FILE, version="0.0.6"):
1439
- """Initializes the SettingsManager with default settings and loads user settings."""
1455
+ """Initialize the SettingsManager with default settings and load user settings."""
1440
1456
  import hashlib
1441
1457
  import uuid
1442
1458
 
@@ -1505,11 +1521,11 @@ class SettingsManager(JSONDict):
1505
1521
  )
1506
1522
 
1507
1523
  def __setitem__(self, key, value):
1508
- """Updates one key: value pair."""
1524
+ """Update one key: value pair."""
1509
1525
  self.update({key: value})
1510
1526
 
1511
1527
  def update(self, *args, **kwargs):
1512
- """Updates settings, validating keys and types."""
1528
+ """Update settings, validating keys and types."""
1513
1529
  for arg in args:
1514
1530
  if isinstance(arg, dict):
1515
1531
  kwargs.update(arg)
@@ -1524,7 +1540,7 @@ class SettingsManager(JSONDict):
1524
1540
  super().update(*args, **kwargs)
1525
1541
 
1526
1542
  def reset(self):
1527
- """Resets the settings to default and saves them."""
1543
+ """Reset the settings to default and save them."""
1528
1544
  self.clear()
1529
1545
  self.update(self.defaults)
1530
1546
 
@@ -3,6 +3,7 @@
3
3
 
4
4
  import os
5
5
  from copy import deepcopy
6
+ from typing import Union
6
7
 
7
8
  import numpy as np
8
9
  import torch
@@ -11,7 +12,13 @@ from ultralytics.utils import DEFAULT_CFG, LOGGER, colorstr
11
12
  from ultralytics.utils.torch_utils import autocast, profile_ops
12
13
 
13
14
 
14
- def check_train_batch_size(model, imgsz=640, amp=True, batch=-1, max_num_obj=1):
15
+ def check_train_batch_size(
16
+ model: torch.nn.Module,
17
+ imgsz: int = 640,
18
+ amp: bool = True,
19
+ batch: Union[int, float] = -1,
20
+ max_num_obj: int = 1,
21
+ ) -> int:
15
22
  """
16
23
  Compute optimal YOLO training batch size using the autobatch() function.
17
24
 
@@ -19,7 +26,7 @@ def check_train_batch_size(model, imgsz=640, amp=True, batch=-1, max_num_obj=1):
19
26
  model (torch.nn.Module): YOLO model to check batch size for.
20
27
  imgsz (int, optional): Image size used for training.
21
28
  amp (bool, optional): Use automatic mixed precision if True.
22
- batch (float, optional): Fraction of GPU memory to use. If -1, use default.
29
+ batch (int | float, optional): Fraction of GPU memory to use. If -1, use default.
23
30
  max_num_obj (int, optional): The maximum number of objects from dataset.
24
31
 
25
32
  Returns:
@@ -35,7 +42,13 @@ def check_train_batch_size(model, imgsz=640, amp=True, batch=-1, max_num_obj=1):
35
42
  )
36
43
 
37
44
 
38
- def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch, max_num_obj=1):
45
+ def autobatch(
46
+ model: torch.nn.Module,
47
+ imgsz: int = 640,
48
+ fraction: float = 0.60,
49
+ batch_size: int = DEFAULT_CFG.batch,
50
+ max_num_obj: int = 1,
51
+ ) -> int:
39
52
  """
40
53
  Automatically estimate the best YOLO batch size to use a fraction of the available CUDA memory.
41
54