dgenerate-ultralytics-headless 8.3.196__py3-none-any.whl → 8.3.248__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (243) hide show
  1. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +33 -34
  2. dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
  3. tests/__init__.py +5 -7
  4. tests/conftest.py +8 -15
  5. tests/test_cli.py +8 -10
  6. tests/test_cuda.py +9 -10
  7. tests/test_engine.py +29 -2
  8. tests/test_exports.py +69 -21
  9. tests/test_integrations.py +8 -11
  10. tests/test_python.py +109 -71
  11. tests/test_solutions.py +170 -159
  12. ultralytics/__init__.py +27 -9
  13. ultralytics/cfg/__init__.py +57 -64
  14. ultralytics/cfg/datasets/Argoverse.yaml +7 -6
  15. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  16. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  17. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  18. ultralytics/cfg/datasets/Objects365.yaml +19 -15
  19. ultralytics/cfg/datasets/SKU-110K.yaml +1 -1
  20. ultralytics/cfg/datasets/VOC.yaml +19 -21
  21. ultralytics/cfg/datasets/VisDrone.yaml +5 -5
  22. ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
  23. ultralytics/cfg/datasets/coco-pose.yaml +24 -2
  24. ultralytics/cfg/datasets/coco.yaml +2 -2
  25. ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
  26. ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
  27. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  28. ultralytics/cfg/datasets/dog-pose.yaml +28 -0
  29. ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
  30. ultralytics/cfg/datasets/dota8.yaml +2 -2
  31. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
  32. ultralytics/cfg/datasets/kitti.yaml +27 -0
  33. ultralytics/cfg/datasets/lvis.yaml +7 -7
  34. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  35. ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
  36. ultralytics/cfg/datasets/xView.yaml +16 -16
  37. ultralytics/cfg/default.yaml +96 -94
  38. ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
  39. ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
  40. ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
  41. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
  42. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
  43. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
  44. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
  45. ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
  46. ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
  47. ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
  48. ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
  49. ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
  50. ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
  51. ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
  52. ultralytics/cfg/models/v6/yolov6.yaml +1 -1
  53. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
  54. ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
  55. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
  56. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
  57. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
  58. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
  59. ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
  60. ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
  61. ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
  62. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
  63. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
  64. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
  65. ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
  66. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
  67. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  68. ultralytics/cfg/trackers/botsort.yaml +16 -17
  69. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  70. ultralytics/data/__init__.py +4 -4
  71. ultralytics/data/annotator.py +3 -4
  72. ultralytics/data/augment.py +286 -476
  73. ultralytics/data/base.py +18 -26
  74. ultralytics/data/build.py +151 -26
  75. ultralytics/data/converter.py +38 -50
  76. ultralytics/data/dataset.py +47 -75
  77. ultralytics/data/loaders.py +42 -49
  78. ultralytics/data/split.py +5 -6
  79. ultralytics/data/split_dota.py +8 -15
  80. ultralytics/data/utils.py +41 -45
  81. ultralytics/engine/exporter.py +462 -462
  82. ultralytics/engine/model.py +150 -191
  83. ultralytics/engine/predictor.py +30 -40
  84. ultralytics/engine/results.py +177 -311
  85. ultralytics/engine/trainer.py +193 -120
  86. ultralytics/engine/tuner.py +77 -63
  87. ultralytics/engine/validator.py +39 -22
  88. ultralytics/hub/__init__.py +16 -19
  89. ultralytics/hub/auth.py +6 -12
  90. ultralytics/hub/google/__init__.py +7 -10
  91. ultralytics/hub/session.py +15 -25
  92. ultralytics/hub/utils.py +5 -8
  93. ultralytics/models/__init__.py +1 -1
  94. ultralytics/models/fastsam/__init__.py +1 -1
  95. ultralytics/models/fastsam/model.py +8 -10
  96. ultralytics/models/fastsam/predict.py +19 -30
  97. ultralytics/models/fastsam/utils.py +1 -2
  98. ultralytics/models/fastsam/val.py +5 -7
  99. ultralytics/models/nas/__init__.py +1 -1
  100. ultralytics/models/nas/model.py +5 -8
  101. ultralytics/models/nas/predict.py +7 -9
  102. ultralytics/models/nas/val.py +1 -2
  103. ultralytics/models/rtdetr/__init__.py +1 -1
  104. ultralytics/models/rtdetr/model.py +7 -8
  105. ultralytics/models/rtdetr/predict.py +15 -19
  106. ultralytics/models/rtdetr/train.py +10 -13
  107. ultralytics/models/rtdetr/val.py +21 -23
  108. ultralytics/models/sam/__init__.py +15 -2
  109. ultralytics/models/sam/amg.py +14 -20
  110. ultralytics/models/sam/build.py +26 -19
  111. ultralytics/models/sam/build_sam3.py +377 -0
  112. ultralytics/models/sam/model.py +29 -32
  113. ultralytics/models/sam/modules/blocks.py +83 -144
  114. ultralytics/models/sam/modules/decoders.py +22 -40
  115. ultralytics/models/sam/modules/encoders.py +44 -101
  116. ultralytics/models/sam/modules/memory_attention.py +16 -30
  117. ultralytics/models/sam/modules/sam.py +206 -79
  118. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  119. ultralytics/models/sam/modules/transformer.py +18 -28
  120. ultralytics/models/sam/modules/utils.py +174 -50
  121. ultralytics/models/sam/predict.py +2268 -366
  122. ultralytics/models/sam/sam3/__init__.py +3 -0
  123. ultralytics/models/sam/sam3/decoder.py +546 -0
  124. ultralytics/models/sam/sam3/encoder.py +529 -0
  125. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  126. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  127. ultralytics/models/sam/sam3/model_misc.py +199 -0
  128. ultralytics/models/sam/sam3/necks.py +129 -0
  129. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  130. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  131. ultralytics/models/sam/sam3/vitdet.py +547 -0
  132. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  133. ultralytics/models/utils/loss.py +14 -26
  134. ultralytics/models/utils/ops.py +13 -17
  135. ultralytics/models/yolo/__init__.py +1 -1
  136. ultralytics/models/yolo/classify/predict.py +9 -12
  137. ultralytics/models/yolo/classify/train.py +15 -41
  138. ultralytics/models/yolo/classify/val.py +34 -32
  139. ultralytics/models/yolo/detect/predict.py +8 -11
  140. ultralytics/models/yolo/detect/train.py +13 -32
  141. ultralytics/models/yolo/detect/val.py +75 -63
  142. ultralytics/models/yolo/model.py +37 -53
  143. ultralytics/models/yolo/obb/predict.py +5 -14
  144. ultralytics/models/yolo/obb/train.py +11 -14
  145. ultralytics/models/yolo/obb/val.py +42 -39
  146. ultralytics/models/yolo/pose/__init__.py +1 -1
  147. ultralytics/models/yolo/pose/predict.py +7 -22
  148. ultralytics/models/yolo/pose/train.py +10 -22
  149. ultralytics/models/yolo/pose/val.py +40 -59
  150. ultralytics/models/yolo/segment/predict.py +16 -20
  151. ultralytics/models/yolo/segment/train.py +3 -12
  152. ultralytics/models/yolo/segment/val.py +106 -56
  153. ultralytics/models/yolo/world/train.py +12 -16
  154. ultralytics/models/yolo/world/train_world.py +11 -34
  155. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  156. ultralytics/models/yolo/yoloe/predict.py +16 -23
  157. ultralytics/models/yolo/yoloe/train.py +31 -56
  158. ultralytics/models/yolo/yoloe/train_seg.py +5 -10
  159. ultralytics/models/yolo/yoloe/val.py +16 -21
  160. ultralytics/nn/__init__.py +7 -7
  161. ultralytics/nn/autobackend.py +152 -80
  162. ultralytics/nn/modules/__init__.py +60 -60
  163. ultralytics/nn/modules/activation.py +4 -6
  164. ultralytics/nn/modules/block.py +133 -217
  165. ultralytics/nn/modules/conv.py +52 -97
  166. ultralytics/nn/modules/head.py +64 -116
  167. ultralytics/nn/modules/transformer.py +79 -89
  168. ultralytics/nn/modules/utils.py +16 -21
  169. ultralytics/nn/tasks.py +111 -156
  170. ultralytics/nn/text_model.py +40 -67
  171. ultralytics/solutions/__init__.py +12 -12
  172. ultralytics/solutions/ai_gym.py +11 -17
  173. ultralytics/solutions/analytics.py +15 -16
  174. ultralytics/solutions/config.py +5 -6
  175. ultralytics/solutions/distance_calculation.py +10 -13
  176. ultralytics/solutions/heatmap.py +7 -13
  177. ultralytics/solutions/instance_segmentation.py +5 -8
  178. ultralytics/solutions/object_blurrer.py +7 -10
  179. ultralytics/solutions/object_counter.py +12 -19
  180. ultralytics/solutions/object_cropper.py +8 -14
  181. ultralytics/solutions/parking_management.py +33 -31
  182. ultralytics/solutions/queue_management.py +10 -12
  183. ultralytics/solutions/region_counter.py +9 -12
  184. ultralytics/solutions/security_alarm.py +15 -20
  185. ultralytics/solutions/similarity_search.py +13 -17
  186. ultralytics/solutions/solutions.py +75 -74
  187. ultralytics/solutions/speed_estimation.py +7 -10
  188. ultralytics/solutions/streamlit_inference.py +4 -7
  189. ultralytics/solutions/templates/similarity-search.html +7 -18
  190. ultralytics/solutions/trackzone.py +7 -10
  191. ultralytics/solutions/vision_eye.py +5 -8
  192. ultralytics/trackers/__init__.py +1 -1
  193. ultralytics/trackers/basetrack.py +3 -5
  194. ultralytics/trackers/bot_sort.py +10 -27
  195. ultralytics/trackers/byte_tracker.py +14 -30
  196. ultralytics/trackers/track.py +3 -6
  197. ultralytics/trackers/utils/gmc.py +11 -22
  198. ultralytics/trackers/utils/kalman_filter.py +37 -48
  199. ultralytics/trackers/utils/matching.py +12 -15
  200. ultralytics/utils/__init__.py +116 -116
  201. ultralytics/utils/autobatch.py +2 -4
  202. ultralytics/utils/autodevice.py +17 -18
  203. ultralytics/utils/benchmarks.py +70 -70
  204. ultralytics/utils/callbacks/base.py +8 -10
  205. ultralytics/utils/callbacks/clearml.py +5 -13
  206. ultralytics/utils/callbacks/comet.py +32 -46
  207. ultralytics/utils/callbacks/dvc.py +13 -18
  208. ultralytics/utils/callbacks/mlflow.py +4 -5
  209. ultralytics/utils/callbacks/neptune.py +7 -15
  210. ultralytics/utils/callbacks/platform.py +314 -38
  211. ultralytics/utils/callbacks/raytune.py +3 -4
  212. ultralytics/utils/callbacks/tensorboard.py +23 -31
  213. ultralytics/utils/callbacks/wb.py +10 -13
  214. ultralytics/utils/checks.py +151 -87
  215. ultralytics/utils/cpu.py +3 -8
  216. ultralytics/utils/dist.py +19 -15
  217. ultralytics/utils/downloads.py +29 -41
  218. ultralytics/utils/errors.py +6 -14
  219. ultralytics/utils/events.py +2 -4
  220. ultralytics/utils/export/__init__.py +7 -0
  221. ultralytics/utils/{export.py → export/engine.py} +16 -16
  222. ultralytics/utils/export/imx.py +325 -0
  223. ultralytics/utils/export/tensorflow.py +231 -0
  224. ultralytics/utils/files.py +24 -28
  225. ultralytics/utils/git.py +9 -11
  226. ultralytics/utils/instance.py +30 -51
  227. ultralytics/utils/logger.py +212 -114
  228. ultralytics/utils/loss.py +15 -24
  229. ultralytics/utils/metrics.py +131 -160
  230. ultralytics/utils/nms.py +21 -30
  231. ultralytics/utils/ops.py +107 -165
  232. ultralytics/utils/patches.py +33 -21
  233. ultralytics/utils/plotting.py +122 -119
  234. ultralytics/utils/tal.py +28 -44
  235. ultralytics/utils/torch_utils.py +70 -187
  236. ultralytics/utils/tqdm.py +20 -20
  237. ultralytics/utils/triton.py +13 -19
  238. ultralytics/utils/tuner.py +17 -5
  239. dgenerate_ultralytics_headless-8.3.196.dist-info/RECORD +0 -281
  240. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
  241. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
  242. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
  243. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
@@ -10,7 +10,7 @@ from multiprocessing.pool import ThreadPool
10
10
  from pathlib import Path
11
11
  from urllib import parse, request
12
12
 
13
- from ultralytics.utils import LOGGER, TQDM, checks, clean_url, emojis, is_online, url2file
13
+ from ultralytics.utils import ASSETS_URL, LOGGER, TQDM, checks, clean_url, emojis, is_online, url2file
14
14
 
15
15
  # Define Ultralytics GitHub assets maintained at https://github.com/ultralytics/assets
16
16
  GITHUB_ASSETS_REPO = "ultralytics/assets"
@@ -43,8 +43,7 @@ GITHUB_ASSETS_STEMS = frozenset(k.rpartition(".")[0] for k in GITHUB_ASSETS_NAME
43
43
 
44
44
 
45
45
  def is_url(url: str | Path, check: bool = False) -> bool:
46
- """
47
- Validate if the given string is a URL and optionally check if the URL exists online.
46
+ """Validate if the given string is a URL and optionally check if the URL exists online.
48
47
 
49
48
  Args:
50
49
  url (str): The string to be validated as a URL.
@@ -60,18 +59,18 @@ def is_url(url: str | Path, check: bool = False) -> bool:
60
59
  try:
61
60
  url = str(url)
62
61
  result = parse.urlparse(url)
63
- assert all([result.scheme, result.netloc]) # check if is url
62
+ if not (result.scheme and result.netloc):
63
+ return False
64
64
  if check:
65
- with request.urlopen(url) as response:
66
- return response.getcode() == 200 # check if exists online
65
+ r = request.urlopen(request.Request(url, method="HEAD"), timeout=3)
66
+ return 200 <= r.getcode() < 400
67
67
  return True
68
68
  except Exception:
69
69
  return False
70
70
 
71
71
 
72
72
  def delete_dsstore(path: str | Path, files_to_delete: tuple[str, ...] = (".DS_Store", "__MACOSX")) -> None:
73
- """
74
- Delete all specified system files in a directory.
73
+ """Delete all specified system files in a directory.
75
74
 
76
75
  Args:
77
76
  path (str | Path): The directory path where the files should be deleted.
@@ -82,7 +81,7 @@ def delete_dsstore(path: str | Path, files_to_delete: tuple[str, ...] = (".DS_St
82
81
  >>> delete_dsstore("path/to/dir")
83
82
 
84
83
  Notes:
85
- ".DS_store" files are created by the Apple operating system and contain metadata about folders and files. They
84
+ ".DS_Store" files are created by the Apple operating system and contain metadata about folders and files. They
86
85
  are hidden system files and can cause issues when transferring files between different operating systems.
87
86
  """
88
87
  for file in files_to_delete:
@@ -98,8 +97,7 @@ def zip_directory(
98
97
  exclude: tuple[str, ...] = (".DS_Store", "__MACOSX"),
99
98
  progress: bool = True,
100
99
  ) -> Path:
101
- """
102
- Zip the contents of a directory, excluding specified files.
100
+ """Zip the contents of a directory, excluding specified files.
103
101
 
104
102
  The resulting zip file is named after the directory and placed alongside it.
105
103
 
@@ -141,12 +139,11 @@ def unzip_file(
141
139
  exist_ok: bool = False,
142
140
  progress: bool = True,
143
141
  ) -> Path:
144
- """
145
- Unzip a *.zip file to the specified path, excluding specified files.
142
+ """Unzip a *.zip file to the specified path, excluding specified files.
146
143
 
147
- If the zipfile does not contain a single top-level directory, the function will create a new
148
- directory with the same name as the zipfile (without the extension) to extract its contents.
149
- If a path is not provided, the function will use the parent directory of the zipfile as the default path.
144
+ If the zipfile does not contain a single top-level directory, the function will create a new directory with the same
145
+ name as the zipfile (without the extension) to extract its contents. If a path is not provided, the function will
146
+ use the parent directory of the zipfile as the default path.
150
147
 
151
148
  Args:
152
149
  file (str | Path): The path to the zipfile to be extracted.
@@ -182,7 +179,7 @@ def unzip_file(
182
179
  if unzip_as_dir:
183
180
  # Zip has 1 top-level directory
184
181
  extract_path = path # i.e. ../datasets
185
- path = Path(path) / list(top_level_dirs)[0] # i.e. extract coco8/ dir to ../datasets/
182
+ path = Path(path) / next(iter(top_level_dirs)) # i.e. extract coco8/ dir to ../datasets/
186
183
  else:
187
184
  # Zip has multiple files at top level
188
185
  path = extract_path = Path(path) / Path(file).stem # i.e. extract multiple files to ../datasets/coco8/
@@ -209,8 +206,7 @@ def check_disk_space(
209
206
  sf: float = 1.5,
210
207
  hard: bool = True,
211
208
  ) -> bool:
212
- """
213
- Check if there is sufficient disk space to download and store a file.
209
+ """Check if there is sufficient disk space to download and store a file.
214
210
 
215
211
  Args:
216
212
  file_bytes (int): The file size in bytes.
@@ -221,7 +217,7 @@ def check_disk_space(
221
217
  Returns:
222
218
  (bool): True if there is sufficient disk space, False otherwise.
223
219
  """
224
- total, used, free = shutil.disk_usage(path) # bytes
220
+ _total, _used, free = shutil.disk_usage(path) # bytes
225
221
  if file_bytes * sf < free:
226
222
  return True # sufficient space
227
223
 
@@ -237,8 +233,7 @@ def check_disk_space(
237
233
 
238
234
 
239
235
  def get_google_drive_file_info(link: str) -> tuple[str, str | None]:
240
- """
241
- Retrieve the direct download link and filename for a shareable Google Drive file link.
236
+ """Retrieve the direct download link and filename for a shareable Google Drive file link.
242
237
 
243
238
  Args:
244
239
  link (str): The shareable link of the Google Drive file.
@@ -288,16 +283,15 @@ def safe_download(
288
283
  exist_ok: bool = False,
289
284
  progress: bool = True,
290
285
  ) -> Path | str:
291
- """
292
- Download files from a URL with options for retrying, unzipping, and deleting the downloaded file. Enhanced with
286
+ """Download files from a URL with options for retrying, unzipping, and deleting the downloaded file. Enhanced with
293
287
  robust partial download detection using Content-Length validation.
294
288
 
295
289
  Args:
296
290
  url (str): The URL of the file to be downloaded.
297
- file (str, optional): The filename of the downloaded file.
298
- If not provided, the file will be saved with the same name as the URL.
299
- dir (str | Path, optional): The directory to save the downloaded file.
300
- If not provided, the file will be saved in the current working directory.
291
+ file (str, optional): The filename of the downloaded file. If not provided, the file will be saved with the same
292
+ name as the URL.
293
+ dir (str | Path, optional): The directory to save the downloaded file. If not provided, the file will be saved
294
+ in the current working directory.
301
295
  unzip (bool, optional): Whether to unzip the downloaded file.
302
296
  delete (bool, optional): Whether to delete the downloaded file after unzipping.
303
297
  curl (bool, optional): Whether to use curl command line tool for downloading.
@@ -323,10 +317,7 @@ def safe_download(
323
317
  if "://" not in str(url) and Path(url).is_file(): # URL exists ('://' check required in Windows Python<3.10)
324
318
  f = Path(url) # filename
325
319
  elif not f.is_file(): # URL and file do not exist
326
- uri = (url if gdrive else clean_url(url)).replace( # cleaned and aliased url
327
- "https://github.com/ultralytics/assets/releases/download/v0.0.0/",
328
- "https://ultralytics.com/assets/", # assets alias
329
- )
320
+ uri = (url if gdrive else clean_url(url)).replace(ASSETS_URL, "https://ultralytics.com/assets") # clean
330
321
  desc = f"Downloading {uri} to '{f}'"
331
322
  f.parent.mkdir(parents=True, exist_ok=True) # make directory if missing
332
323
  curl_installed = shutil.which("curl")
@@ -374,10 +365,10 @@ def safe_download(
374
365
  raise # Re-raise immediately - no point retrying if insufficient disk space
375
366
  except Exception as e:
376
367
  if i == 0 and not is_online():
377
- raise ConnectionError(emojis(f"❌ Download failure for {uri}. Environment is not online.")) from e
368
+ raise ConnectionError(emojis(f"❌ Download failure for {uri}. Environment may be offline.")) from e
378
369
  elif i >= retry:
379
- raise ConnectionError(emojis(f"❌ Download failure for {uri}. Retry limit reached.")) from e
380
- LOGGER.warning(f"Download failure, retrying {i + 1}/{retry} {uri}...")
370
+ raise ConnectionError(emojis(f"❌ Download failure for {uri}. Retry limit reached. {e}")) from e
371
+ LOGGER.warning(f"Download failure, retrying {i + 1}/{retry} {uri}... {e}")
381
372
 
382
373
  if unzip and f.exists() and f.suffix in {"", ".zip", ".tar", ".gz"}:
383
374
  from zipfile import is_zipfile
@@ -399,8 +390,7 @@ def get_github_assets(
399
390
  version: str = "latest",
400
391
  retry: bool = False,
401
392
  ) -> tuple[str, list[str]]:
402
- """
403
- Retrieve the specified version's tag and assets from a GitHub repository.
393
+ """Retrieve the specified version's tag and assets from a GitHub repository.
404
394
 
405
395
  If the version is not specified, the function fetches the latest release assets.
406
396
 
@@ -437,8 +427,7 @@ def attempt_download_asset(
437
427
  release: str = "v8.3.0",
438
428
  **kwargs,
439
429
  ) -> str:
440
- """
441
- Attempt to download a file from GitHub release assets if it is not found locally.
430
+ """Attempt to download a file from GitHub release assets if it is not found locally.
442
431
 
443
432
  Args:
444
433
  file (str | Path): The filename or file path to be downloaded.
@@ -497,8 +486,7 @@ def download(
497
486
  retry: int = 3,
498
487
  exist_ok: bool = False,
499
488
  ) -> None:
500
- """
501
- Download files from specified URLs to a given directory.
489
+ """Download files from specified URLs to a given directory.
502
490
 
503
491
  Supports concurrent downloads if multiple threads are specified.
504
492
 
@@ -4,11 +4,10 @@ from ultralytics.utils import emojis
4
4
 
5
5
 
6
6
  class HUBModelError(Exception):
7
- """
8
- Exception raised when a model cannot be found or retrieved from Ultralytics HUB.
7
+ """Exception raised when a model cannot be found or retrieved from Ultralytics HUB.
9
8
 
10
- This custom exception is used specifically for handling errors related to model fetching in Ultralytics YOLO.
11
- The error message is processed to include emojis for better user experience.
9
+ This custom exception is used specifically for handling errors related to model fetching in Ultralytics YOLO. The
10
+ error message is processed to include emojis for better user experience.
12
11
 
13
12
  Attributes:
14
13
  message (str): The error message displayed when the exception is raised.
@@ -25,19 +24,12 @@ class HUBModelError(Exception):
25
24
  """
26
25
 
27
26
  def __init__(self, message: str = "Model not found. Please check model URL and try again."):
28
- """
29
- Initialize a HUBModelError exception.
27
+ """Initialize a HUBModelError exception.
30
28
 
31
- This exception is raised when a requested model is not found or cannot be retrieved from Ultralytics HUB.
32
- The message is processed to include emojis for better user experience.
29
+ This exception is raised when a requested model is not found or cannot be retrieved from Ultralytics HUB. The
30
+ message is processed to include emojis for better user experience.
33
31
 
34
32
  Args:
35
33
  message (str, optional): The error message to display when the exception is raised.
36
-
37
- Examples:
38
- >>> try:
39
- ... raise HUBModelError("Custom model error message")
40
- ... except HUBModelError as e:
41
- ... print(e)
42
34
  """
43
35
  super().__init__(emojis(message))
@@ -24,8 +24,7 @@ def _post(url: str, data: dict, timeout: float = 5.0) -> None:
24
24
 
25
25
 
26
26
  class Events:
27
- """
28
- Collect and send anonymous usage analytics with rate-limiting.
27
+ """Collect and send anonymous usage analytics with rate-limiting.
29
28
 
30
29
  Event collection and transmission are enabled when sync is enabled in settings, the current process is rank -1 or 0,
31
30
  tests are not running, the environment is online, and the installation source is either pip or the official
@@ -71,8 +70,7 @@ class Events:
71
70
  )
72
71
 
73
72
  def __call__(self, cfg, device=None) -> None:
74
- """
75
- Queue an event and flush the queue asynchronously when the rate limit elapses.
73
+ """Queue an event and flush the queue asynchronously when the rate limit elapses.
76
74
 
77
75
  Args:
78
76
  cfg (IterableSimpleNamespace): The configuration object containing mode and task information.
@@ -0,0 +1,7 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from .engine import onnx2engine, torch2onnx
4
+ from .imx import torch2imx
5
+ from .tensorflow import keras2pb, onnx2saved_model, pb2tfjs, tflite2edgetpu
6
+
7
+ __all__ = ["keras2pb", "onnx2engine", "onnx2saved_model", "pb2tfjs", "tflite2edgetpu", "torch2imx", "torch2onnx"]
@@ -8,9 +8,10 @@ from pathlib import Path
8
8
  import torch
9
9
 
10
10
  from ultralytics.utils import IS_JETSON, LOGGER
11
+ from ultralytics.utils.torch_utils import TORCH_2_4
11
12
 
12
13
 
13
- def export_onnx(
14
+ def torch2onnx(
14
15
  torch_model: torch.nn.Module,
15
16
  im: torch.Tensor,
16
17
  onnx_file: str,
@@ -19,8 +20,7 @@ def export_onnx(
19
20
  output_names: list[str] = ["output0"],
20
21
  dynamic: bool | dict = False,
21
22
  ) -> None:
22
- """
23
- Export a PyTorch model to ONNX format.
23
+ """Export a PyTorch model to ONNX format.
24
24
 
25
25
  Args:
26
26
  torch_model (torch.nn.Module): The PyTorch model to export.
@@ -34,6 +34,7 @@ def export_onnx(
34
34
  Notes:
35
35
  Setting `do_constant_folding=True` may cause issues with DNN inference for torch>=1.12.
36
36
  """
37
+ kwargs = {"dynamo": False} if TORCH_2_4 else {}
37
38
  torch.onnx.export(
38
39
  torch_model,
39
40
  im,
@@ -44,10 +45,11 @@ def export_onnx(
44
45
  input_names=input_names,
45
46
  output_names=output_names,
46
47
  dynamic_axes=dynamic or None,
48
+ **kwargs,
47
49
  )
48
50
 
49
51
 
50
- def export_engine(
52
+ def onnx2engine(
51
53
  onnx_file: str,
52
54
  engine_file: str | None = None,
53
55
  workspace: int | None = None,
@@ -61,8 +63,7 @@ def export_engine(
61
63
  verbose: bool = False,
62
64
  prefix: str = "",
63
65
  ) -> None:
64
- """
65
- Export a YOLO model to TensorRT engine format.
66
+ """Export a YOLO model to TensorRT engine format.
66
67
 
67
68
  Args:
68
69
  onnx_file (str): Path to the ONNX file to be converted.
@@ -87,7 +88,7 @@ def export_engine(
87
88
  INT8 calibration requires a dataset and generates a calibration cache.
88
89
  Metadata is serialized and written to the engine file if provided.
89
90
  """
90
- import tensorrt as trt # noqa
91
+ import tensorrt as trt
91
92
 
92
93
  engine_file = engine_file or Path(onnx_file).with_suffix(".engine")
93
94
 
@@ -98,12 +99,12 @@ def export_engine(
98
99
  # Engine builder
99
100
  builder = trt.Builder(logger)
100
101
  config = builder.create_builder_config()
101
- workspace = int((workspace or 0) * (1 << 30))
102
+ workspace_bytes = int((workspace or 0) * (1 << 30))
102
103
  is_trt10 = int(trt.__version__.split(".", 1)[0]) >= 10 # is TensorRT >= 10
103
- if is_trt10 and workspace > 0:
104
- config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace)
105
- elif workspace > 0: # TensorRT versions 7, 8
106
- config.max_workspace_size = workspace
104
+ if is_trt10 and workspace_bytes > 0:
105
+ config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace_bytes)
106
+ elif workspace_bytes > 0: # TensorRT versions 7, 8
107
+ config.max_workspace_size = workspace_bytes
107
108
  flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
108
109
  network = builder.create_network(flag)
109
110
  half = builder.platform_has_fast_fp16 and half
@@ -151,11 +152,10 @@ def export_engine(
151
152
  config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED
152
153
 
153
154
  class EngineCalibrator(trt.IInt8Calibrator):
154
- """
155
- Custom INT8 calibrator for TensorRT engine optimization.
155
+ """Custom INT8 calibrator for TensorRT engine optimization.
156
156
 
157
- This calibrator provides the necessary interface for TensorRT to perform INT8 quantization calibration
158
- using a dataset. It handles batch generation, caching, and calibration algorithm selection.
157
+ This calibrator provides the necessary interface for TensorRT to perform INT8 quantization calibration using
158
+ a dataset. It handles batch generation, caching, and calibration algorithm selection.
159
159
 
160
160
  Attributes:
161
161
  dataset: Dataset for calibration.
@@ -0,0 +1,325 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from __future__ import annotations
4
+
5
+ import subprocess
6
+ import sys
7
+ import types
8
+ from pathlib import Path
9
+ from shutil import which
10
+
11
+ import numpy as np
12
+ import torch
13
+
14
+ from ultralytics.nn.modules import Detect, Pose, Segment
15
+ from ultralytics.utils import LOGGER, WINDOWS
16
+ from ultralytics.utils.patches import onnx_export_patch
17
+ from ultralytics.utils.tal import make_anchors
18
+ from ultralytics.utils.torch_utils import copy_attr
19
+
20
+ # Configuration for Model Compression Toolkit (MCT) quantization
21
+ MCT_CONFIG = {
22
+ "YOLO11": {
23
+ "detect": {
24
+ "layer_names": ["sub", "mul_2", "add_14", "cat_21"],
25
+ "weights_memory": 2585350.2439,
26
+ "n_layers": 238,
27
+ },
28
+ "pose": {
29
+ "layer_names": ["sub", "mul_2", "add_14", "cat_22", "cat_23", "mul_4", "add_15"],
30
+ "weights_memory": 2437771.67,
31
+ "n_layers": 257,
32
+ },
33
+ "classify": {"layer_names": [], "weights_memory": np.inf, "n_layers": 112},
34
+ "segment": {"layer_names": ["sub", "mul_2", "add_14", "cat_22"], "weights_memory": 2466604.8, "n_layers": 265},
35
+ },
36
+ "YOLOv8": {
37
+ "detect": {"layer_names": ["sub", "mul", "add_6", "cat_17"], "weights_memory": 2550540.8, "n_layers": 168},
38
+ "pose": {
39
+ "layer_names": ["add_7", "mul_2", "cat_19", "mul", "sub", "add_6", "cat_18"],
40
+ "weights_memory": 2482451.85,
41
+ "n_layers": 187,
42
+ },
43
+ "classify": {"layer_names": [], "weights_memory": np.inf, "n_layers": 73},
44
+ "segment": {"layer_names": ["sub", "mul", "add_6", "cat_18"], "weights_memory": 2580060.0, "n_layers": 195},
45
+ },
46
+ }
47
+
48
+
49
+ class FXModel(torch.nn.Module):
50
+ """A custom model class for torch.fx compatibility.
51
+
52
+ This class extends `torch.nn.Module` and is designed to ensure compatibility with torch.fx for tracing and graph
53
+ manipulation. It copies attributes from an existing model and explicitly sets the model attribute to ensure proper
54
+ copying.
55
+
56
+ Attributes:
57
+ model (nn.Module): The original model's layers.
58
+ """
59
+
60
+ def __init__(self, model, imgsz=(640, 640)):
61
+ """Initialize the FXModel.
62
+
63
+ Args:
64
+ model (nn.Module): The original model to wrap for torch.fx compatibility.
65
+ imgsz (tuple[int, int]): The input image size (height, width). Default is (640, 640).
66
+ """
67
+ super().__init__()
68
+ copy_attr(self, model)
69
+ # Explicitly set `model` since `copy_attr` somehow does not copy it.
70
+ self.model = model.model
71
+ self.imgsz = imgsz
72
+
73
+ def forward(self, x):
74
+ """Forward pass through the model.
75
+
76
+ This method performs the forward pass through the model, handling the dependencies between layers and saving
77
+ intermediate outputs.
78
+
79
+ Args:
80
+ x (torch.Tensor): The input tensor to the model.
81
+
82
+ Returns:
83
+ (torch.Tensor): The output tensor from the model.
84
+ """
85
+ y = [] # outputs
86
+ for m in self.model:
87
+ if m.f != -1: # if not from previous layer
88
+ # from earlier layers
89
+ x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]
90
+ if isinstance(m, Detect):
91
+ m._inference = types.MethodType(_inference, m) # bind method to Detect
92
+ m.anchors, m.strides = (
93
+ x.transpose(0, 1)
94
+ for x in make_anchors(
95
+ torch.cat([s / m.stride.unsqueeze(-1) for s in self.imgsz], dim=1), m.stride, 0.5
96
+ )
97
+ )
98
+ if type(m) is Pose:
99
+ m.forward = types.MethodType(pose_forward, m) # bind method to Detect
100
+ if type(m) is Segment:
101
+ m.forward = types.MethodType(segment_forward, m) # bind method to Detect
102
+ x = m(x) # run
103
+ y.append(x) # save output
104
+ return x
105
+
106
+
107
+ def _inference(self, x: list[torch.Tensor]) -> tuple[torch.Tensor]:
108
+ """Decode boxes and cls scores for imx object detection."""
109
+ x_cat = torch.cat([xi.view(x[0].shape[0], self.no, -1) for xi in x], 2)
110
+ box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
111
+ dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
112
+ return dbox.transpose(1, 2), cls.sigmoid().permute(0, 2, 1)
113
+
114
+
115
+ def pose_forward(self, x: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
116
+ """Forward pass for imx pose estimation, including keypoint decoding."""
117
+ bs = x[0].shape[0] # batch size
118
+ kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w)
119
+ x = Detect.forward(self, x)
120
+ pred_kpt = self.kpts_decode(bs, kpt)
121
+ return *x, pred_kpt.permute(0, 2, 1)
122
+
123
+
124
+ def segment_forward(self, x: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
125
+ """Forward pass for imx segmentation."""
126
+ p = self.proto(x[0]) # mask protos
127
+ bs = p.shape[0] # batch size
128
+ mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
129
+ x = Detect.forward(self, x)
130
+ return *x, mc.transpose(1, 2), p
131
+
132
+
133
+ class NMSWrapper(torch.nn.Module):
134
+ """Wrap PyTorch Module with multiclass_nms layer from edge-mdt-cl."""
135
+
136
+ def __init__(
137
+ self,
138
+ model: torch.nn.Module,
139
+ score_threshold: float = 0.001,
140
+ iou_threshold: float = 0.7,
141
+ max_detections: int = 300,
142
+ task: str = "detect",
143
+ ):
144
+ """Initialize NMSWrapper with PyTorch Module and NMS parameters.
145
+
146
+ Args:
147
+ model (torch.nn.Module): Model instance.
148
+ score_threshold (float): Score threshold for non-maximum suppression.
149
+ iou_threshold (float): Intersection over union threshold for non-maximum suppression.
150
+ max_detections (int): The number of detections to return.
151
+ task (str): Task type, either 'detect' or 'pose'.
152
+ """
153
+ super().__init__()
154
+ self.model = model
155
+ self.score_threshold = score_threshold
156
+ self.iou_threshold = iou_threshold
157
+ self.max_detections = max_detections
158
+ self.task = task
159
+
160
+ def forward(self, images):
161
+ """Forward pass with model inference and NMS post-processing."""
162
+ from edgemdt_cl.pytorch.nms.nms_with_indices import multiclass_nms_with_indices
163
+
164
+ # model inference
165
+ outputs = self.model(images)
166
+ boxes, scores = outputs[0], outputs[1]
167
+ nms_outputs = multiclass_nms_with_indices(
168
+ boxes=boxes,
169
+ scores=scores,
170
+ score_threshold=self.score_threshold,
171
+ iou_threshold=self.iou_threshold,
172
+ max_detections=self.max_detections,
173
+ )
174
+ if self.task == "pose":
175
+ kpts = outputs[2] # (bs, max_detections, kpts 17*3)
176
+ out_kpts = torch.gather(kpts, 1, nms_outputs.indices.unsqueeze(-1).expand(-1, -1, kpts.size(-1)))
177
+ return nms_outputs.boxes, nms_outputs.scores, nms_outputs.labels, out_kpts
178
+ if self.task == "segment":
179
+ mc, proto = outputs[2], outputs[3]
180
+ out_mc = torch.gather(mc, 1, nms_outputs.indices.unsqueeze(-1).expand(-1, -1, mc.size(-1)))
181
+ return nms_outputs.boxes, nms_outputs.scores, nms_outputs.labels, out_mc, proto
182
+ return nms_outputs.boxes, nms_outputs.scores, nms_outputs.labels, nms_outputs.n_valid
183
+
184
+
185
+ def torch2imx(
186
+ model: torch.nn.Module,
187
+ file: Path | str,
188
+ conf: float,
189
+ iou: float,
190
+ max_det: int,
191
+ metadata: dict | None = None,
192
+ gptq: bool = False,
193
+ dataset=None,
194
+ prefix: str = "",
195
+ ):
196
+ """Export YOLO model to IMX format for deployment on Sony IMX500 devices.
197
+
198
+ This function quantizes a YOLO model using Model Compression Toolkit (MCT) and exports it to IMX format compatible
199
+ with Sony IMX500 edge devices. It supports both YOLOv8n and YOLO11n models for detection and pose estimation tasks.
200
+
201
+ Args:
202
+ model (torch.nn.Module): The YOLO model to export. Must be YOLOv8n or YOLO11n.
203
+ file (Path | str): Output file path for the exported model.
204
+ conf (float): Confidence threshold for NMS post-processing.
205
+ iou (float): IoU threshold for NMS post-processing.
206
+ max_det (int): Maximum number of detections to return.
207
+ metadata (dict | None, optional): Metadata to embed in the ONNX model. Defaults to None.
208
+ gptq (bool, optional): Whether to use Gradient-Based Post Training Quantization. If False, uses standard Post
209
+ Training Quantization. Defaults to False.
210
+ dataset (optional): Representative dataset for quantization calibration. Defaults to None.
211
+ prefix (str, optional): Logging prefix string. Defaults to "".
212
+
213
+ Returns:
214
+ f (Path): Path to the exported IMX model directory
215
+
216
+ Raises:
217
+ ValueError: If the model is not a supported YOLOv8n or YOLO11n variant.
218
+
219
+ Examples:
220
+ >>> from ultralytics import YOLO
221
+ >>> model = YOLO("yolo11n.pt")
222
+ >>> path, _ = export_imx(model, "model.imx", conf=0.25, iou=0.45, max_det=300)
223
+
224
+ Notes:
225
+ - Requires model_compression_toolkit, onnx, edgemdt_tpc, and edge-mdt-cl packages
226
+ - Only supports YOLOv8n and YOLO11n models (detection and pose tasks)
227
+ - Output includes quantized ONNX model, IMX binary, and labels.txt file
228
+ """
229
+ import model_compression_toolkit as mct
230
+ import onnx
231
+ from edgemdt_tpc import get_target_platform_capabilities
232
+
233
+ LOGGER.info(f"\n{prefix} starting export with model_compression_toolkit {mct.__version__}...")
234
+
235
+ def representative_dataset_gen(dataloader=dataset):
236
+ for batch in dataloader:
237
+ img = batch["img"]
238
+ img = img / 255.0
239
+ yield [img]
240
+
241
+ # NOTE: need tpc_version to be "4.0" for IMX500 Pose estimation models
242
+ tpc = get_target_platform_capabilities(tpc_version="4.0", device_type="imx500")
243
+
244
+ bit_cfg = mct.core.BitWidthConfig()
245
+ mct_config = MCT_CONFIG["YOLO11" if "C2PSA" in model.__str__() else "YOLOv8"][model.task]
246
+
247
+ # Check if the model has the expected number of layers
248
+ if len(list(model.modules())) != mct_config["n_layers"]:
249
+ raise ValueError("IMX export only supported for YOLOv8n and YOLO11n models.")
250
+
251
+ for layer_name in mct_config["layer_names"]:
252
+ bit_cfg.set_manual_activation_bit_width([mct.core.common.network_editors.NodeNameFilter(layer_name)], 16)
253
+
254
+ config = mct.core.CoreConfig(
255
+ mixed_precision_config=mct.core.MixedPrecisionQuantizationConfig(num_of_images=10),
256
+ quantization_config=mct.core.QuantizationConfig(concat_threshold_update=True),
257
+ bit_width_config=bit_cfg,
258
+ )
259
+
260
+ resource_utilization = mct.core.ResourceUtilization(weights_memory=mct_config["weights_memory"])
261
+
262
+ quant_model = (
263
+ mct.gptq.pytorch_gradient_post_training_quantization( # Perform Gradient-Based Post Training Quantization
264
+ model=model,
265
+ representative_data_gen=representative_dataset_gen,
266
+ target_resource_utilization=resource_utilization,
267
+ gptq_config=mct.gptq.get_pytorch_gptq_config(
268
+ n_epochs=1000, use_hessian_based_weights=False, use_hessian_sample_attention=False
269
+ ),
270
+ core_config=config,
271
+ target_platform_capabilities=tpc,
272
+ )[0]
273
+ if gptq
274
+ else mct.ptq.pytorch_post_training_quantization( # Perform post training quantization
275
+ in_module=model,
276
+ representative_data_gen=representative_dataset_gen,
277
+ target_resource_utilization=resource_utilization,
278
+ core_config=config,
279
+ target_platform_capabilities=tpc,
280
+ )[0]
281
+ )
282
+
283
+ if model.task != "classify":
284
+ quant_model = NMSWrapper(
285
+ model=quant_model,
286
+ score_threshold=conf or 0.001,
287
+ iou_threshold=iou,
288
+ max_detections=max_det,
289
+ task=model.task,
290
+ )
291
+
292
+ f = Path(str(file).replace(file.suffix, "_imx_model"))
293
+ f.mkdir(exist_ok=True)
294
+ onnx_model = f / Path(str(file.name).replace(file.suffix, "_imx.onnx")) # js dir
295
+
296
+ with onnx_export_patch():
297
+ mct.exporter.pytorch_export_model(
298
+ model=quant_model, save_model_path=onnx_model, repr_dataset=representative_dataset_gen
299
+ )
300
+
301
+ model_onnx = onnx.load(onnx_model) # load onnx model
302
+ for k, v in metadata.items():
303
+ meta = model_onnx.metadata_props.add()
304
+ meta.key, meta.value = k, str(v)
305
+
306
+ onnx.save(model_onnx, onnx_model)
307
+
308
+ # Find imxconv-pt binary - check venv bin directory first, then PATH
309
+ bin_dir = Path(sys.executable).parent
310
+ imxconv = bin_dir / ("imxconv-pt.exe" if WINDOWS else "imxconv-pt")
311
+ if not imxconv.exists():
312
+ imxconv = which("imxconv-pt") # fallback to PATH
313
+ if not imxconv:
314
+ raise FileNotFoundError("imxconv-pt not found. Install with: pip install imx500-converter[pt]")
315
+
316
+ subprocess.run(
317
+ [str(imxconv), "-i", str(onnx_model), "-o", str(f), "--no-input-persistency", "--overwrite-output"],
318
+ check=True,
319
+ )
320
+
321
+ # Needed for imx models.
322
+ with open(f / "labels.txt", "w", encoding="utf-8") as file:
323
+ file.writelines([f"{name}\n" for _, name in model.names.items()])
324
+
325
+ return f