ultralytics 8.3.142__py3-none-any.whl → 8.3.144__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. tests/conftest.py +7 -24
  2. tests/test_cli.py +1 -1
  3. tests/test_cuda.py +7 -2
  4. tests/test_engine.py +7 -8
  5. tests/test_exports.py +16 -16
  6. tests/test_integrations.py +1 -1
  7. tests/test_solutions.py +12 -12
  8. ultralytics/__init__.py +1 -1
  9. ultralytics/cfg/__init__.py +16 -13
  10. ultralytics/data/annotator.py +6 -5
  11. ultralytics/data/augment.py +127 -126
  12. ultralytics/data/base.py +54 -51
  13. ultralytics/data/build.py +47 -23
  14. ultralytics/data/converter.py +47 -43
  15. ultralytics/data/dataset.py +51 -50
  16. ultralytics/data/loaders.py +77 -44
  17. ultralytics/data/split.py +22 -9
  18. ultralytics/data/split_dota.py +63 -39
  19. ultralytics/data/utils.py +59 -39
  20. ultralytics/engine/exporter.py +79 -27
  21. ultralytics/engine/model.py +39 -39
  22. ultralytics/engine/predictor.py +37 -28
  23. ultralytics/engine/results.py +187 -157
  24. ultralytics/engine/trainer.py +36 -19
  25. ultralytics/engine/tuner.py +12 -9
  26. ultralytics/engine/validator.py +7 -9
  27. ultralytics/hub/__init__.py +11 -13
  28. ultralytics/hub/auth.py +22 -2
  29. ultralytics/hub/google/__init__.py +19 -19
  30. ultralytics/hub/session.py +37 -51
  31. ultralytics/hub/utils.py +19 -5
  32. ultralytics/models/fastsam/model.py +30 -12
  33. ultralytics/models/fastsam/predict.py +5 -6
  34. ultralytics/models/fastsam/utils.py +3 -3
  35. ultralytics/models/fastsam/val.py +10 -6
  36. ultralytics/models/nas/model.py +9 -5
  37. ultralytics/models/nas/predict.py +6 -6
  38. ultralytics/models/nas/val.py +3 -3
  39. ultralytics/models/rtdetr/model.py +7 -6
  40. ultralytics/models/rtdetr/predict.py +14 -7
  41. ultralytics/models/rtdetr/train.py +10 -4
  42. ultralytics/models/rtdetr/val.py +36 -9
  43. ultralytics/models/sam/amg.py +30 -12
  44. ultralytics/models/sam/build.py +22 -22
  45. ultralytics/models/sam/model.py +10 -9
  46. ultralytics/models/sam/modules/blocks.py +76 -80
  47. ultralytics/models/sam/modules/decoders.py +6 -8
  48. ultralytics/models/sam/modules/encoders.py +23 -26
  49. ultralytics/models/sam/modules/memory_attention.py +13 -1
  50. ultralytics/models/sam/modules/sam.py +57 -26
  51. ultralytics/models/sam/modules/tiny_encoder.py +232 -237
  52. ultralytics/models/sam/modules/transformer.py +13 -13
  53. ultralytics/models/sam/modules/utils.py +11 -19
  54. ultralytics/models/sam/predict.py +114 -101
  55. ultralytics/models/utils/loss.py +98 -77
  56. ultralytics/models/utils/ops.py +116 -67
  57. ultralytics/models/yolo/classify/predict.py +5 -5
  58. ultralytics/models/yolo/classify/train.py +32 -28
  59. ultralytics/models/yolo/classify/val.py +7 -8
  60. ultralytics/models/yolo/detect/predict.py +1 -0
  61. ultralytics/models/yolo/detect/train.py +15 -14
  62. ultralytics/models/yolo/detect/val.py +37 -36
  63. ultralytics/models/yolo/model.py +106 -23
  64. ultralytics/models/yolo/obb/predict.py +3 -4
  65. ultralytics/models/yolo/obb/train.py +14 -6
  66. ultralytics/models/yolo/obb/val.py +29 -23
  67. ultralytics/models/yolo/pose/predict.py +9 -8
  68. ultralytics/models/yolo/pose/train.py +24 -16
  69. ultralytics/models/yolo/pose/val.py +44 -26
  70. ultralytics/models/yolo/segment/predict.py +5 -5
  71. ultralytics/models/yolo/segment/train.py +11 -7
  72. ultralytics/models/yolo/segment/val.py +2 -2
  73. ultralytics/models/yolo/world/train.py +33 -23
  74. ultralytics/models/yolo/world/train_world.py +11 -3
  75. ultralytics/models/yolo/yoloe/predict.py +11 -11
  76. ultralytics/models/yolo/yoloe/train.py +73 -21
  77. ultralytics/models/yolo/yoloe/train_seg.py +10 -7
  78. ultralytics/models/yolo/yoloe/val.py +42 -18
  79. ultralytics/nn/autobackend.py +59 -15
  80. ultralytics/nn/modules/__init__.py +4 -4
  81. ultralytics/nn/modules/activation.py +4 -1
  82. ultralytics/nn/modules/block.py +178 -111
  83. ultralytics/nn/modules/conv.py +6 -5
  84. ultralytics/nn/modules/head.py +469 -121
  85. ultralytics/nn/modules/transformer.py +147 -58
  86. ultralytics/nn/tasks.py +227 -20
  87. ultralytics/nn/text_model.py +30 -33
  88. ultralytics/solutions/ai_gym.py +1 -1
  89. ultralytics/solutions/analytics.py +7 -4
  90. ultralytics/solutions/config.py +10 -10
  91. ultralytics/solutions/distance_calculation.py +11 -10
  92. ultralytics/solutions/heatmap.py +1 -1
  93. ultralytics/solutions/instance_segmentation.py +6 -3
  94. ultralytics/solutions/object_blurrer.py +3 -3
  95. ultralytics/solutions/object_counter.py +16 -8
  96. ultralytics/solutions/object_cropper.py +12 -5
  97. ultralytics/solutions/parking_management.py +29 -28
  98. ultralytics/solutions/queue_management.py +6 -6
  99. ultralytics/solutions/region_counter.py +10 -3
  100. ultralytics/solutions/security_alarm.py +3 -3
  101. ultralytics/solutions/similarity_search.py +85 -24
  102. ultralytics/solutions/solutions.py +215 -85
  103. ultralytics/solutions/speed_estimation.py +28 -22
  104. ultralytics/solutions/streamlit_inference.py +17 -12
  105. ultralytics/solutions/trackzone.py +4 -4
  106. ultralytics/trackers/basetrack.py +16 -23
  107. ultralytics/trackers/bot_sort.py +30 -20
  108. ultralytics/trackers/byte_tracker.py +70 -64
  109. ultralytics/trackers/track.py +4 -8
  110. ultralytics/trackers/utils/gmc.py +31 -58
  111. ultralytics/trackers/utils/kalman_filter.py +37 -37
  112. ultralytics/trackers/utils/matching.py +1 -1
  113. ultralytics/utils/__init__.py +105 -89
  114. ultralytics/utils/autobatch.py +16 -3
  115. ultralytics/utils/autodevice.py +54 -24
  116. ultralytics/utils/benchmarks.py +42 -28
  117. ultralytics/utils/callbacks/base.py +3 -3
  118. ultralytics/utils/callbacks/clearml.py +9 -9
  119. ultralytics/utils/callbacks/comet.py +67 -25
  120. ultralytics/utils/callbacks/dvc.py +7 -10
  121. ultralytics/utils/callbacks/mlflow.py +2 -5
  122. ultralytics/utils/callbacks/neptune.py +7 -13
  123. ultralytics/utils/callbacks/raytune.py +1 -1
  124. ultralytics/utils/callbacks/tensorboard.py +5 -6
  125. ultralytics/utils/callbacks/wb.py +14 -14
  126. ultralytics/utils/checks.py +14 -13
  127. ultralytics/utils/dist.py +5 -5
  128. ultralytics/utils/downloads.py +94 -67
  129. ultralytics/utils/errors.py +5 -5
  130. ultralytics/utils/export.py +61 -47
  131. ultralytics/utils/files.py +23 -22
  132. ultralytics/utils/instance.py +48 -52
  133. ultralytics/utils/loss.py +78 -40
  134. ultralytics/utils/metrics.py +186 -130
  135. ultralytics/utils/ops.py +186 -190
  136. ultralytics/utils/patches.py +15 -17
  137. ultralytics/utils/plotting.py +71 -27
  138. ultralytics/utils/tal.py +21 -15
  139. ultralytics/utils/torch_utils.py +53 -50
  140. ultralytics/utils/triton.py +5 -4
  141. ultralytics/utils/tuner.py +5 -5
  142. {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/METADATA +1 -1
  143. ultralytics-8.3.144.dist-info/RECORD +272 -0
  144. ultralytics-8.3.142.dist-info/RECORD +0 -272
  145. {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/WHEEL +0 -0
  146. {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/entry_points.txt +0 -0
  147. {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/licenses/LICENSE +0 -0
  148. {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/top_level.txt +0 -0
@@ -6,6 +6,7 @@ import subprocess
6
6
  from itertools import repeat
7
7
  from multiprocessing.pool import ThreadPool
8
8
  from pathlib import Path
9
+ from typing import List, Tuple
9
10
  from urllib import parse, request
10
11
 
11
12
  import torch
@@ -41,21 +42,20 @@ GITHUB_ASSETS_NAMES = frozenset(
41
42
  GITHUB_ASSETS_STEMS = frozenset(k.rpartition(".")[0] for k in GITHUB_ASSETS_NAMES)
42
43
 
43
44
 
44
- def is_url(url, check=False):
45
+ def is_url(url, check: bool = False) -> bool:
45
46
  """
46
- Validates if the given string is a URL and optionally checks if the URL exists online.
47
+ Validate if the given string is a URL and optionally check if the URL exists online.
47
48
 
48
49
  Args:
49
50
  url (str): The string to be validated as a URL.
50
51
  check (bool, optional): If True, performs an additional check to see if the URL exists online.
51
- Defaults to False.
52
52
 
53
53
  Returns:
54
- (bool): Returns True for a valid URL. If 'check' is True, also returns True if the URL exists online.
55
- Returns False otherwise.
54
+ (bool): True for a valid URL. If 'check' is True, also returns True if the URL exists online.
56
55
 
57
56
  Examples:
58
57
  >>> valid = is_url("https://www.example.com")
58
+ >>> valid_and_exists = is_url("https://www.example.com", check=True)
59
59
  """
60
60
  try:
61
61
  url = str(url)
@@ -71,10 +71,10 @@ def is_url(url, check=False):
71
71
 
72
72
  def delete_dsstore(path, files_to_delete=(".DS_Store", "__MACOSX")):
73
73
  """
74
- Delete all ".DS_store" files in a specified directory.
74
+ Delete all specified system files in a directory.
75
75
 
76
76
  Args:
77
- path (str, optional): The directory path where the ".DS_store" files should be deleted.
77
+ path (str | Path): The directory path where the files should be deleted.
78
78
  files_to_delete (tuple): The files to be deleted.
79
79
 
80
80
  Examples:
@@ -92,16 +92,17 @@ def delete_dsstore(path, files_to_delete=(".DS_Store", "__MACOSX")):
92
92
  f.unlink()
93
93
 
94
94
 
95
- def zip_directory(directory, compress=True, exclude=(".DS_Store", "__MACOSX"), progress=True):
95
+ def zip_directory(directory, compress: bool = True, exclude=(".DS_Store", "__MACOSX"), progress: bool = True) -> Path:
96
96
  """
97
- Zips the contents of a directory, excluding files containing strings in the exclude list. The resulting zip file is
98
- named after the directory and placed alongside it.
97
+ Zip the contents of a directory, excluding specified files.
98
+
99
+ The resulting zip file is named after the directory and placed alongside it.
99
100
 
100
101
  Args:
101
102
  directory (str | Path): The path to the directory to be zipped.
102
- compress (bool): Whether to compress the files while zipping. Default is True.
103
- exclude (tuple, optional): A tuple of filename strings to be excluded. Defaults to ('.DS_Store', '__MACOSX').
104
- progress (bool, optional): Whether to display a progress bar. Defaults to True.
103
+ compress (bool): Whether to compress the files while zipping.
104
+ exclude (tuple, optional): A tuple of filename strings to be excluded.
105
+ progress (bool, optional): Whether to display a progress bar.
105
106
 
106
107
  Returns:
107
108
  (Path): The path to the resulting zip file.
@@ -117,7 +118,7 @@ def zip_directory(directory, compress=True, exclude=(".DS_Store", "__MACOSX"), p
117
118
  if not directory.is_dir():
118
119
  raise FileNotFoundError(f"Directory '{directory}' does not exist.")
119
120
 
120
- # Unzip with progress bar
121
+ # Zip with progress bar
121
122
  files_to_zip = [f for f in directory.rglob("*") if f.is_file() and all(x not in f.name for x in exclude)]
122
123
  zip_file = directory.with_suffix(".zip")
123
124
  compression = ZIP_DEFLATED if compress else ZIP_STORED
@@ -128,9 +129,15 @@ def zip_directory(directory, compress=True, exclude=(".DS_Store", "__MACOSX"), p
128
129
  return zip_file # return path to zip file
129
130
 
130
131
 
131
- def unzip_file(file, path=None, exclude=(".DS_Store", "__MACOSX"), exist_ok=False, progress=True):
132
+ def unzip_file(
133
+ file,
134
+ path=None,
135
+ exclude=(".DS_Store", "__MACOSX"),
136
+ exist_ok: bool = False,
137
+ progress: bool = True,
138
+ ) -> Path:
132
139
  """
133
- Unzips a *.zip file to the specified path, excluding files containing strings in the exclude list.
140
+ Unzip a *.zip file to the specified path, excluding specified files.
134
141
 
135
142
  If the zipfile does not contain a single top-level directory, the function will create a new
136
143
  directory with the same name as the zipfile (without the extension) to extract its contents.
@@ -138,17 +145,17 @@ def unzip_file(file, path=None, exclude=(".DS_Store", "__MACOSX"), exist_ok=Fals
138
145
 
139
146
  Args:
140
147
  file (str | Path): The path to the zipfile to be extracted.
141
- path (str | Path, optional): The path to extract the zipfile to. Defaults to None.
142
- exclude (tuple, optional): A tuple of filename strings to be excluded. Defaults to ('.DS_Store', '__MACOSX').
143
- exist_ok (bool, optional): Whether to overwrite existing contents if they exist. Defaults to False.
144
- progress (bool, optional): Whether to display a progress bar. Defaults to True.
145
-
146
- Raises:
147
- BadZipFile: If the provided file does not exist or is not a valid zipfile.
148
+ path (str | Path, optional): The path to extract the zipfile to.
149
+ exclude (tuple, optional): A tuple of filename strings to be excluded.
150
+ exist_ok (bool, optional): Whether to overwrite existing contents if they exist.
151
+ progress (bool, optional): Whether to display a progress bar.
148
152
 
149
153
  Returns:
150
154
  (Path): The path to the directory where the zipfile was extracted.
151
155
 
156
+ Raises:
157
+ BadZipFile: If the provided file does not exist or is not a valid zipfile.
158
+
152
159
  Examples:
153
160
  >>> from ultralytics.utils.downloads import unzip_file
154
161
  >>> directory = unzip_file("path/to/file.zip")
@@ -191,15 +198,20 @@ def unzip_file(file, path=None, exclude=(".DS_Store", "__MACOSX"), exist_ok=Fals
191
198
  return path # return unzip dir
192
199
 
193
200
 
194
- def check_disk_space(url="https://ultralytics.com/assets/coco8.zip", path=Path.cwd(), sf=1.5, hard=True):
201
+ def check_disk_space(
202
+ url: str = "https://ultralytics.com/assets/coco8.zip",
203
+ path=Path.cwd(),
204
+ sf: float = 1.5,
205
+ hard: bool = True,
206
+ ) -> bool:
195
207
  """
196
208
  Check if there is sufficient disk space to download and store a file.
197
209
 
198
210
  Args:
199
- url (str, optional): The URL to the file. Defaults to 'https://ultralytics.com/assets/coco8.zip'.
211
+ url (str, optional): The URL to the file.
200
212
  path (str | Path, optional): The path or drive to check the available free space on.
201
- sf (float, optional): Safety factor, the multiplier for the required free space. Defaults to 1.5.
202
- hard (bool, optional): Whether to throw an error or not on insufficient disk space. Defaults to True.
213
+ sf (float, optional): Safety factor, the multiplier for the required free space.
214
+ hard (bool, optional): Whether to throw an error or not on insufficient disk space.
203
215
 
204
216
  Returns:
205
217
  (bool): True if there is sufficient disk space, False otherwise.
@@ -231,16 +243,16 @@ def check_disk_space(url="https://ultralytics.com/assets/coco8.zip", path=Path.c
231
243
  return False
232
244
 
233
245
 
234
- def get_google_drive_file_info(link):
246
+ def get_google_drive_file_info(link: str) -> Tuple[str, str]:
235
247
  """
236
- Retrieves the direct download link and filename for a shareable Google Drive file link.
248
+ Retrieve the direct download link and filename for a shareable Google Drive file link.
237
249
 
238
250
  Args:
239
251
  link (str): The shareable link of the Google Drive file.
240
252
 
241
253
  Returns:
242
- (str): Direct download URL for the Google Drive file.
243
- (str): Original filename of the Google Drive file. If filename extraction fails, returns None.
254
+ url (str): Direct download URL for the Google Drive file.
255
+ filename (str | None): Original filename of the Google Drive file. If filename extraction fails, returns None.
244
256
 
245
257
  Examples:
246
258
  >>> from ultralytics.utils.downloads import get_google_drive_file_info
@@ -275,16 +287,16 @@ def safe_download(
275
287
  url,
276
288
  file=None,
277
289
  dir=None,
278
- unzip=True,
279
- delete=False,
280
- curl=False,
281
- retry=3,
282
- min_bytes=1e0,
283
- exist_ok=False,
284
- progress=True,
290
+ unzip: bool = True,
291
+ delete: bool = False,
292
+ curl: bool = False,
293
+ retry: int = 3,
294
+ min_bytes: float = 1e0,
295
+ exist_ok: bool = False,
296
+ progress: bool = True,
285
297
  ):
286
298
  """
287
- Downloads files from a URL, with options for retrying, unzipping, and deleting the downloaded file.
299
+ Download files from a URL with options for retrying, unzipping, and deleting the downloaded file.
288
300
 
289
301
  Args:
290
302
  url (str): The URL of the file to be downloaded.
@@ -292,14 +304,14 @@ def safe_download(
292
304
  If not provided, the file will be saved with the same name as the URL.
293
305
  dir (str | Path, optional): The directory to save the downloaded file.
294
306
  If not provided, the file will be saved in the current working directory.
295
- unzip (bool, optional): Whether to unzip the downloaded file. Default: True.
296
- delete (bool, optional): Whether to delete the downloaded file after unzipping. Default: False.
297
- curl (bool, optional): Whether to use curl command line tool for downloading. Default: False.
298
- retry (int, optional): The number of times to retry the download in case of failure. Default: 3.
307
+ unzip (bool, optional): Whether to unzip the downloaded file.
308
+ delete (bool, optional): Whether to delete the downloaded file after unzipping.
309
+ curl (bool, optional): Whether to use curl command line tool for downloading.
310
+ retry (int, optional): The number of times to retry the download in case of failure.
299
311
  min_bytes (float, optional): The minimum number of bytes that the downloaded file should have, to be considered
300
- a successful download. Default: 1E0.
301
- exist_ok (bool, optional): Whether to overwrite existing contents during unzipping. Defaults to False.
302
- progress (bool, optional): Whether to display a progress bar during the download. Default: True.
312
+ a successful download.
313
+ exist_ok (bool, optional): Whether to overwrite existing contents during unzipping.
314
+ progress (bool, optional): Whether to display a progress bar during the download.
303
315
 
304
316
  Returns:
305
317
  (Path | str): The path to the downloaded file or extracted directory.
@@ -376,19 +388,24 @@ def safe_download(
376
388
  return f
377
389
 
378
390
 
379
- def get_github_assets(repo="ultralytics/assets", version="latest", retry=False):
391
+ def get_github_assets(
392
+ repo: str = "ultralytics/assets",
393
+ version: str = "latest",
394
+ retry: bool = False,
395
+ ) -> Tuple[str, List[str]]:
380
396
  """
381
- Retrieve the specified version's tag and assets from a GitHub repository. If the version is not specified, the
382
- function fetches the latest release assets.
397
+ Retrieve the specified version's tag and assets from a GitHub repository.
398
+
399
+ If the version is not specified, the function fetches the latest release assets.
383
400
 
384
401
  Args:
385
- repo (str, optional): The GitHub repository in the format 'owner/repo'. Defaults to 'ultralytics/assets'.
386
- version (str, optional): The release version to fetch assets from. Defaults to 'latest'.
387
- retry (bool, optional): Flag to retry the request in case of a failure. Defaults to False.
402
+ repo (str, optional): The GitHub repository in the format 'owner/repo'.
403
+ version (str, optional): The release version to fetch assets from.
404
+ retry (bool, optional): Flag to retry the request in case of a failure.
388
405
 
389
406
  Returns:
390
- (str): The release tag.
391
- (List[str]): A list of asset names.
407
+ tag (str): The release tag.
408
+ assets (List[str]): A list of asset names.
392
409
 
393
410
  Examples:
394
411
  >>> tag, assets = get_github_assets(repo="ultralytics/assets", version="latest")
@@ -408,14 +425,14 @@ def get_github_assets(repo="ultralytics/assets", version="latest", retry=False):
408
425
  return data["tag_name"], [x["name"] for x in data["assets"]] # tag, assets i.e. ['yolo11n.pt', 'yolov8s.pt', ...]
409
426
 
410
427
 
411
- def attempt_download_asset(file, repo="ultralytics/assets", release="v8.3.0", **kwargs):
428
+ def attempt_download_asset(file, repo: str = "ultralytics/assets", release: str = "v8.3.0", **kwargs) -> str:
412
429
  """
413
430
  Attempt to download a file from GitHub release assets if it is not found locally.
414
431
 
415
432
  Args:
416
433
  file (str | Path): The filename or file path to be downloaded.
417
- repo (str, optional): The GitHub repository in the format 'owner/repo'. Defaults to 'ultralytics/assets'.
418
- release (str, optional): The specific release version to be downloaded. Defaults to 'v8.3.0'.
434
+ repo (str, optional): The GitHub repository in the format 'owner/repo'.
435
+ release (str, optional): The specific release version to be downloaded.
419
436
  **kwargs (Any): Additional keyword arguments for the download process.
420
437
 
421
438
  Returns:
@@ -459,20 +476,30 @@ def attempt_download_asset(file, repo="ultralytics/assets", release="v8.3.0", **
459
476
  return str(file)
460
477
 
461
478
 
462
- def download(url, dir=Path.cwd(), unzip=True, delete=False, curl=False, threads=1, retry=3, exist_ok=False):
479
+ def download(
480
+ url,
481
+ dir=Path.cwd(),
482
+ unzip: bool = True,
483
+ delete: bool = False,
484
+ curl: bool = False,
485
+ threads: int = 1,
486
+ retry: int = 3,
487
+ exist_ok: bool = False,
488
+ ):
463
489
  """
464
- Downloads files from specified URLs to a given directory. Supports concurrent downloads if multiple threads are
465
- specified.
490
+ Download files from specified URLs to a given directory.
491
+
492
+ Supports concurrent downloads if multiple threads are specified.
466
493
 
467
494
  Args:
468
495
  url (str | List[str]): The URL or list of URLs of the files to be downloaded.
469
- dir (Path, optional): The directory where the files will be saved. Defaults to the current working directory.
470
- unzip (bool, optional): Flag to unzip the files after downloading. Defaults to True.
471
- delete (bool, optional): Flag to delete the zip files after extraction. Defaults to False.
472
- curl (bool, optional): Flag to use curl for downloading. Defaults to False.
473
- threads (int, optional): Number of threads to use for concurrent downloads. Defaults to 1.
474
- retry (int, optional): Number of retries in case of download failure. Defaults to 3.
475
- exist_ok (bool, optional): Whether to overwrite existing contents during unzipping. Defaults to False.
496
+ dir (Path, optional): The directory where the files will be saved.
497
+ unzip (bool, optional): Flag to unzip the files after downloading.
498
+ delete (bool, optional): Flag to delete the zip files after extraction.
499
+ curl (bool, optional): Flag to use curl for downloading.
500
+ threads (int, optional): Number of threads to use for concurrent downloads.
501
+ retry (int, optional): Number of retries in case of download failure.
502
+ exist_ok (bool, optional): Whether to overwrite existing contents during unzipping.
476
503
 
477
504
  Examples:
478
505
  >>> download("https://ultralytics.com/assets/example.zip", dir="path/to/dir", unzip=True)
@@ -18,13 +18,13 @@ class HUBModelError(Exception):
18
18
 
19
19
  Examples:
20
20
  >>> try:
21
- >>> # Code that might fail to find a model
22
- >>> raise HUBModelError("Custom model not found message")
23
- >>> except HUBModelError as e:
24
- >>> print(e) # Displays the emoji-enhanced error message
21
+ ... # Code that might fail to find a model
22
+ ... raise HUBModelError("Custom model not found message")
23
+ ... except HUBModelError as e:
24
+ ... print(e) # Displays the emoji-enhanced error message
25
25
  """
26
26
 
27
- def __init__(self, message="Model not found. Please check model URL and try again."):
27
+ def __init__(self, message: str = "Model not found. Please check model URL and try again."):
28
28
  """
29
29
  Initialize a HUBModelError exception.
30
30
 
@@ -2,6 +2,7 @@
2
2
 
3
3
  import json
4
4
  from pathlib import Path
5
+ from typing import Dict, List, Optional, Tuple, Union
5
6
 
6
7
  import torch
7
8
 
@@ -9,28 +10,28 @@ from ultralytics.utils import IS_JETSON, LOGGER
9
10
 
10
11
 
11
12
  def export_onnx(
12
- torch_model,
13
- im,
14
- onnx_file,
15
- opset=14,
16
- input_names=["images"],
17
- output_names=["output0"],
18
- dynamic=False,
19
- ):
13
+ torch_model: torch.nn.Module,
14
+ im: torch.Tensor,
15
+ onnx_file: str,
16
+ opset: int = 14,
17
+ input_names: List[str] = ["images"],
18
+ output_names: List[str] = ["output0"],
19
+ dynamic: Union[bool, Dict] = False,
20
+ ) -> None:
20
21
  """
21
- Exports a PyTorch model to ONNX format.
22
+ Export a PyTorch model to ONNX format.
22
23
 
23
24
  Args:
24
25
  torch_model (torch.nn.Module): The PyTorch model to export.
25
26
  im (torch.Tensor): Example input tensor for the model.
26
27
  onnx_file (str): Path to save the exported ONNX file.
27
28
  opset (int): ONNX opset version to use for export.
28
- input_names (list): List of input tensor names.
29
- output_names (list): List of output tensor names.
30
- dynamic (bool | dict, optional): Whether to enable dynamic axes. Defaults to False.
29
+ input_names (List[str]): List of input tensor names.
30
+ output_names (List[str]): List of output tensor names.
31
+ dynamic (bool | Dict, optional): Whether to enable dynamic axes.
31
32
 
32
33
  Notes:
33
- - Setting `do_constant_folding=True` may cause issues with DNN inference for torch>=1.12.
34
+ Setting `do_constant_folding=True` may cause issues with DNN inference for torch>=1.12.
34
35
  """
35
36
  torch.onnx.export(
36
37
  torch_model,
@@ -46,44 +47,44 @@ def export_onnx(
46
47
 
47
48
 
48
49
  def export_engine(
49
- onnx_file,
50
- engine_file=None,
51
- workspace=None,
52
- half=False,
53
- int8=False,
54
- dynamic=False,
55
- shape=(1, 3, 640, 640),
56
- dla=None,
50
+ onnx_file: str,
51
+ engine_file: Optional[str] = None,
52
+ workspace: Optional[int] = None,
53
+ half: bool = False,
54
+ int8: bool = False,
55
+ dynamic: bool = False,
56
+ shape: Tuple[int, int, int, int] = (1, 3, 640, 640),
57
+ dla: Optional[int] = None,
57
58
  dataset=None,
58
- metadata=None,
59
- verbose=False,
60
- prefix="",
61
- ):
59
+ metadata: Optional[Dict] = None,
60
+ verbose: bool = False,
61
+ prefix: str = "",
62
+ ) -> None:
62
63
  """
63
- Exports a YOLO model to TensorRT engine format.
64
+ Export a YOLO model to TensorRT engine format.
64
65
 
65
66
  Args:
66
67
  onnx_file (str): Path to the ONNX file to be converted.
67
68
  engine_file (str, optional): Path to save the generated TensorRT engine file.
68
- workspace (int, optional): Workspace size in GB for TensorRT. Defaults to None.
69
- half (bool, optional): Enable FP16 precision. Defaults to False.
70
- int8 (bool, optional): Enable INT8 precision. Defaults to False.
71
- dynamic (bool, optional): Enable dynamic input shapes. Defaults to False.
72
- shape (tuple, optional): Input shape (batch, channels, height, width). Defaults to (1, 3, 640, 640).
73
- dla (int, optional): DLA core to use (Jetson devices only). Defaults to None.
74
- dataset (ultralytics.data.build.InfiniteDataLoader, optional): Dataset for INT8 calibration. Defaults to None.
75
- metadata (dict, optional): Metadata to include in the engine file. Defaults to None.
76
- verbose (bool, optional): Enable verbose logging. Defaults to False.
77
- prefix (str, optional): Prefix for log messages. Defaults to "".
69
+ workspace (int, optional): Workspace size in GB for TensorRT.
70
+ half (bool, optional): Enable FP16 precision.
71
+ int8 (bool, optional): Enable INT8 precision.
72
+ dynamic (bool, optional): Enable dynamic input shapes.
73
+ shape (Tuple[int, int, int, int], optional): Input shape (batch, channels, height, width).
74
+ dla (int, optional): DLA core to use (Jetson devices only).
75
+ dataset (ultralytics.data.build.InfiniteDataLoader, optional): Dataset for INT8 calibration.
76
+ metadata (Dict, optional): Metadata to include in the engine file.
77
+ verbose (bool, optional): Enable verbose logging.
78
+ prefix (str, optional): Prefix for log messages.
78
79
 
79
80
  Raises:
80
81
  ValueError: If DLA is enabled on non-Jetson devices or required precision is not set.
81
82
  RuntimeError: If the ONNX file cannot be parsed.
82
83
 
83
84
  Notes:
84
- - TensorRT version compatibility is handled for workspace size and engine building.
85
- - INT8 calibration requires a dataset and generates a calibration cache.
86
- - Metadata is serialized and written to the engine file if provided.
85
+ TensorRT version compatibility is handled for workspace size and engine building.
86
+ INT8 calibration requires a dataset and generates a calibration cache.
87
+ Metadata is serialized and written to the engine file if provided.
87
88
  """
88
89
  import tensorrt as trt # noqa
89
90
 
@@ -151,12 +152,24 @@ def export_engine(
151
152
 
152
153
  class EngineCalibrator(trt.IInt8Calibrator):
153
154
  """
154
- Custom INT8 calibrator for TensorRT.
155
+ Custom INT8 calibrator for TensorRT engine optimization.
155
156
 
156
- Args:
157
- dataset (object): Dataset for calibration.
157
+ This calibrator provides the necessary interface for TensorRT to perform INT8 quantization calibration
158
+ using a dataset. It handles batch generation, caching, and calibration algorithm selection.
159
+
160
+ Attributes:
161
+ dataset: Dataset for calibration.
162
+ data_iter: Iterator over the calibration dataset.
163
+ algo (trt.CalibrationAlgoType): Calibration algorithm type.
158
164
  batch (int): Batch size for calibration.
159
- cache (str, optional): Path to save the calibration cache. Defaults to "".
165
+ cache (Path): Path to save the calibration cache.
166
+
167
+ Methods:
168
+ get_algorithm: Get the calibration algorithm to use.
169
+ get_batch_size: Get the batch size to use for calibration.
170
+ get_batch: Get the next batch to use for calibration.
171
+ read_calibration_cache: Use existing cache instead of calibrating again.
172
+ write_calibration_cache: Write calibration cache to disk.
160
173
  """
161
174
 
162
175
  def __init__(
@@ -164,6 +177,7 @@ def export_engine(
164
177
  dataset, # ultralytics.data.build.InfiniteDataLoader
165
178
  cache: str = "",
166
179
  ) -> None:
180
+ """Initialize the INT8 calibrator with dataset and cache path."""
167
181
  trt.IInt8Calibrator.__init__(self)
168
182
  self.dataset = dataset
169
183
  self.data_iter = iter(dataset)
@@ -179,22 +193,22 @@ def export_engine(
179
193
  """Get the batch size to use for calibration."""
180
194
  return self.batch or 1
181
195
 
182
- def get_batch(self, names) -> list:
196
+ def get_batch(self, names) -> Optional[List[int]]:
183
197
  """Get the next batch to use for calibration, as a list of device memory pointers."""
184
198
  try:
185
199
  im0s = next(self.data_iter)["img"] / 255.0
186
200
  im0s = im0s.to("cuda") if im0s.device.type == "cpu" else im0s
187
201
  return [int(im0s.data_ptr())]
188
202
  except StopIteration:
189
- # Return [] or None, signal to TensorRT there is no calibration data remaining
203
+ # Return None to signal to TensorRT there is no calibration data remaining
190
204
  return None
191
205
 
192
- def read_calibration_cache(self) -> bytes:
206
+ def read_calibration_cache(self) -> Optional[bytes]:
193
207
  """Use existing cache instead of calibrating again, otherwise, implicitly return None."""
194
208
  if self.cache.exists() and self.cache.suffix == ".cache":
195
209
  return self.cache.read_bytes()
196
210
 
197
- def write_calibration_cache(self, cache) -> None:
211
+ def write_calibration_cache(self, cache: bytes) -> None:
198
212
  """Write calibration cache to disk."""
199
213
  _ = self.cache.write_bytes(cache)
200
214
 
@@ -8,6 +8,7 @@ import tempfile
8
8
  from contextlib import contextmanager
9
9
  from datetime import datetime
10
10
  from pathlib import Path
11
+ from typing import Union
11
12
 
12
13
 
13
14
  class WorkingDirectory(contextlib.ContextDecorator):
@@ -38,22 +39,22 @@ class WorkingDirectory(contextlib.ContextDecorator):
38
39
  >>> pass
39
40
  """
40
41
 
41
- def __init__(self, new_dir):
42
- """Sets the working directory to 'new_dir' upon instantiation for use with context managers or decorators."""
42
+ def __init__(self, new_dir: Union[str, Path]):
43
+ """Initialize the WorkingDirectory context manager with the target directory."""
43
44
  self.dir = new_dir # new dir
44
45
  self.cwd = Path.cwd().resolve() # current dir
45
46
 
46
47
  def __enter__(self):
47
- """Changes the current working directory to the specified directory upon entering the context."""
48
+ """Change the current working directory to the specified directory upon entering the context."""
48
49
  os.chdir(self.dir)
49
50
 
50
51
  def __exit__(self, exc_type, exc_val, exc_tb): # noqa
51
- """Restores the original working directory when exiting the context."""
52
+ """Restore the original working directory when exiting the context."""
52
53
  os.chdir(self.cwd)
53
54
 
54
55
 
55
56
  @contextmanager
56
- def spaces_in_path(path):
57
+ def spaces_in_path(path: Union[str, Path]):
57
58
  """
58
59
  Context manager to handle paths with spaces in their names.
59
60
 
@@ -64,7 +65,8 @@ def spaces_in_path(path):
64
65
  path (str | Path): The original path that may contain spaces.
65
66
 
66
67
  Yields:
67
- (Path | str): Temporary path with spaces replaced by underscores if spaces were present, otherwise the original path.
68
+ (Path | str): Temporary path with spaces replaced by underscores if spaces were present, otherwise the
69
+ original path.
68
70
 
69
71
  Examples:
70
72
  >>> with spaces_in_path('/path/with spaces') as new_path:
@@ -82,7 +84,6 @@ def spaces_in_path(path):
82
84
 
83
85
  # Copy file/directory
84
86
  if path.is_dir():
85
- # tmp_path.mkdir(parents=True, exist_ok=True)
86
87
  shutil.copytree(path, tmp_path)
87
88
  elif path.is_file():
88
89
  tmp_path.parent.mkdir(parents=True, exist_ok=True)
@@ -104,7 +105,7 @@ def spaces_in_path(path):
104
105
  yield path
105
106
 
106
107
 
107
- def increment_path(path, exist_ok=False, sep="", mkdir=False):
108
+ def increment_path(path: Union[str, Path], exist_ok: bool = False, sep: str = "", mkdir: bool = False) -> Path:
108
109
  """
109
110
  Increment a file or directory path, i.e., runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
110
111
 
@@ -114,9 +115,9 @@ def increment_path(path, exist_ok=False, sep="", mkdir=False):
114
115
 
115
116
  Args:
116
117
  path (str | Path): Path to increment.
117
- exist_ok (bool): If True, the path will not be incremented and returned as-is.
118
- sep (str): Separator to use between the path and the incrementation number.
119
- mkdir (bool): Create a directory if it does not exist.
118
+ exist_ok (bool, optional): If True, the path will not be incremented and returned as-is.
119
+ sep (str, optional): Separator to use between the path and the incrementation number.
120
+ mkdir (bool, optional): Create a directory if it does not exist.
120
121
 
121
122
  Returns:
122
123
  (Path): Incremented path.
@@ -152,20 +153,20 @@ def increment_path(path, exist_ok=False, sep="", mkdir=False):
152
153
  return path
153
154
 
154
155
 
155
- def file_age(path=__file__):
156
+ def file_age(path: Union[str, Path] = __file__) -> int:
156
157
  """Return days since the last modification of the specified file."""
157
158
  dt = datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime) # delta
158
159
  return dt.days # + dt.seconds / 86400 # fractional days
159
160
 
160
161
 
161
- def file_date(path=__file__):
162
- """Returns the file modification date in 'YYYY-M-D' format."""
162
+ def file_date(path: Union[str, Path] = __file__) -> str:
163
+ """Return the file modification date in 'YYYY-M-D' format."""
163
164
  t = datetime.fromtimestamp(Path(path).stat().st_mtime)
164
165
  return f"{t.year}-{t.month}-{t.day}"
165
166
 
166
167
 
167
- def file_size(path):
168
- """Returns the size of a file or directory in megabytes (MB)."""
168
+ def file_size(path: Union[str, Path]) -> float:
169
+ """Return the size of a file or directory in megabytes (MB)."""
169
170
  if isinstance(path, (str, Path)):
170
171
  mb = 1 << 20 # bytes to MiB (1024 ** 2)
171
172
  path = Path(path)
@@ -176,20 +177,20 @@ def file_size(path):
176
177
  return 0.0
177
178
 
178
179
 
179
- def get_latest_run(search_dir="."):
180
- """Returns the path to the most recent 'last.pt' file in the specified directory for resuming training."""
180
+ def get_latest_run(search_dir: str = ".") -> str:
181
+ """Return the path to the most recent 'last.pt' file in the specified directory for resuming training."""
181
182
  last_list = glob.glob(f"{search_dir}/**/last*.pt", recursive=True)
182
183
  return max(last_list, key=os.path.getctime) if last_list else ""
183
184
 
184
185
 
185
- def update_models(model_names=("yolo11n.pt",), source_dir=Path("."), update_names=False):
186
+ def update_models(model_names: tuple = ("yolo11n.pt",), source_dir: Path = Path("."), update_names: bool = False):
186
187
  """
187
188
  Update and re-save specified YOLO models in an 'updated_models' subdirectory.
188
189
 
189
190
  Args:
190
- model_names (Tuple[str, ...]): Model filenames to update.
191
- source_dir (Path): Directory containing models and target subdirectory.
192
- update_names (bool): Update model names from a data YAML.
191
+ model_names (tuple, optional): Model filenames to update.
192
+ source_dir (Path, optional): Directory containing models and target subdirectory.
193
+ update_names (bool, optional): Update model names from a data YAML.
193
194
 
194
195
  Examples:
195
196
  Update specified YOLO models and save them in 'updated_models' subdirectory: