dgenerate-ultralytics-headless 8.3.253__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. dgenerate_ultralytics_headless-8.3.253.dist-info/METADATA +405 -0
  2. dgenerate_ultralytics_headless-8.3.253.dist-info/RECORD +299 -0
  3. dgenerate_ultralytics_headless-8.3.253.dist-info/WHEEL +5 -0
  4. dgenerate_ultralytics_headless-8.3.253.dist-info/entry_points.txt +3 -0
  5. dgenerate_ultralytics_headless-8.3.253.dist-info/licenses/LICENSE +661 -0
  6. dgenerate_ultralytics_headless-8.3.253.dist-info/top_level.txt +1 -0
  7. tests/__init__.py +23 -0
  8. tests/conftest.py +59 -0
  9. tests/test_cli.py +131 -0
  10. tests/test_cuda.py +216 -0
  11. tests/test_engine.py +157 -0
  12. tests/test_exports.py +309 -0
  13. tests/test_integrations.py +151 -0
  14. tests/test_python.py +777 -0
  15. tests/test_solutions.py +371 -0
  16. ultralytics/__init__.py +48 -0
  17. ultralytics/assets/bus.jpg +0 -0
  18. ultralytics/assets/zidane.jpg +0 -0
  19. ultralytics/cfg/__init__.py +1028 -0
  20. ultralytics/cfg/datasets/Argoverse.yaml +78 -0
  21. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  22. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  23. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  24. ultralytics/cfg/datasets/HomeObjects-3K.yaml +32 -0
  25. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  26. ultralytics/cfg/datasets/Objects365.yaml +447 -0
  27. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  28. ultralytics/cfg/datasets/TT100K.yaml +346 -0
  29. ultralytics/cfg/datasets/VOC.yaml +102 -0
  30. ultralytics/cfg/datasets/VisDrone.yaml +87 -0
  31. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  32. ultralytics/cfg/datasets/brain-tumor.yaml +22 -0
  33. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  34. ultralytics/cfg/datasets/coco-pose.yaml +64 -0
  35. ultralytics/cfg/datasets/coco.yaml +118 -0
  36. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  37. ultralytics/cfg/datasets/coco128.yaml +101 -0
  38. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  39. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  40. ultralytics/cfg/datasets/coco8-pose.yaml +47 -0
  41. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  42. ultralytics/cfg/datasets/coco8.yaml +101 -0
  43. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  44. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  45. ultralytics/cfg/datasets/dog-pose.yaml +52 -0
  46. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  47. ultralytics/cfg/datasets/dota8.yaml +35 -0
  48. ultralytics/cfg/datasets/hand-keypoints.yaml +50 -0
  49. ultralytics/cfg/datasets/kitti.yaml +27 -0
  50. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  51. ultralytics/cfg/datasets/medical-pills.yaml +21 -0
  52. ultralytics/cfg/datasets/open-images-v7.yaml +663 -0
  53. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  54. ultralytics/cfg/datasets/signature.yaml +21 -0
  55. ultralytics/cfg/datasets/tiger-pose.yaml +41 -0
  56. ultralytics/cfg/datasets/xView.yaml +155 -0
  57. ultralytics/cfg/default.yaml +130 -0
  58. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  59. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  60. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  61. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  62. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  63. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  64. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  65. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  66. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  67. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  68. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  69. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  70. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  71. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  72. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  73. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  74. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  75. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  76. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  77. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  78. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  79. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  80. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  81. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  82. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  83. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  84. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  85. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  86. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  87. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +48 -0
  88. ultralytics/cfg/models/v8/yoloe-v8.yaml +48 -0
  89. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  90. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  91. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  92. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  93. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  94. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  95. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  96. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  97. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  98. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  99. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  100. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  101. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  102. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  103. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  104. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  105. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  106. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  107. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  108. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  109. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  110. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  111. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  112. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  113. ultralytics/cfg/trackers/botsort.yaml +21 -0
  114. ultralytics/cfg/trackers/bytetrack.yaml +12 -0
  115. ultralytics/data/__init__.py +26 -0
  116. ultralytics/data/annotator.py +66 -0
  117. ultralytics/data/augment.py +2801 -0
  118. ultralytics/data/base.py +435 -0
  119. ultralytics/data/build.py +437 -0
  120. ultralytics/data/converter.py +855 -0
  121. ultralytics/data/dataset.py +834 -0
  122. ultralytics/data/loaders.py +704 -0
  123. ultralytics/data/scripts/download_weights.sh +18 -0
  124. ultralytics/data/scripts/get_coco.sh +61 -0
  125. ultralytics/data/scripts/get_coco128.sh +18 -0
  126. ultralytics/data/scripts/get_imagenet.sh +52 -0
  127. ultralytics/data/split.py +138 -0
  128. ultralytics/data/split_dota.py +344 -0
  129. ultralytics/data/utils.py +798 -0
  130. ultralytics/engine/__init__.py +1 -0
  131. ultralytics/engine/exporter.py +1580 -0
  132. ultralytics/engine/model.py +1125 -0
  133. ultralytics/engine/predictor.py +508 -0
  134. ultralytics/engine/results.py +1522 -0
  135. ultralytics/engine/trainer.py +977 -0
  136. ultralytics/engine/tuner.py +449 -0
  137. ultralytics/engine/validator.py +387 -0
  138. ultralytics/hub/__init__.py +166 -0
  139. ultralytics/hub/auth.py +151 -0
  140. ultralytics/hub/google/__init__.py +174 -0
  141. ultralytics/hub/session.py +422 -0
  142. ultralytics/hub/utils.py +162 -0
  143. ultralytics/models/__init__.py +9 -0
  144. ultralytics/models/fastsam/__init__.py +7 -0
  145. ultralytics/models/fastsam/model.py +79 -0
  146. ultralytics/models/fastsam/predict.py +169 -0
  147. ultralytics/models/fastsam/utils.py +23 -0
  148. ultralytics/models/fastsam/val.py +38 -0
  149. ultralytics/models/nas/__init__.py +7 -0
  150. ultralytics/models/nas/model.py +98 -0
  151. ultralytics/models/nas/predict.py +56 -0
  152. ultralytics/models/nas/val.py +38 -0
  153. ultralytics/models/rtdetr/__init__.py +7 -0
  154. ultralytics/models/rtdetr/model.py +63 -0
  155. ultralytics/models/rtdetr/predict.py +88 -0
  156. ultralytics/models/rtdetr/train.py +89 -0
  157. ultralytics/models/rtdetr/val.py +216 -0
  158. ultralytics/models/sam/__init__.py +25 -0
  159. ultralytics/models/sam/amg.py +275 -0
  160. ultralytics/models/sam/build.py +365 -0
  161. ultralytics/models/sam/build_sam3.py +377 -0
  162. ultralytics/models/sam/model.py +169 -0
  163. ultralytics/models/sam/modules/__init__.py +1 -0
  164. ultralytics/models/sam/modules/blocks.py +1067 -0
  165. ultralytics/models/sam/modules/decoders.py +495 -0
  166. ultralytics/models/sam/modules/encoders.py +794 -0
  167. ultralytics/models/sam/modules/memory_attention.py +298 -0
  168. ultralytics/models/sam/modules/sam.py +1160 -0
  169. ultralytics/models/sam/modules/tiny_encoder.py +979 -0
  170. ultralytics/models/sam/modules/transformer.py +344 -0
  171. ultralytics/models/sam/modules/utils.py +512 -0
  172. ultralytics/models/sam/predict.py +3940 -0
  173. ultralytics/models/sam/sam3/__init__.py +3 -0
  174. ultralytics/models/sam/sam3/decoder.py +546 -0
  175. ultralytics/models/sam/sam3/encoder.py +529 -0
  176. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  177. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  178. ultralytics/models/sam/sam3/model_misc.py +199 -0
  179. ultralytics/models/sam/sam3/necks.py +129 -0
  180. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  181. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  182. ultralytics/models/sam/sam3/vitdet.py +547 -0
  183. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  184. ultralytics/models/utils/__init__.py +1 -0
  185. ultralytics/models/utils/loss.py +466 -0
  186. ultralytics/models/utils/ops.py +315 -0
  187. ultralytics/models/yolo/__init__.py +7 -0
  188. ultralytics/models/yolo/classify/__init__.py +7 -0
  189. ultralytics/models/yolo/classify/predict.py +90 -0
  190. ultralytics/models/yolo/classify/train.py +202 -0
  191. ultralytics/models/yolo/classify/val.py +216 -0
  192. ultralytics/models/yolo/detect/__init__.py +7 -0
  193. ultralytics/models/yolo/detect/predict.py +122 -0
  194. ultralytics/models/yolo/detect/train.py +227 -0
  195. ultralytics/models/yolo/detect/val.py +507 -0
  196. ultralytics/models/yolo/model.py +430 -0
  197. ultralytics/models/yolo/obb/__init__.py +7 -0
  198. ultralytics/models/yolo/obb/predict.py +56 -0
  199. ultralytics/models/yolo/obb/train.py +79 -0
  200. ultralytics/models/yolo/obb/val.py +302 -0
  201. ultralytics/models/yolo/pose/__init__.py +7 -0
  202. ultralytics/models/yolo/pose/predict.py +65 -0
  203. ultralytics/models/yolo/pose/train.py +110 -0
  204. ultralytics/models/yolo/pose/val.py +248 -0
  205. ultralytics/models/yolo/segment/__init__.py +7 -0
  206. ultralytics/models/yolo/segment/predict.py +109 -0
  207. ultralytics/models/yolo/segment/train.py +69 -0
  208. ultralytics/models/yolo/segment/val.py +307 -0
  209. ultralytics/models/yolo/world/__init__.py +5 -0
  210. ultralytics/models/yolo/world/train.py +173 -0
  211. ultralytics/models/yolo/world/train_world.py +178 -0
  212. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  213. ultralytics/models/yolo/yoloe/predict.py +162 -0
  214. ultralytics/models/yolo/yoloe/train.py +287 -0
  215. ultralytics/models/yolo/yoloe/train_seg.py +122 -0
  216. ultralytics/models/yolo/yoloe/val.py +206 -0
  217. ultralytics/nn/__init__.py +27 -0
  218. ultralytics/nn/autobackend.py +964 -0
  219. ultralytics/nn/modules/__init__.py +182 -0
  220. ultralytics/nn/modules/activation.py +54 -0
  221. ultralytics/nn/modules/block.py +1947 -0
  222. ultralytics/nn/modules/conv.py +669 -0
  223. ultralytics/nn/modules/head.py +1183 -0
  224. ultralytics/nn/modules/transformer.py +793 -0
  225. ultralytics/nn/modules/utils.py +159 -0
  226. ultralytics/nn/tasks.py +1768 -0
  227. ultralytics/nn/text_model.py +356 -0
  228. ultralytics/py.typed +1 -0
  229. ultralytics/solutions/__init__.py +41 -0
  230. ultralytics/solutions/ai_gym.py +108 -0
  231. ultralytics/solutions/analytics.py +264 -0
  232. ultralytics/solutions/config.py +107 -0
  233. ultralytics/solutions/distance_calculation.py +123 -0
  234. ultralytics/solutions/heatmap.py +125 -0
  235. ultralytics/solutions/instance_segmentation.py +86 -0
  236. ultralytics/solutions/object_blurrer.py +89 -0
  237. ultralytics/solutions/object_counter.py +190 -0
  238. ultralytics/solutions/object_cropper.py +87 -0
  239. ultralytics/solutions/parking_management.py +280 -0
  240. ultralytics/solutions/queue_management.py +93 -0
  241. ultralytics/solutions/region_counter.py +133 -0
  242. ultralytics/solutions/security_alarm.py +151 -0
  243. ultralytics/solutions/similarity_search.py +219 -0
  244. ultralytics/solutions/solutions.py +828 -0
  245. ultralytics/solutions/speed_estimation.py +114 -0
  246. ultralytics/solutions/streamlit_inference.py +260 -0
  247. ultralytics/solutions/templates/similarity-search.html +156 -0
  248. ultralytics/solutions/trackzone.py +88 -0
  249. ultralytics/solutions/vision_eye.py +67 -0
  250. ultralytics/trackers/__init__.py +7 -0
  251. ultralytics/trackers/basetrack.py +115 -0
  252. ultralytics/trackers/bot_sort.py +257 -0
  253. ultralytics/trackers/byte_tracker.py +469 -0
  254. ultralytics/trackers/track.py +116 -0
  255. ultralytics/trackers/utils/__init__.py +1 -0
  256. ultralytics/trackers/utils/gmc.py +339 -0
  257. ultralytics/trackers/utils/kalman_filter.py +482 -0
  258. ultralytics/trackers/utils/matching.py +154 -0
  259. ultralytics/utils/__init__.py +1450 -0
  260. ultralytics/utils/autobatch.py +118 -0
  261. ultralytics/utils/autodevice.py +205 -0
  262. ultralytics/utils/benchmarks.py +728 -0
  263. ultralytics/utils/callbacks/__init__.py +5 -0
  264. ultralytics/utils/callbacks/base.py +233 -0
  265. ultralytics/utils/callbacks/clearml.py +146 -0
  266. ultralytics/utils/callbacks/comet.py +625 -0
  267. ultralytics/utils/callbacks/dvc.py +197 -0
  268. ultralytics/utils/callbacks/hub.py +110 -0
  269. ultralytics/utils/callbacks/mlflow.py +134 -0
  270. ultralytics/utils/callbacks/neptune.py +126 -0
  271. ultralytics/utils/callbacks/platform.py +453 -0
  272. ultralytics/utils/callbacks/raytune.py +42 -0
  273. ultralytics/utils/callbacks/tensorboard.py +123 -0
  274. ultralytics/utils/callbacks/wb.py +188 -0
  275. ultralytics/utils/checks.py +1020 -0
  276. ultralytics/utils/cpu.py +85 -0
  277. ultralytics/utils/dist.py +123 -0
  278. ultralytics/utils/downloads.py +529 -0
  279. ultralytics/utils/errors.py +35 -0
  280. ultralytics/utils/events.py +113 -0
  281. ultralytics/utils/export/__init__.py +7 -0
  282. ultralytics/utils/export/engine.py +237 -0
  283. ultralytics/utils/export/imx.py +325 -0
  284. ultralytics/utils/export/tensorflow.py +231 -0
  285. ultralytics/utils/files.py +219 -0
  286. ultralytics/utils/git.py +137 -0
  287. ultralytics/utils/instance.py +484 -0
  288. ultralytics/utils/logger.py +506 -0
  289. ultralytics/utils/loss.py +849 -0
  290. ultralytics/utils/metrics.py +1563 -0
  291. ultralytics/utils/nms.py +337 -0
  292. ultralytics/utils/ops.py +664 -0
  293. ultralytics/utils/patches.py +201 -0
  294. ultralytics/utils/plotting.py +1047 -0
  295. ultralytics/utils/tal.py +404 -0
  296. ultralytics/utils/torch_utils.py +984 -0
  297. ultralytics/utils/tqdm.py +443 -0
  298. ultralytics/utils/triton.py +112 -0
  299. ultralytics/utils/tuner.py +168 -0
@@ -0,0 +1,1580 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+ """
3
+ Export a YOLO PyTorch model to other formats. TensorFlow exports authored by https://github.com/zldrobit.
4
+
5
+ Format | `format=argument` | Model
6
+ --- | --- | ---
7
+ PyTorch | - | yolo11n.pt
8
+ TorchScript | `torchscript` | yolo11n.torchscript
9
+ ONNX | `onnx` | yolo11n.onnx
10
+ OpenVINO | `openvino` | yolo11n_openvino_model/
11
+ TensorRT | `engine` | yolo11n.engine
12
+ CoreML | `coreml` | yolo11n.mlpackage
13
+ TensorFlow SavedModel | `saved_model` | yolo11n_saved_model/
14
+ TensorFlow GraphDef | `pb` | yolo11n.pb
15
+ TensorFlow Lite | `tflite` | yolo11n.tflite
16
+ TensorFlow Edge TPU | `edgetpu` | yolo11n_edgetpu.tflite
17
+ TensorFlow.js | `tfjs` | yolo11n_web_model/
18
+ PaddlePaddle | `paddle` | yolo11n_paddle_model/
19
+ MNN | `mnn` | yolo11n.mnn
20
+ NCNN | `ncnn` | yolo11n_ncnn_model/
21
+ IMX | `imx` | yolo11n_imx_model/
22
+ RKNN | `rknn` | yolo11n_rknn_model/
23
+ ExecuTorch | `executorch` | yolo11n_executorch_model/
24
+ Axelera | `axelera` | yolo11n_axelera_model/
25
+
26
+ Requirements:
27
+ $ pip install "ultralytics[export]"
28
+
29
+ Python:
30
+ from ultralytics import YOLO
31
+ model = YOLO('yolo11n.pt')
32
+ results = model.export(format='onnx')
33
+
34
+ CLI:
35
+ $ yolo mode=export model=yolo11n.pt format=onnx
36
+
37
+ Inference:
38
+ $ yolo predict model=yolo11n.pt # PyTorch
39
+ yolo11n.torchscript # TorchScript
40
+ yolo11n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
41
+ yolo11n_openvino_model # OpenVINO
42
+ yolo11n.engine # TensorRT
43
+ yolo11n.mlpackage # CoreML (macOS-only)
44
+ yolo11n_saved_model # TensorFlow SavedModel
45
+ yolo11n.pb # TensorFlow GraphDef
46
+ yolo11n.tflite # TensorFlow Lite
47
+ yolo11n_edgetpu.tflite # TensorFlow Edge TPU
48
+ yolo11n_paddle_model # PaddlePaddle
49
+ yolo11n.mnn # MNN
50
+ yolo11n_ncnn_model # NCNN
51
+ yolo11n_imx_model # IMX
52
+ yolo11n_rknn_model # RKNN
53
+ yolo11n_executorch_model # ExecuTorch
54
+ yolo11n_axelera_model # Axelera
55
+
56
+ TensorFlow.js:
57
+ $ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
58
+ $ npm install
59
+ $ ln -s ../../yolo11n_web_model public/yolo11n_web_model
60
+ $ npm start
61
+ """
62
+
63
+ import json
64
+ import os
65
+ import re
66
+ import shutil
67
+ import subprocess
68
+ import time
69
+ from copy import deepcopy
70
+ from datetime import datetime
71
+ from pathlib import Path
72
+
73
+ import numpy as np
74
+ import torch
75
+
76
+ from ultralytics import __version__
77
+ from ultralytics.cfg import TASK2DATA, get_cfg
78
+ from ultralytics.data import build_dataloader
79
+ from ultralytics.data.dataset import YOLODataset
80
+ from ultralytics.data.utils import check_cls_dataset, check_det_dataset
81
+ from ultralytics.nn.autobackend import check_class_names, default_class_names
82
+ from ultralytics.nn.modules import C2f, Classify, Detect, RTDETRDecoder
83
+ from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, WorldModel
84
+ from ultralytics.utils import (
85
+ ARM64,
86
+ DEFAULT_CFG,
87
+ IS_COLAB,
88
+ IS_DEBIAN_BOOKWORM,
89
+ IS_DEBIAN_TRIXIE,
90
+ IS_DOCKER,
91
+ IS_JETSON,
92
+ IS_RASPBERRYPI,
93
+ IS_UBUNTU,
94
+ LINUX,
95
+ LOGGER,
96
+ MACOS,
97
+ MACOS_VERSION,
98
+ RKNN_CHIPS,
99
+ SETTINGS,
100
+ TORCH_VERSION,
101
+ WINDOWS,
102
+ YAML,
103
+ callbacks,
104
+ colorstr,
105
+ get_default_args,
106
+ )
107
+ from ultralytics.utils.checks import (
108
+ IS_PYTHON_3_10,
109
+ IS_PYTHON_MINIMUM_3_9,
110
+ check_apt_requirements,
111
+ check_imgsz,
112
+ check_requirements,
113
+ check_version,
114
+ is_intel,
115
+ is_sudo_available,
116
+ )
117
+ from ultralytics.utils.export import (
118
+ keras2pb,
119
+ onnx2engine,
120
+ onnx2saved_model,
121
+ pb2tfjs,
122
+ tflite2edgetpu,
123
+ torch2imx,
124
+ torch2onnx,
125
+ )
126
+ from ultralytics.utils.files import file_size
127
+ from ultralytics.utils.metrics import batch_probiou
128
+ from ultralytics.utils.nms import TorchNMS
129
+ from ultralytics.utils.ops import Profile
130
+ from ultralytics.utils.patches import arange_patch
131
+ from ultralytics.utils.torch_utils import (
132
+ TORCH_1_10,
133
+ TORCH_1_11,
134
+ TORCH_1_13,
135
+ TORCH_2_1,
136
+ TORCH_2_4,
137
+ TORCH_2_9,
138
+ select_device,
139
+ )
140
+
141
+
142
+ def export_formats():
143
+ """Return a dictionary of Ultralytics YOLO export formats."""
144
+ x = [
145
+ ["PyTorch", "-", ".pt", True, True, []],
146
+ ["TorchScript", "torchscript", ".torchscript", True, True, ["batch", "optimize", "half", "nms", "dynamic"]],
147
+ ["ONNX", "onnx", ".onnx", True, True, ["batch", "dynamic", "half", "opset", "simplify", "nms"]],
148
+ [
149
+ "OpenVINO",
150
+ "openvino",
151
+ "_openvino_model",
152
+ True,
153
+ False,
154
+ ["batch", "dynamic", "half", "int8", "nms", "fraction"],
155
+ ],
156
+ [
157
+ "TensorRT",
158
+ "engine",
159
+ ".engine",
160
+ False,
161
+ True,
162
+ ["batch", "dynamic", "half", "int8", "simplify", "nms", "fraction"],
163
+ ],
164
+ ["CoreML", "coreml", ".mlpackage", True, False, ["batch", "dynamic", "half", "int8", "nms"]],
165
+ ["TensorFlow SavedModel", "saved_model", "_saved_model", True, True, ["batch", "int8", "keras", "nms"]],
166
+ ["TensorFlow GraphDef", "pb", ".pb", True, True, ["batch"]],
167
+ ["TensorFlow Lite", "tflite", ".tflite", True, False, ["batch", "half", "int8", "nms", "fraction"]],
168
+ ["TensorFlow Edge TPU", "edgetpu", "_edgetpu.tflite", True, False, []],
169
+ ["TensorFlow.js", "tfjs", "_web_model", True, False, ["batch", "half", "int8", "nms"]],
170
+ ["PaddlePaddle", "paddle", "_paddle_model", True, True, ["batch"]],
171
+ ["MNN", "mnn", ".mnn", True, True, ["batch", "half", "int8"]],
172
+ ["NCNN", "ncnn", "_ncnn_model", True, True, ["batch", "half"]],
173
+ ["IMX", "imx", "_imx_model", True, True, ["int8", "fraction", "nms"]],
174
+ ["RKNN", "rknn", "_rknn_model", False, False, ["batch", "name"]],
175
+ ["ExecuTorch", "executorch", "_executorch_model", True, False, ["batch"]],
176
+ ["Axelera", "axelera", "_axelera_model", False, False, ["batch", "int8", "fraction"]],
177
+ ]
178
+ return dict(zip(["Format", "Argument", "Suffix", "CPU", "GPU", "Arguments"], zip(*x)))
179
+
180
+
181
+ def best_onnx_opset(onnx, cuda=False) -> int:
182
+ """Return max ONNX opset for this torch version with ONNX fallback."""
183
+ if TORCH_2_4: # _constants.ONNX_MAX_OPSET first defined in torch 1.13
184
+ opset = torch.onnx.utils._constants.ONNX_MAX_OPSET - 1 # use second-latest version for safety
185
+ if cuda:
186
+ opset -= 2 # fix CUDA ONNXRuntime NMS squeeze op errors
187
+ else:
188
+ version = ".".join(TORCH_VERSION.split(".")[:2])
189
+ opset = {
190
+ "1.8": 12,
191
+ "1.9": 12,
192
+ "1.10": 13,
193
+ "1.11": 14,
194
+ "1.12": 15,
195
+ "1.13": 17,
196
+ "2.0": 17, # reduced from 18 to fix ONNX errors
197
+ "2.1": 17, # reduced from 19
198
+ "2.2": 17, # reduced from 19
199
+ "2.3": 17, # reduced from 19
200
+ "2.4": 20,
201
+ "2.5": 20,
202
+ "2.6": 20,
203
+ "2.7": 20,
204
+ "2.8": 23,
205
+ }.get(version, 12)
206
+ return min(opset, onnx.defs.onnx_opset_version())
207
+
208
+
209
+ def validate_args(format, passed_args, valid_args):
210
+ """Validate arguments based on the export format.
211
+
212
+ Args:
213
+ format (str): The export format.
214
+ passed_args (Namespace): The arguments used during export.
215
+ valid_args (list): List of valid arguments for the format.
216
+
217
+ Raises:
218
+ AssertionError: If an unsupported argument is used, or if the format lacks supported argument listings.
219
+ """
220
+ export_args = ["half", "int8", "dynamic", "keras", "nms", "batch", "fraction"]
221
+
222
+ assert valid_args is not None, f"ERROR ❌️ valid arguments for '{format}' not listed."
223
+ custom = {"batch": 1, "data": None, "device": None} # exporter defaults
224
+ default_args = get_cfg(DEFAULT_CFG, custom)
225
+ for arg in export_args:
226
+ not_default = getattr(passed_args, arg, None) != getattr(default_args, arg, None)
227
+ if not_default:
228
+ assert arg in valid_args, f"ERROR ❌️ argument '{arg}' is not supported for format='{format}'"
229
+
230
+
231
+ def try_export(inner_func):
232
+ """YOLO export decorator, i.e. @try_export."""
233
+ inner_args = get_default_args(inner_func)
234
+
235
+ def outer_func(*args, **kwargs):
236
+ """Export a model."""
237
+ prefix = inner_args["prefix"]
238
+ dt = 0.0
239
+ try:
240
+ with Profile() as dt:
241
+ f = inner_func(*args, **kwargs) # exported file/dir or tuple of (file/dir, *)
242
+ path = f if isinstance(f, (str, Path)) else f[0]
243
+ mb = file_size(path)
244
+ assert mb > 0.0, "0.0 MB output model size"
245
+ LOGGER.info(f"{prefix} export success ✅ {dt.t:.1f}s, saved as '{path}' ({mb:.1f} MB)")
246
+ return f
247
+ except Exception as e:
248
+ LOGGER.error(f"{prefix} export failure {dt.t:.1f}s: {e}")
249
+ raise e
250
+
251
+ return outer_func
252
+
253
+
254
+ class Exporter:
255
+ """A class for exporting YOLO models to various formats.
256
+
257
+ This class provides functionality to export YOLO models to different formats including ONNX, TensorRT, CoreML,
258
+ TensorFlow, and others. It handles format validation, device selection, model preparation, and the actual export
259
+ process for each supported format.
260
+
261
+ Attributes:
262
+ args (SimpleNamespace): Configuration arguments for the exporter.
263
+ callbacks (dict): Dictionary of callback functions for different export events.
264
+ im (torch.Tensor): Input tensor for model inference during export.
265
+ model (torch.nn.Module): The YOLO model to be exported.
266
+ file (Path): Path to the model file being exported.
267
+ output_shape (tuple): Shape of the model output tensor(s).
268
+ pretty_name (str): Formatted model name for display purposes.
269
+ metadata (dict): Model metadata including description, author, version, etc.
270
+ device (torch.device): Device on which the model is loaded.
271
+ imgsz (tuple): Input image size for the model.
272
+
273
+ Methods:
274
+ __call__: Main export method that handles the export process.
275
+ get_int8_calibration_dataloader: Build dataloader for INT8 calibration.
276
+ export_torchscript: Export model to TorchScript format.
277
+ export_onnx: Export model to ONNX format.
278
+ export_openvino: Export model to OpenVINO format.
279
+ export_paddle: Export model to PaddlePaddle format.
280
+ export_mnn: Export model to MNN format.
281
+ export_ncnn: Export model to NCNN format.
282
+ export_coreml: Export model to CoreML format.
283
+ export_engine: Export model to TensorRT format.
284
+ export_saved_model: Export model to TensorFlow SavedModel format.
285
+ export_pb: Export model to TensorFlow GraphDef format.
286
+ export_tflite: Export model to TensorFlow Lite format.
287
+ export_edgetpu: Export model to Edge TPU format.
288
+ export_tfjs: Export model to TensorFlow.js format.
289
+ export_rknn: Export model to RKNN format.
290
+ export_imx: Export model to IMX format.
291
+
292
+ Examples:
293
+ Export a YOLOv8 model to ONNX format
294
+ >>> from ultralytics.engine.exporter import Exporter
295
+ >>> exporter = Exporter()
296
+ >>> exporter(model="yolov8n.pt") # exports to yolov8n.onnx
297
+
298
+ Export with specific arguments
299
+ >>> args = {"format": "onnx", "dynamic": True, "half": True}
300
+ >>> exporter = Exporter(overrides=args)
301
+ >>> exporter(model="yolov8n.pt")
302
+ """
303
+
304
+ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
305
+ """Initialize the Exporter class.
306
+
307
+ Args:
308
+ cfg (str, optional): Path to a configuration file.
309
+ overrides (dict, optional): Configuration overrides.
310
+ _callbacks (dict, optional): Dictionary of callback functions.
311
+ """
312
+ self.args = get_cfg(cfg, overrides)
313
+ self.callbacks = _callbacks or callbacks.get_default_callbacks()
314
+ callbacks.add_integration_callbacks(self)
315
+
316
+ def __call__(self, model=None) -> str:
317
+ """Export a model and return the final exported path as a string.
318
+
319
+ Returns:
320
+ (str): Path to the exported file or directory (the last export artifact).
321
+ """
322
+ t = time.time()
323
+ fmt = self.args.format.lower() # to lowercase
324
+ if fmt in {"tensorrt", "trt"}: # 'engine' aliases
325
+ fmt = "engine"
326
+ if fmt in {"mlmodel", "mlpackage", "mlprogram", "apple", "ios", "coreml"}: # 'coreml' aliases
327
+ fmt = "coreml"
328
+ fmts_dict = export_formats()
329
+ fmts = tuple(fmts_dict["Argument"][1:]) # available export formats
330
+ if fmt not in fmts:
331
+ import difflib
332
+
333
+ # Get the closest match if format is invalid
334
+ matches = difflib.get_close_matches(fmt, fmts, n=1, cutoff=0.6) # 60% similarity required to match
335
+ if not matches:
336
+ msg = "Model is already in PyTorch format." if fmt == "pt" else f"Invalid export format='{fmt}'."
337
+ raise ValueError(f"{msg} Valid formats are {fmts}")
338
+ LOGGER.warning(f"Invalid export format='{fmt}', updating to format='{matches[0]}'")
339
+ fmt = matches[0]
340
+ flags = [x == fmt for x in fmts]
341
+ if sum(flags) != 1:
342
+ raise ValueError(f"Invalid export format='{fmt}'. Valid formats are {fmts}")
343
+ (
344
+ jit,
345
+ onnx,
346
+ xml,
347
+ engine,
348
+ coreml,
349
+ saved_model,
350
+ pb,
351
+ tflite,
352
+ edgetpu,
353
+ tfjs,
354
+ paddle,
355
+ mnn,
356
+ ncnn,
357
+ imx,
358
+ rknn,
359
+ executorch,
360
+ axelera,
361
+ ) = flags # export booleans
362
+
363
+ is_tf_format = any((saved_model, pb, tflite, edgetpu, tfjs))
364
+
365
+ # Device
366
+ dla = None
367
+ if engine and self.args.device is None:
368
+ LOGGER.warning("TensorRT requires GPU export, automatically assigning device=0")
369
+ self.args.device = "0"
370
+ if engine and "dla" in str(self.args.device): # convert int/list to str first
371
+ device_str = str(self.args.device)
372
+ dla = device_str.rsplit(":", 1)[-1]
373
+ self.args.device = "0" # update device to "0"
374
+ assert dla in {"0", "1"}, f"Expected device 'dla:0' or 'dla:1', but got {device_str}."
375
+ if imx and self.args.device is None and torch.cuda.is_available():
376
+ LOGGER.warning("Exporting on CPU while CUDA is available, setting device=0 for faster export on GPU.")
377
+ self.args.device = "0" # update device to "0"
378
+ self.device = select_device("cpu" if self.args.device is None else self.args.device)
379
+
380
+ # Argument compatibility checks
381
+ fmt_keys = fmts_dict["Arguments"][flags.index(True) + 1]
382
+ validate_args(fmt, self.args, fmt_keys)
383
+ if axelera:
384
+ if not IS_PYTHON_3_10:
385
+ raise SystemError("Axelera export only supported on Python 3.10.")
386
+ if not self.args.int8:
387
+ LOGGER.warning("Setting int8=True for Axelera mixed-precision export.")
388
+ self.args.int8 = True
389
+ if model.task not in {"detect"}:
390
+ raise ValueError("Axelera export only supported for detection models.")
391
+ if not self.args.data:
392
+ self.args.data = "coco128.yaml" # Axelera default to coco128.yaml
393
+ if imx:
394
+ if not self.args.int8:
395
+ LOGGER.warning("IMX export requires int8=True, setting int8=True.")
396
+ self.args.int8 = True
397
+ if not self.args.nms and model.task in {"detect", "pose", "segment"}:
398
+ LOGGER.warning("IMX export requires nms=True, setting nms=True.")
399
+ self.args.nms = True
400
+ if model.task not in {"detect", "pose", "classify", "segment"}:
401
+ raise ValueError(
402
+ "IMX export only supported for detection, pose estimation, classification, and segmentation models."
403
+ )
404
+ if not hasattr(model, "names"):
405
+ model.names = default_class_names()
406
+ model.names = check_class_names(model.names)
407
+ if self.args.half and self.args.int8:
408
+ LOGGER.warning("half=True and int8=True are mutually exclusive, setting half=False.")
409
+ self.args.half = False
410
+ if self.args.half and jit and self.device.type == "cpu":
411
+ LOGGER.warning(
412
+ "half=True only compatible with GPU export for TorchScript, i.e. use device=0, setting half=False."
413
+ )
414
+ self.args.half = False
415
+ self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size
416
+ if self.args.optimize:
417
+ assert not ncnn, "optimize=True not compatible with format='ncnn', i.e. use optimize=False"
418
+ assert self.device.type == "cpu", "optimize=True not compatible with cuda devices, i.e. use device='cpu'"
419
+ if rknn:
420
+ if not self.args.name:
421
+ LOGGER.warning(
422
+ "Rockchip RKNN export requires a missing 'name' arg for processor type. "
423
+ "Using default name='rk3588'."
424
+ )
425
+ self.args.name = "rk3588"
426
+ self.args.name = self.args.name.lower()
427
+ assert self.args.name in RKNN_CHIPS, (
428
+ f"Invalid processor name '{self.args.name}' for Rockchip RKNN export. Valid names are {RKNN_CHIPS}."
429
+ )
430
+ if self.args.nms:
431
+ assert not isinstance(model, ClassificationModel), "'nms=True' is not valid for classification models."
432
+ assert not tflite or not ARM64 or not LINUX, "TFLite export with NMS unsupported on ARM64 Linux"
433
+ assert not is_tf_format or TORCH_1_13, "TensorFlow exports with NMS require torch>=1.13"
434
+ assert not onnx or TORCH_1_13, "ONNX export with NMS requires torch>=1.13"
435
+ if getattr(model, "end2end", False) or isinstance(model.model[-1], RTDETRDecoder):
436
+ LOGGER.warning("'nms=True' is not available for end2end models. Forcing 'nms=False'.")
437
+ self.args.nms = False
438
+ self.args.conf = self.args.conf or 0.25 # set conf default value for nms export
439
+ if (engine or coreml or self.args.nms) and self.args.dynamic and self.args.batch == 1:
440
+ LOGGER.warning(
441
+ f"'dynamic=True' model with '{'nms=True' if self.args.nms else f'format={self.args.format}'}' requires max batch size, i.e. 'batch=16'"
442
+ )
443
+ if edgetpu:
444
+ if not LINUX or ARM64:
445
+ raise SystemError(
446
+ "Edge TPU export only supported on non-aarch64 Linux. See https://coral.ai/docs/edgetpu/compiler"
447
+ )
448
+ elif self.args.batch != 1: # see github.com/ultralytics/ultralytics/pull/13420
449
+ LOGGER.warning("Edge TPU export requires batch size 1, setting batch=1.")
450
+ self.args.batch = 1
451
+ if isinstance(model, WorldModel):
452
+ LOGGER.warning(
453
+ "YOLOWorld (original version) export is not supported to any format. "
454
+ "YOLOWorldv2 models (i.e. 'yolov8s-worldv2.pt') only support export to "
455
+ "(torchscript, onnx, openvino, engine, coreml) formats. "
456
+ "See https://docs.ultralytics.com/models/yolo-world for details."
457
+ )
458
+ model.clip_model = None # openvino int8 export error: https://github.com/ultralytics/ultralytics/pull/18445
459
+ if self.args.int8 and not self.args.data:
460
+ self.args.data = DEFAULT_CFG.data or TASK2DATA[getattr(model, "task", "detect")] # assign default data
461
+ LOGGER.warning(
462
+ f"INT8 export requires a missing 'data' arg for calibration. Using default 'data={self.args.data}'."
463
+ )
464
+ if tfjs and (ARM64 and LINUX):
465
+ raise SystemError("TF.js exports are not currently supported on ARM64 Linux")
466
+ # Recommend OpenVINO if export and Intel CPU
467
+ if SETTINGS.get("openvino_msg"):
468
+ if is_intel():
469
+ LOGGER.info(
470
+ "💡 ProTip: Export to OpenVINO format for best performance on Intel hardware."
471
+ " Learn more at https://docs.ultralytics.com/integrations/openvino/"
472
+ )
473
+ SETTINGS["openvino_msg"] = False
474
+
475
+ # Input
476
+ im = torch.zeros(self.args.batch, model.yaml.get("channels", 3), *self.imgsz).to(self.device)
477
+ file = Path(
478
+ getattr(model, "pt_path", None) or getattr(model, "yaml_file", None) or model.yaml.get("yaml_file", "")
479
+ )
480
+ if file.suffix in {".yaml", ".yml"}:
481
+ file = Path(file.name)
482
+
483
+ # Update model
484
+ model = deepcopy(model).to(self.device)
485
+ for p in model.parameters():
486
+ p.requires_grad = False
487
+ model.eval()
488
+ model.float()
489
+ model = model.fuse()
490
+
491
+ if imx:
492
+ from ultralytics.utils.export.imx import FXModel
493
+
494
+ model = FXModel(model, self.imgsz)
495
+ if tflite or edgetpu:
496
+ from ultralytics.utils.export.tensorflow import tf_wrapper
497
+
498
+ model = tf_wrapper(model)
499
+ for m in model.modules():
500
+ if isinstance(m, Classify):
501
+ m.export = True
502
+ if isinstance(m, (Detect, RTDETRDecoder)): # includes all Detect subclasses like Segment, Pose, OBB
503
+ m.dynamic = self.args.dynamic
504
+ m.export = True
505
+ m.format = self.args.format
506
+ m.max_det = self.args.max_det
507
+ m.xyxy = self.args.nms and not coreml
508
+ m.shape = None # reset cached shape for new export input size
509
+ if hasattr(model, "pe") and hasattr(m, "fuse"): # for YOLOE models
510
+ m.fuse(model.pe.to(self.device))
511
+ elif isinstance(m, C2f) and not is_tf_format:
512
+ # EdgeTPU does not support FlexSplitV while split provides cleaner ONNX graph
513
+ m.forward = m.forward_split
514
+
515
+ y = None
516
+ for _ in range(2): # dry runs
517
+ y = NMSModel(model, self.args)(im) if self.args.nms and not coreml and not imx else model(im)
518
+ if self.args.half and (onnx or jit) and self.device.type != "cpu":
519
+ im, model = im.half(), model.half() # to FP16
520
+
521
+ # Assign
522
+ self.im = im
523
+ self.model = model
524
+ self.file = file
525
+ self.output_shape = (
526
+ tuple(y.shape)
527
+ if isinstance(y, torch.Tensor)
528
+ else tuple(tuple(x.shape if isinstance(x, torch.Tensor) else []) for x in y)
529
+ )
530
+ self.pretty_name = Path(self.model.yaml.get("yaml_file", self.file)).stem.replace("yolo", "YOLO")
531
+ data = model.args["data"] if hasattr(model, "args") and isinstance(model.args, dict) else ""
532
+ description = f"Ultralytics {self.pretty_name} model {f'trained on {data}' if data else ''}"
533
+ self.metadata = {
534
+ "description": description,
535
+ "author": "Ultralytics",
536
+ "date": datetime.now().isoformat(),
537
+ "version": __version__,
538
+ "license": "AGPL-3.0 License (https://ultralytics.com/license)",
539
+ "docs": "https://docs.ultralytics.com",
540
+ "stride": int(max(model.stride)),
541
+ "task": model.task,
542
+ "batch": self.args.batch,
543
+ "imgsz": self.imgsz,
544
+ "names": model.names,
545
+ "args": {k: v for k, v in self.args if k in fmt_keys},
546
+ "channels": model.yaml.get("channels", 3),
547
+ } # model metadata
548
+ if dla is not None:
549
+ self.metadata["dla"] = dla # make sure `AutoBackend` uses correct dla device if it has one
550
+ if model.task == "pose":
551
+ self.metadata["kpt_shape"] = model.model[-1].kpt_shape
552
+ if hasattr(model, "kpt_names"):
553
+ self.metadata["kpt_names"] = model.kpt_names
554
+
555
+ LOGGER.info(
556
+ f"\n{colorstr('PyTorch:')} starting from '{file}' with input shape {tuple(im.shape)} BCHW and "
557
+ f"output shape(s) {self.output_shape} ({file_size(file):.1f} MB)"
558
+ )
559
+ self.run_callbacks("on_export_start")
560
+ # Exports
561
+ f = [""] * len(fmts) # exported filenames
562
+ if jit: # TorchScript
563
+ f[0] = self.export_torchscript()
564
+ if engine: # TensorRT required before ONNX
565
+ f[1] = self.export_engine(dla=dla)
566
+ if onnx: # ONNX
567
+ f[2] = self.export_onnx()
568
+ if xml: # OpenVINO
569
+ f[3] = self.export_openvino()
570
+ if coreml: # CoreML
571
+ f[4] = self.export_coreml()
572
+ if is_tf_format: # TensorFlow formats
573
+ self.args.int8 |= edgetpu
574
+ f[5], keras_model = self.export_saved_model()
575
+ if pb or tfjs: # pb prerequisite to tfjs
576
+ f[6] = self.export_pb(keras_model=keras_model)
577
+ if tflite:
578
+ f[7] = self.export_tflite()
579
+ if edgetpu:
580
+ f[8] = self.export_edgetpu(tflite_model=Path(f[5]) / f"{self.file.stem}_full_integer_quant.tflite")
581
+ if tfjs:
582
+ f[9] = self.export_tfjs()
583
+ if paddle: # PaddlePaddle
584
+ f[10] = self.export_paddle()
585
+ if mnn: # MNN
586
+ f[11] = self.export_mnn()
587
+ if ncnn: # NCNN
588
+ f[12] = self.export_ncnn()
589
+ if imx:
590
+ f[13] = self.export_imx()
591
+ if rknn:
592
+ f[14] = self.export_rknn()
593
+ if executorch:
594
+ f[15] = self.export_executorch()
595
+ if axelera:
596
+ f[16] = self.export_axelera()
597
+
598
+ # Finish
599
+ f = [str(x) for x in f if x] # filter out '' and None
600
+ if any(f):
601
+ f = str(Path(f[-1]))
602
+ square = self.imgsz[0] == self.imgsz[1]
603
+ s = (
604
+ ""
605
+ if square
606
+ else f"WARNING ⚠️ non-PyTorch val requires square images, 'imgsz={self.imgsz}' will not "
607
+ f"work. Use export 'imgsz={max(self.imgsz)}' if val is required."
608
+ )
609
+ imgsz = self.imgsz[0] if square else str(self.imgsz)[1:-1].replace(" ", "")
610
+ predict_data = f"data={data}" if model.task == "segment" and pb else ""
611
+ q = "int8" if self.args.int8 else "half" if self.args.half else "" # quantization
612
+ LOGGER.info(
613
+ f"\nExport complete ({time.time() - t:.1f}s)"
614
+ f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
615
+ f"\nPredict: yolo predict task={model.task} model={f} imgsz={imgsz} {q} {predict_data}"
616
+ f"\nValidate: yolo val task={model.task} model={f} imgsz={imgsz} data={data} {q} {s}"
617
+ f"\nVisualize: https://netron.app"
618
+ )
619
+
620
+ self.run_callbacks("on_export_end")
621
+ return f # path to final export artifact
622
+
623
+ def get_int8_calibration_dataloader(self, prefix=""):
624
+ """Build and return a dataloader for calibration of INT8 models."""
625
+ LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'")
626
+ data = (check_cls_dataset if self.model.task == "classify" else check_det_dataset)(self.args.data)
627
+ dataset = YOLODataset(
628
+ data[self.args.split or "val"],
629
+ data=data,
630
+ fraction=self.args.fraction,
631
+ task=self.model.task,
632
+ imgsz=self.imgsz[0],
633
+ augment=False,
634
+ batch_size=self.args.batch,
635
+ )
636
+ n = len(dataset)
637
+ if n < self.args.batch:
638
+ raise ValueError(
639
+ f"The calibration dataset ({n} images) must have at least as many images as the batch size "
640
+ f"('batch={self.args.batch}')."
641
+ )
642
+ elif self.args.format == "axelera" and n < 100:
643
+ LOGGER.warning(f"{prefix} >100 images required for Axelera calibration, found {n} images.")
644
+ elif self.args.format != "axelera" and n < 300:
645
+ LOGGER.warning(f"{prefix} >300 images recommended for INT8 calibration, found {n} images.")
646
+ return build_dataloader(dataset, batch=self.args.batch, workers=0, drop_last=True) # required for batch loading
647
+
648
+ @try_export
649
+ def export_torchscript(self, prefix=colorstr("TorchScript:")):
650
+ """Export YOLO model to TorchScript format."""
651
+ LOGGER.info(f"\n{prefix} starting export with torch {TORCH_VERSION}...")
652
+ f = self.file.with_suffix(".torchscript")
653
+
654
+ ts = torch.jit.trace(NMSModel(self.model, self.args) if self.args.nms else self.model, self.im, strict=False)
655
+ extra_files = {"config.txt": json.dumps(self.metadata)} # torch._C.ExtraFilesMap()
656
+ if self.args.optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
657
+ LOGGER.info(f"{prefix} optimizing for mobile...")
658
+ from torch.utils.mobile_optimizer import optimize_for_mobile
659
+
660
+ optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files)
661
+ else:
662
+ ts.save(str(f), _extra_files=extra_files)
663
+ return f
664
+
665
+ @try_export
666
+ def export_onnx(self, prefix=colorstr("ONNX:")):
667
+ """Export YOLO model to ONNX format."""
668
+ requirements = ["onnx>=1.12.0,<2.0.0"]
669
+ if self.args.simplify:
670
+ requirements += ["onnxslim>=0.1.71", "onnxruntime" + ("-gpu" if torch.cuda.is_available() else "")]
671
+ check_requirements(requirements)
672
+ import onnx
673
+
674
+ opset = self.args.opset or best_onnx_opset(onnx, cuda="cuda" in self.device.type)
675
+ LOGGER.info(f"\n{prefix} starting export with onnx {onnx.__version__} opset {opset}...")
676
+ if self.args.nms:
677
+ assert TORCH_1_13, f"'nms=True' ONNX export requires torch>=1.13 (found torch=={TORCH_VERSION})"
678
+
679
+ f = str(self.file.with_suffix(".onnx"))
680
+ output_names = ["output0", "output1"] if self.model.task == "segment" else ["output0"]
681
+ dynamic = self.args.dynamic
682
+ if dynamic:
683
+ dynamic = {"images": {0: "batch", 2: "height", 3: "width"}} # shape(1,3,640,640)
684
+ if isinstance(self.model, SegmentationModel):
685
+ dynamic["output0"] = {0: "batch", 2: "anchors"} # shape(1, 116, 8400)
686
+ dynamic["output1"] = {0: "batch", 2: "mask_height", 3: "mask_width"} # shape(1,32,160,160)
687
+ elif isinstance(self.model, DetectionModel):
688
+ dynamic["output0"] = {0: "batch", 2: "anchors"} # shape(1, 84, 8400)
689
+ if self.args.nms: # only batch size is dynamic with NMS
690
+ dynamic["output0"].pop(2)
691
+ if self.args.nms and self.model.task == "obb":
692
+ self.args.opset = opset # for NMSModel
693
+
694
+ with arange_patch(self.args):
695
+ torch2onnx(
696
+ NMSModel(self.model, self.args) if self.args.nms else self.model,
697
+ self.im,
698
+ f,
699
+ opset=opset,
700
+ input_names=["images"],
701
+ output_names=output_names,
702
+ dynamic=dynamic or None,
703
+ )
704
+
705
+ # Checks
706
+ model_onnx = onnx.load(f) # load onnx model
707
+
708
+ # Simplify
709
+ if self.args.simplify:
710
+ try:
711
+ import onnxslim
712
+
713
+ LOGGER.info(f"{prefix} slimming with onnxslim {onnxslim.__version__}...")
714
+ model_onnx = onnxslim.slim(model_onnx)
715
+
716
+ except Exception as e:
717
+ LOGGER.warning(f"{prefix} simplifier failure: {e}")
718
+
719
+ # Metadata
720
+ for k, v in self.metadata.items():
721
+ meta = model_onnx.metadata_props.add()
722
+ meta.key, meta.value = k, str(v)
723
+
724
+ # IR version
725
+ if getattr(model_onnx, "ir_version", 0) > 10:
726
+ LOGGER.info(f"{prefix} limiting IR version {model_onnx.ir_version} to 10 for ONNXRuntime compatibility...")
727
+ model_onnx.ir_version = 10
728
+
729
+ # FP16 conversion for CPU export (GPU exports are already FP16 from model.half() during tracing)
730
+ if self.args.half and self.args.format == "onnx" and self.device.type == "cpu":
731
+ try:
732
+ from onnxruntime.transformers import float16
733
+
734
+ LOGGER.info(f"{prefix} converting to FP16...")
735
+ model_onnx = float16.convert_float_to_float16(model_onnx, keep_io_types=True)
736
+ except Exception as e:
737
+ LOGGER.warning(f"{prefix} FP16 conversion failure: {e}")
738
+
739
+ onnx.save(model_onnx, f)
740
+ return f
741
+
742
+ @try_export
743
+ def export_openvino(self, prefix=colorstr("OpenVINO:")):
744
+ """Export YOLO model to OpenVINO format."""
745
+ # OpenVINO <= 2025.1.0 error on macOS 15.4+: https://github.com/openvinotoolkit/openvino/issues/30023"
746
+ check_requirements("openvino>=2025.2.0" if MACOS and MACOS_VERSION >= "15.4" else "openvino>=2024.0.0")
747
+ import openvino as ov
748
+
749
+ LOGGER.info(f"\n{prefix} starting export with openvino {ov.__version__}...")
750
+ assert TORCH_2_1, f"OpenVINO export requires torch>=2.1 but torch=={TORCH_VERSION} is installed"
751
+ ov_model = ov.convert_model(
752
+ NMSModel(self.model, self.args) if self.args.nms else self.model,
753
+ input=None if self.args.dynamic else [self.im.shape],
754
+ example_input=self.im,
755
+ )
756
+
757
+ def serialize(ov_model, file):
758
+ """Set RT info, serialize, and save metadata YAML."""
759
+ ov_model.set_rt_info("YOLO", ["model_info", "model_type"])
760
+ ov_model.set_rt_info(True, ["model_info", "reverse_input_channels"])
761
+ ov_model.set_rt_info(114, ["model_info", "pad_value"])
762
+ ov_model.set_rt_info([255.0], ["model_info", "scale_values"])
763
+ ov_model.set_rt_info(self.args.iou, ["model_info", "iou_threshold"])
764
+ ov_model.set_rt_info([v.replace(" ", "_") for v in self.model.names.values()], ["model_info", "labels"])
765
+ if self.model.task != "classify":
766
+ ov_model.set_rt_info("fit_to_window_letterbox", ["model_info", "resize_type"])
767
+
768
+ ov.save_model(ov_model, file, compress_to_fp16=self.args.half)
769
+ YAML.save(Path(file).parent / "metadata.yaml", self.metadata) # add metadata.yaml
770
+
771
+ if self.args.int8:
772
+ fq = str(self.file).replace(self.file.suffix, f"_int8_openvino_model{os.sep}")
773
+ fq_ov = str(Path(fq) / self.file.with_suffix(".xml").name)
774
+ # INT8 requires nncf, nncf requires packaging>=23.2 https://github.com/openvinotoolkit/nncf/issues/3463
775
+ check_requirements("packaging>=23.2") # must be installed first to build nncf wheel
776
+ check_requirements("nncf>=2.14.0")
777
+ import nncf
778
+
779
+ # Generate calibration data for integer quantization
780
+ ignored_scope = None
781
+ if isinstance(self.model.model[-1], Detect):
782
+ # Includes all Detect subclasses like Segment, Pose, OBB, WorldDetect, YOLOEDetect
783
+ head_module_name = ".".join(list(self.model.named_modules())[-1][0].split(".")[:2])
784
+ ignored_scope = nncf.IgnoredScope( # ignore operations
785
+ patterns=[
786
+ f".*{head_module_name}/.*/Add",
787
+ f".*{head_module_name}/.*/Sub*",
788
+ f".*{head_module_name}/.*/Mul*",
789
+ f".*{head_module_name}/.*/Div*",
790
+ f".*{head_module_name}\\.dfl.*",
791
+ ],
792
+ types=["Sigmoid"],
793
+ )
794
+
795
+ quantized_ov_model = nncf.quantize(
796
+ model=ov_model,
797
+ calibration_dataset=nncf.Dataset(self.get_int8_calibration_dataloader(prefix), self._transform_fn),
798
+ preset=nncf.QuantizationPreset.MIXED,
799
+ ignored_scope=ignored_scope,
800
+ )
801
+ serialize(quantized_ov_model, fq_ov)
802
+ return fq
803
+
804
+ f = str(self.file).replace(self.file.suffix, f"_openvino_model{os.sep}")
805
+ f_ov = str(Path(f) / self.file.with_suffix(".xml").name)
806
+
807
+ serialize(ov_model, f_ov)
808
+ return f
809
+
810
+ @try_export
811
+ def export_paddle(self, prefix=colorstr("PaddlePaddle:")):
812
+ """Export YOLO model to PaddlePaddle format."""
813
+ assert not IS_JETSON, "Jetson Paddle exports not supported yet"
814
+ check_requirements(
815
+ (
816
+ "paddlepaddle-gpu>=3.0.0,!=3.3.0" # exclude 3.3.0 https://github.com/PaddlePaddle/Paddle/issues/77340
817
+ if torch.cuda.is_available()
818
+ else "paddlepaddle==3.0.0" # pin 3.0.0 for ARM64
819
+ if ARM64
820
+ else "paddlepaddle>=3.0.0,!=3.3.0", # exclude 3.3.0 https://github.com/PaddlePaddle/Paddle/issues/77340
821
+ "x2paddle",
822
+ )
823
+ )
824
+ import x2paddle
825
+ from x2paddle.convert import pytorch2paddle
826
+
827
+ LOGGER.info(f"\n{prefix} starting export with X2Paddle {x2paddle.__version__}...")
828
+ f = str(self.file).replace(self.file.suffix, f"_paddle_model{os.sep}")
829
+
830
+ pytorch2paddle(module=self.model, save_dir=f, jit_type="trace", input_examples=[self.im]) # export
831
+ YAML.save(Path(f) / "metadata.yaml", self.metadata) # add metadata.yaml
832
+ return f
833
+
834
+ @try_export
835
+ def export_mnn(self, prefix=colorstr("MNN:")):
836
+ """Export YOLO model to MNN format using MNN https://github.com/alibaba/MNN."""
837
+ assert TORCH_1_10, "MNN export requires torch>=1.10.0 to avoid segmentation faults"
838
+ f_onnx = self.export_onnx() # get onnx model first
839
+
840
+ check_requirements("MNN>=2.9.6")
841
+ import MNN
842
+ from MNN.tools import mnnconvert
843
+
844
+ # Setup and checks
845
+ LOGGER.info(f"\n{prefix} starting export with MNN {MNN.version()}...")
846
+ assert Path(f_onnx).exists(), f"failed to export ONNX file: {f_onnx}"
847
+ f = str(self.file.with_suffix(".mnn")) # MNN model file
848
+ args = ["", "-f", "ONNX", "--modelFile", f_onnx, "--MNNModel", f, "--bizCode", json.dumps(self.metadata)]
849
+ if self.args.int8:
850
+ args.extend(("--weightQuantBits", "8"))
851
+ if self.args.half:
852
+ args.append("--fp16")
853
+ mnnconvert.convert(args)
854
+ # remove scratch file for model convert optimize
855
+ convert_scratch = Path(self.file.parent / ".__convert_external_data.bin")
856
+ if convert_scratch.exists():
857
+ convert_scratch.unlink()
858
+ return f
859
+
860
+ @try_export
861
+ def export_ncnn(self, prefix=colorstr("NCNN:")):
862
+ """Export YOLO model to NCNN format using PNNX https://github.com/pnnx/pnnx."""
863
+ # use git source for ARM64 due to broken PyPI packages https://github.com/Tencent/ncnn/issues/6509
864
+ check_requirements("git+https://github.com/Tencent/ncnn.git" if ARM64 else "ncnn", cmds="--no-deps")
865
+ check_requirements("pnnx")
866
+ import ncnn
867
+ import pnnx
868
+
869
+ LOGGER.info(f"\n{prefix} starting export with NCNN {ncnn.__version__} and PNNX {pnnx.__version__}...")
870
+ f = Path(str(self.file).replace(self.file.suffix, f"_ncnn_model{os.sep}"))
871
+
872
+ ncnn_args = dict(
873
+ ncnnparam=(f / "model.ncnn.param").as_posix(),
874
+ ncnnbin=(f / "model.ncnn.bin").as_posix(),
875
+ ncnnpy=(f / "model_ncnn.py").as_posix(),
876
+ )
877
+
878
+ pnnx_args = dict(
879
+ ptpath=(f / "model.pt").as_posix(),
880
+ pnnxparam=(f / "model.pnnx.param").as_posix(),
881
+ pnnxbin=(f / "model.pnnx.bin").as_posix(),
882
+ pnnxpy=(f / "model_pnnx.py").as_posix(),
883
+ pnnxonnx=(f / "model.pnnx.onnx").as_posix(),
884
+ )
885
+
886
+ f.mkdir(exist_ok=True) # make ncnn_model directory
887
+ pnnx.export(self.model, inputs=self.im, **ncnn_args, **pnnx_args, fp16=self.args.half, device=self.device.type)
888
+
889
+ for f_debug in ("debug.bin", "debug.param", "debug2.bin", "debug2.param", *pnnx_args.values()):
890
+ Path(f_debug).unlink(missing_ok=True)
891
+
892
+ YAML.save(f / "metadata.yaml", self.metadata) # add metadata.yaml
893
+ return str(f)
894
+
895
+ @try_export
896
+ def export_coreml(self, prefix=colorstr("CoreML:")):
897
+ """Export YOLO model to CoreML format."""
898
+ mlmodel = self.args.format.lower() == "mlmodel" # legacy *.mlmodel export format requested
899
+ check_requirements(
900
+ ["coremltools>=9.0", "numpy>=1.14.5,<=2.3.5"]
901
+ ) # latest numpy 2.4.0rc1 breaks coremltools exports
902
+ import coremltools as ct
903
+
904
+ LOGGER.info(f"\n{prefix} starting export with coremltools {ct.__version__}...")
905
+ assert not WINDOWS, "CoreML export is not supported on Windows, please run on macOS or Linux."
906
+ assert TORCH_1_11, "CoreML export requires torch>=1.11"
907
+ if self.args.batch > 1:
908
+ assert self.args.dynamic, (
909
+ "batch sizes > 1 are not supported without 'dynamic=True' for CoreML export. Please retry at 'dynamic=True'."
910
+ )
911
+ if self.args.dynamic:
912
+ assert not self.args.nms, (
913
+ "'nms=True' cannot be used together with 'dynamic=True' for CoreML export. Please disable one of them."
914
+ )
915
+ assert self.model.task != "classify", "'dynamic=True' is not supported for CoreML classification models."
916
+ f = self.file.with_suffix(".mlmodel" if mlmodel else ".mlpackage")
917
+ if f.is_dir():
918
+ shutil.rmtree(f)
919
+
920
+ classifier_config = None
921
+ if self.model.task == "classify":
922
+ classifier_config = ct.ClassifierConfig(list(self.model.names.values()))
923
+ model = self.model
924
+ elif self.model.task == "detect":
925
+ model = IOSDetectModel(self.model, self.im, mlprogram=not mlmodel) if self.args.nms else self.model
926
+ else:
927
+ if self.args.nms:
928
+ LOGGER.warning(f"{prefix} 'nms=True' is only available for Detect models like 'yolo11n.pt'.")
929
+ # TODO CoreML Segment and Pose model pipelining
930
+ model = self.model
931
+ ts = torch.jit.trace(model.eval(), self.im, strict=False) # TorchScript model
932
+
933
+ if self.args.dynamic:
934
+ input_shape = ct.Shape(
935
+ shape=(
936
+ ct.RangeDim(lower_bound=1, upper_bound=self.args.batch, default=1),
937
+ self.im.shape[1],
938
+ ct.RangeDim(lower_bound=32, upper_bound=self.imgsz[0] * 2, default=self.imgsz[0]),
939
+ ct.RangeDim(lower_bound=32, upper_bound=self.imgsz[1] * 2, default=self.imgsz[1]),
940
+ )
941
+ )
942
+ inputs = [ct.TensorType("image", shape=input_shape)]
943
+ else:
944
+ inputs = [ct.ImageType("image", shape=self.im.shape, scale=1 / 255, bias=[0.0, 0.0, 0.0])]
945
+
946
+ # Based on apple's documentation it is better to leave out the minimum_deployment target and let that get set
947
+ # Internally based on the model conversion and output type.
948
+ # Setting minimum_deployment_target >= iOS16 will require setting compute_precision=ct.precision.FLOAT32.
949
+ # iOS16 adds in better support for FP16, but none of the CoreML NMS specifications handle FP16 as input.
950
+ ct_model = ct.convert(
951
+ ts,
952
+ inputs=inputs,
953
+ classifier_config=classifier_config,
954
+ convert_to="neuralnetwork" if mlmodel else "mlprogram",
955
+ )
956
+ bits, mode = (8, "kmeans") if self.args.int8 else (16, "linear") if self.args.half else (32, None)
957
+ if bits < 32:
958
+ if "kmeans" in mode:
959
+ check_requirements("scikit-learn") # scikit-learn package required for k-means quantization
960
+ if mlmodel:
961
+ ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
962
+ elif bits == 8: # mlprogram already quantized to FP16
963
+ import coremltools.optimize.coreml as cto
964
+
965
+ op_config = cto.OpPalettizerConfig(mode="kmeans", nbits=bits, weight_threshold=512)
966
+ config = cto.OptimizationConfig(global_config=op_config)
967
+ ct_model = cto.palettize_weights(ct_model, config=config)
968
+ if self.args.nms and self.model.task == "detect":
969
+ ct_model = self._pipeline_coreml(ct_model, weights_dir=None if mlmodel else ct_model.weights_dir)
970
+
971
+ m = self.metadata # metadata dict
972
+ ct_model.short_description = m.pop("description")
973
+ ct_model.author = m.pop("author")
974
+ ct_model.license = m.pop("license")
975
+ ct_model.version = m.pop("version")
976
+ ct_model.user_defined_metadata.update({k: str(v) for k, v in m.items()})
977
+ if self.model.task == "classify":
978
+ ct_model.user_defined_metadata.update({"com.apple.coreml.model.preview.type": "imageClassifier"})
979
+
980
+ try:
981
+ ct_model.save(str(f)) # save *.mlpackage
982
+ except Exception as e:
983
+ LOGGER.warning(
984
+ f"{prefix} CoreML export to *.mlpackage failed ({e}), reverting to *.mlmodel export. "
985
+ f"Known coremltools Python 3.11 and Windows bugs https://github.com/apple/coremltools/issues/1928."
986
+ )
987
+ f = f.with_suffix(".mlmodel")
988
+ ct_model.save(str(f))
989
+ return f
990
+
991
+ @try_export
992
+ def export_engine(self, dla=None, prefix=colorstr("TensorRT:")):
993
+ """Export YOLO model to TensorRT format https://developer.nvidia.com/tensorrt."""
994
+ assert self.im.device.type != "cpu", "export running on CPU but must be on GPU, i.e. use 'device=0'"
995
+ f_onnx = self.export_onnx() # run before TRT import https://github.com/ultralytics/ultralytics/issues/7016
996
+
997
+ try:
998
+ import tensorrt as trt
999
+ except ImportError:
1000
+ if LINUX:
1001
+ cuda_version = torch.version.cuda.split(".")[0]
1002
+ check_requirements(f"tensorrt-cu{cuda_version}>7.0.0,!=10.1.0")
1003
+ import tensorrt as trt
1004
+ check_version(trt.__version__, ">=7.0.0", hard=True)
1005
+ check_version(trt.__version__, "!=10.1.0", msg="https://github.com/ultralytics/ultralytics/pull/14239")
1006
+
1007
+ # Setup and checks
1008
+ LOGGER.info(f"\n{prefix} starting export with TensorRT {trt.__version__}...")
1009
+ assert Path(f_onnx).exists(), f"failed to export ONNX file: {f_onnx}"
1010
+ f = self.file.with_suffix(".engine") # TensorRT engine file
1011
+ onnx2engine(
1012
+ f_onnx,
1013
+ f,
1014
+ self.args.workspace,
1015
+ self.args.half,
1016
+ self.args.int8,
1017
+ self.args.dynamic,
1018
+ self.im.shape,
1019
+ dla=dla,
1020
+ dataset=self.get_int8_calibration_dataloader(prefix) if self.args.int8 else None,
1021
+ metadata=self.metadata,
1022
+ verbose=self.args.verbose,
1023
+ prefix=prefix,
1024
+ )
1025
+
1026
+ return f
1027
+
1028
+ @try_export
1029
+ def export_saved_model(self, prefix=colorstr("TensorFlow SavedModel:")):
1030
+ """Export YOLO model to TensorFlow SavedModel format."""
1031
+ cuda = torch.cuda.is_available()
1032
+ try:
1033
+ import tensorflow as tf
1034
+ except ImportError:
1035
+ check_requirements("tensorflow>=2.0.0,<=2.19.0")
1036
+ import tensorflow as tf
1037
+ check_requirements(
1038
+ (
1039
+ "tf_keras<=2.19.0", # required by 'onnx2tf' package
1040
+ "sng4onnx>=1.0.1", # required by 'onnx2tf' package
1041
+ "onnx_graphsurgeon>=0.3.26", # required by 'onnx2tf' package
1042
+ "ai-edge-litert>=1.2.0" + (",<1.4.0" if MACOS else ""), # required by 'onnx2tf' package
1043
+ "onnx>=1.12.0,<2.0.0",
1044
+ "onnx2tf>=1.26.3",
1045
+ "onnxslim>=0.1.71",
1046
+ "onnxruntime-gpu" if cuda else "onnxruntime",
1047
+ "protobuf>=5",
1048
+ ),
1049
+ cmds="--extra-index-url https://pypi.ngc.nvidia.com", # onnx_graphsurgeon only on NVIDIA
1050
+ )
1051
+
1052
+ LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
1053
+ check_version(
1054
+ tf.__version__,
1055
+ ">=2.0.0",
1056
+ name="tensorflow",
1057
+ verbose=True,
1058
+ msg="https://github.com/ultralytics/ultralytics/issues/5161",
1059
+ )
1060
+ f = Path(str(self.file).replace(self.file.suffix, "_saved_model"))
1061
+ if f.is_dir():
1062
+ shutil.rmtree(f) # delete output folder
1063
+
1064
+ # Export to TF
1065
+ images = None
1066
+ if self.args.int8 and self.args.data:
1067
+ images = [batch["img"] for batch in self.get_int8_calibration_dataloader(prefix)]
1068
+ images = (
1069
+ torch.nn.functional.interpolate(torch.cat(images, 0).float(), size=self.imgsz)
1070
+ .permute(0, 2, 3, 1)
1071
+ .numpy()
1072
+ .astype(np.float32)
1073
+ )
1074
+
1075
+ # Export to ONNX
1076
+ if isinstance(self.model.model[-1], RTDETRDecoder):
1077
+ self.args.opset = self.args.opset or 19
1078
+ assert 16 <= self.args.opset <= 19, "RTDETR export requires opset>=16;<=19"
1079
+ self.args.simplify = True
1080
+ f_onnx = self.export_onnx() # ensure ONNX is available
1081
+ keras_model = onnx2saved_model(
1082
+ f_onnx,
1083
+ f,
1084
+ int8=self.args.int8,
1085
+ images=images,
1086
+ disable_group_convolution=self.args.format in {"tfjs", "edgetpu"},
1087
+ prefix=prefix,
1088
+ )
1089
+ YAML.save(f / "metadata.yaml", self.metadata) # add metadata.yaml
1090
+ # Add TFLite metadata
1091
+ for file in f.rglob("*.tflite"):
1092
+ file.unlink() if "quant_with_int16_act.tflite" in str(file) else self._add_tflite_metadata(file)
1093
+
1094
+ return str(f), keras_model # or keras_model = tf.saved_model.load(f, tags=None, options=None)
1095
+
1096
+ @try_export
1097
+ def export_pb(self, keras_model, prefix=colorstr("TensorFlow GraphDef:")):
1098
+ """Export YOLO model to TensorFlow GraphDef *.pb format https://github.com/leimao/Frozen-Graph-TensorFlow."""
1099
+ f = self.file.with_suffix(".pb")
1100
+ keras2pb(keras_model, f, prefix)
1101
+ return f
1102
+
1103
+ @try_export
1104
+ def export_tflite(self, prefix=colorstr("TensorFlow Lite:")):
1105
+ """Export YOLO model to TensorFlow Lite format."""
1106
+ # BUG https://github.com/ultralytics/ultralytics/issues/13436
1107
+ import tensorflow as tf
1108
+
1109
+ LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
1110
+ saved_model = Path(str(self.file).replace(self.file.suffix, "_saved_model"))
1111
+ if self.args.int8:
1112
+ f = saved_model / f"{self.file.stem}_int8.tflite" # fp32 in/out
1113
+ elif self.args.half:
1114
+ f = saved_model / f"{self.file.stem}_float16.tflite" # fp32 in/out
1115
+ else:
1116
+ f = saved_model / f"{self.file.stem}_float32.tflite"
1117
+ return str(f)
1118
+
1119
+ @try_export
1120
+ def export_axelera(self, prefix=colorstr("Axelera:")):
1121
+ """YOLO Axelera export."""
1122
+ os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
1123
+ try:
1124
+ from axelera import compiler
1125
+ except ImportError:
1126
+ check_apt_requirements(
1127
+ ["libllvm14", "libgirepository1.0-dev", "pkg-config", "libcairo2-dev", "build-essential", "cmake"]
1128
+ )
1129
+
1130
+ check_requirements(
1131
+ "axelera-voyager-sdk==1.5.2",
1132
+ cmds="--extra-index-url https://software.axelera.ai/artifactory/axelera-runtime-pypi "
1133
+ "--extra-index-url https://software.axelera.ai/artifactory/axelera-dev-pypi",
1134
+ )
1135
+
1136
+ from axelera import compiler
1137
+ from axelera.compiler import CompilerConfig
1138
+
1139
+ self.args.opset = 17 # hardcode opset for Axelera
1140
+ onnx_path = self.export_onnx()
1141
+ model_name = Path(onnx_path).stem
1142
+ export_path = Path(f"{model_name}_axelera_model")
1143
+ export_path.mkdir(exist_ok=True)
1144
+
1145
+ if "C2PSA" in self.model.__str__(): # YOLO11
1146
+ config = CompilerConfig(
1147
+ quantization_scheme="per_tensor_min_max",
1148
+ ignore_weight_buffers=False,
1149
+ resources_used=0.25,
1150
+ aipu_cores_used=1,
1151
+ multicore_mode="batch",
1152
+ output_axm_format=True,
1153
+ model_name=model_name,
1154
+ )
1155
+ else: # YOLOv8
1156
+ config = CompilerConfig(
1157
+ tiling_depth=6,
1158
+ split_buffer_promotion=True,
1159
+ resources_used=0.25,
1160
+ aipu_cores_used=1,
1161
+ multicore_mode="batch",
1162
+ output_axm_format=True,
1163
+ model_name=model_name,
1164
+ )
1165
+
1166
+ qmodel = compiler.quantize(
1167
+ model=onnx_path,
1168
+ calibration_dataset=self.get_int8_calibration_dataloader(prefix),
1169
+ config=config,
1170
+ transform_fn=self._transform_fn,
1171
+ )
1172
+
1173
+ compiler.compile(model=qmodel, config=config, output_dir=export_path)
1174
+
1175
+ axm_name = f"{model_name}.axm"
1176
+ axm_src = Path(axm_name)
1177
+ axm_dst = export_path / axm_name
1178
+
1179
+ if axm_src.exists():
1180
+ axm_src.replace(axm_dst)
1181
+
1182
+ YAML.save(export_path / "metadata.yaml", self.metadata)
1183
+
1184
+ return export_path
1185
+
1186
+ @try_export
1187
+ def export_executorch(self, prefix=colorstr("ExecuTorch:")):
1188
+ """Exports a model to ExecuTorch (.pte) format into a dedicated directory and saves the required metadata,
1189
+ following Ultralytics conventions.
1190
+ """
1191
+ LOGGER.info(f"\n{prefix} starting export with ExecuTorch...")
1192
+ assert TORCH_2_9, f"ExecuTorch export requires torch>=2.9.0 but torch=={TORCH_VERSION} is installed"
1193
+
1194
+ # BUG executorch build on arm64 Docker requires packaging>=22.0 https://github.com/pypa/setuptools/issues/4483
1195
+ if LINUX and ARM64 and IS_DOCKER:
1196
+ check_requirements("packaging>=22.0")
1197
+
1198
+ check_requirements("ruamel.yaml<0.19.0")
1199
+ check_requirements("executorch==1.0.1", "flatbuffers")
1200
+ # Pin numpy to avoid coremltools errors with numpy>=2.4.0, must be separate
1201
+ check_requirements("numpy<=2.3.5")
1202
+
1203
+ from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
1204
+ from executorch.exir import to_edge_transform_and_lower
1205
+
1206
+ file_directory = Path(str(self.file).replace(self.file.suffix, "_executorch_model"))
1207
+ file_directory.mkdir(parents=True, exist_ok=True)
1208
+
1209
+ file_pte = file_directory / self.file.with_suffix(".pte").name
1210
+ sample_inputs = (self.im,)
1211
+
1212
+ et_program = to_edge_transform_and_lower(
1213
+ torch.export.export(self.model, sample_inputs), partitioner=[XnnpackPartitioner()]
1214
+ ).to_executorch()
1215
+
1216
+ with open(file_pte, "wb") as file:
1217
+ file.write(et_program.buffer)
1218
+
1219
+ YAML.save(file_directory / "metadata.yaml", self.metadata)
1220
+
1221
+ return str(file_directory)
1222
+
1223
+ @try_export
1224
+ def export_edgetpu(self, tflite_model="", prefix=colorstr("Edge TPU:")):
1225
+ """Export YOLO model to Edge TPU format https://coral.ai/docs/edgetpu/models-intro/."""
1226
+ cmd = "edgetpu_compiler --version"
1227
+ help_url = "https://coral.ai/docs/edgetpu/compiler/"
1228
+ assert LINUX, f"export only supported on Linux. See {help_url}"
1229
+ if subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True).returncode != 0:
1230
+ LOGGER.info(f"\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}")
1231
+ sudo = "sudo " if is_sudo_available() else ""
1232
+ for c in (
1233
+ f"{sudo}mkdir -p /etc/apt/keyrings",
1234
+ f"curl -fsSL https://packages.cloud.google.com/apt/doc/apt-key.gpg | {sudo}gpg --dearmor -o /etc/apt/keyrings/google.gpg",
1235
+ f'echo "deb [signed-by=/etc/apt/keyrings/google.gpg] https://packages.cloud.google.com/apt coral-edgetpu-stable main" | {sudo}tee /etc/apt/sources.list.d/coral-edgetpu.list',
1236
+ ):
1237
+ subprocess.run(c, shell=True, check=True)
1238
+ check_apt_requirements(["edgetpu-compiler"])
1239
+
1240
+ ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().rsplit(maxsplit=1)[-1]
1241
+ LOGGER.info(f"\n{prefix} starting export with Edge TPU compiler {ver}...")
1242
+ tflite2edgetpu(tflite_file=tflite_model, output_dir=tflite_model.parent, prefix=prefix)
1243
+ f = str(tflite_model).replace(".tflite", "_edgetpu.tflite") # Edge TPU model
1244
+ self._add_tflite_metadata(f)
1245
+ return f
1246
+
1247
+ @try_export
1248
+ def export_tfjs(self, prefix=colorstr("TensorFlow.js:")):
1249
+ """Export YOLO model to TensorFlow.js format."""
1250
+ check_requirements("tensorflowjs")
1251
+
1252
+ f = str(self.file).replace(self.file.suffix, "_web_model") # js dir
1253
+ f_pb = str(self.file.with_suffix(".pb")) # *.pb path
1254
+ pb2tfjs(pb_file=f_pb, output_dir=f, half=self.args.half, int8=self.args.int8, prefix=prefix)
1255
+ # Add metadata
1256
+ YAML.save(Path(f) / "metadata.yaml", self.metadata) # add metadata.yaml
1257
+ return f
1258
+
1259
+ @try_export
1260
+ def export_rknn(self, prefix=colorstr("RKNN:")):
1261
+ """Export YOLO model to RKNN format."""
1262
+ LOGGER.info(f"\n{prefix} starting export with rknn-toolkit2...")
1263
+
1264
+ check_requirements("rknn-toolkit2")
1265
+ if IS_COLAB:
1266
+ # Prevent 'exit' from closing the notebook https://github.com/airockchip/rknn-toolkit2/issues/259
1267
+ import builtins
1268
+
1269
+ builtins.exit = lambda: None
1270
+
1271
+ from rknn.api import RKNN
1272
+
1273
+ f = self.export_onnx()
1274
+ export_path = Path(f"{Path(f).stem}_rknn_model")
1275
+ export_path.mkdir(exist_ok=True)
1276
+
1277
+ rknn = RKNN(verbose=False)
1278
+ rknn.config(mean_values=[[0, 0, 0]], std_values=[[255, 255, 255]], target_platform=self.args.name)
1279
+ rknn.load_onnx(model=f)
1280
+ rknn.build(do_quantization=False) # TODO: Add quantization support
1281
+ f = f.replace(".onnx", f"-{self.args.name}.rknn")
1282
+ rknn.export_rknn(f"{export_path / f}")
1283
+ YAML.save(export_path / "metadata.yaml", self.metadata)
1284
+ return export_path
1285
+
1286
+ @try_export
1287
+ def export_imx(self, prefix=colorstr("IMX:")):
1288
+ """Export YOLO model to IMX format."""
1289
+ assert LINUX, (
1290
+ "Export only supported on Linux."
1291
+ "See https://developer.aitrios.sony-semicon.com/en/docs/raspberry-pi-ai-camera/imx500-converter?version=3.17.3&progLang="
1292
+ )
1293
+ assert not ARM64, "IMX export is not supported on ARM64 architectures."
1294
+ assert IS_PYTHON_MINIMUM_3_9, "IMX export is only supported on Python 3.9 or above."
1295
+
1296
+ if getattr(self.model, "end2end", False):
1297
+ raise ValueError("IMX export is not supported for end2end models.")
1298
+ check_requirements(
1299
+ (
1300
+ "model-compression-toolkit>=2.4.1",
1301
+ "edge-mdt-cl<1.1.0",
1302
+ "edge-mdt-tpc>=1.2.0",
1303
+ "pydantic<=2.11.7",
1304
+ )
1305
+ )
1306
+
1307
+ check_requirements("imx500-converter[pt]>=3.17.3")
1308
+
1309
+ # Install Java>=17
1310
+ try:
1311
+ java_output = subprocess.run(["java", "--version"], check=True, capture_output=True).stdout.decode()
1312
+ version_match = re.search(r"(?:openjdk|java) (\d+)", java_output)
1313
+ java_version = int(version_match.group(1)) if version_match else 0
1314
+ assert java_version >= 17, "Java version too old"
1315
+ except (FileNotFoundError, subprocess.CalledProcessError, AssertionError):
1316
+ if IS_UBUNTU or IS_DEBIAN_TRIXIE:
1317
+ LOGGER.info(f"\n{prefix} installing Java 21 for Ubuntu...")
1318
+ check_apt_requirements(["openjdk-21-jre"])
1319
+ elif IS_RASPBERRYPI or IS_DEBIAN_BOOKWORM:
1320
+ LOGGER.info(f"\n{prefix} installing Java 17 for Raspberry Pi or Debian ...")
1321
+ check_apt_requirements(["openjdk-17-jre"])
1322
+
1323
+ return torch2imx(
1324
+ self.model,
1325
+ self.file,
1326
+ self.args.conf,
1327
+ self.args.iou,
1328
+ self.args.max_det,
1329
+ metadata=self.metadata,
1330
+ dataset=self.get_int8_calibration_dataloader(prefix),
1331
+ prefix=prefix,
1332
+ )
1333
+
1334
+ def _add_tflite_metadata(self, file):
1335
+ """Add metadata to *.tflite models per https://ai.google.dev/edge/litert/models/metadata."""
1336
+ import zipfile
1337
+
1338
+ with zipfile.ZipFile(file, "a", zipfile.ZIP_DEFLATED) as zf:
1339
+ zf.writestr("metadata.json", json.dumps(self.metadata, indent=2))
1340
+
1341
+ def _pipeline_coreml(self, model, weights_dir=None, prefix=colorstr("CoreML Pipeline:")):
1342
+ """Create CoreML pipeline with NMS for YOLO detection models."""
1343
+ import coremltools as ct
1344
+
1345
+ LOGGER.info(f"{prefix} starting pipeline with coremltools {ct.__version__}...")
1346
+
1347
+ # Output shapes
1348
+ spec = model.get_spec()
1349
+ outs = list(iter(spec.description.output))
1350
+ if self.args.format == "mlmodel": # mlmodel doesn't infer shapes automatically
1351
+ outs[0].type.multiArrayType.shape[:] = self.output_shape[2], self.output_shape[1] - 4
1352
+ outs[1].type.multiArrayType.shape[:] = self.output_shape[2], 4
1353
+
1354
+ # Checks
1355
+ names = self.metadata["names"]
1356
+ nx, ny = spec.description.input[0].type.imageType.width, spec.description.input[0].type.imageType.height
1357
+ nc = outs[0].type.multiArrayType.shape[-1]
1358
+ if len(names) != nc: # Hack fix for MLProgram NMS bug https://github.com/ultralytics/ultralytics/issues/22309
1359
+ names = {**names, **{i: str(i) for i in range(len(names), nc)}}
1360
+
1361
+ # Model from spec
1362
+ model = ct.models.MLModel(spec, weights_dir=weights_dir)
1363
+
1364
+ # Create NMS protobuf
1365
+ nms_spec = ct.proto.Model_pb2.Model()
1366
+ nms_spec.specificationVersion = spec.specificationVersion
1367
+ for i in range(len(outs)):
1368
+ decoder_output = model._spec.description.output[i].SerializeToString()
1369
+ nms_spec.description.input.add()
1370
+ nms_spec.description.input[i].ParseFromString(decoder_output)
1371
+ nms_spec.description.output.add()
1372
+ nms_spec.description.output[i].ParseFromString(decoder_output)
1373
+
1374
+ output_names = ["confidence", "coordinates"]
1375
+ for i, name in enumerate(output_names):
1376
+ nms_spec.description.output[i].name = name
1377
+
1378
+ for i, out in enumerate(outs):
1379
+ ma_type = nms_spec.description.output[i].type.multiArrayType
1380
+ ma_type.shapeRange.sizeRanges.add()
1381
+ ma_type.shapeRange.sizeRanges[0].lowerBound = 0
1382
+ ma_type.shapeRange.sizeRanges[0].upperBound = -1
1383
+ ma_type.shapeRange.sizeRanges.add()
1384
+ ma_type.shapeRange.sizeRanges[1].lowerBound = out.type.multiArrayType.shape[-1]
1385
+ ma_type.shapeRange.sizeRanges[1].upperBound = out.type.multiArrayType.shape[-1]
1386
+ del ma_type.shape[:]
1387
+
1388
+ nms = nms_spec.nonMaximumSuppression
1389
+ nms.confidenceInputFeatureName = outs[0].name # 1x507x80
1390
+ nms.coordinatesInputFeatureName = outs[1].name # 1x507x4
1391
+ nms.confidenceOutputFeatureName = output_names[0]
1392
+ nms.coordinatesOutputFeatureName = output_names[1]
1393
+ nms.iouThresholdInputFeatureName = "iouThreshold"
1394
+ nms.confidenceThresholdInputFeatureName = "confidenceThreshold"
1395
+ nms.iouThreshold = self.args.iou
1396
+ nms.confidenceThreshold = self.args.conf
1397
+ nms.pickTop.perClass = True
1398
+ nms.stringClassLabels.vector.extend(names.values())
1399
+ nms_model = ct.models.MLModel(nms_spec)
1400
+
1401
+ # Pipeline models together
1402
+ pipeline = ct.models.pipeline.Pipeline(
1403
+ input_features=[
1404
+ ("image", ct.models.datatypes.Array(3, ny, nx)),
1405
+ ("iouThreshold", ct.models.datatypes.Double()),
1406
+ ("confidenceThreshold", ct.models.datatypes.Double()),
1407
+ ],
1408
+ output_features=output_names,
1409
+ )
1410
+ pipeline.add_model(model)
1411
+ pipeline.add_model(nms_model)
1412
+
1413
+ # Correct datatypes
1414
+ pipeline.spec.description.input[0].ParseFromString(model._spec.description.input[0].SerializeToString())
1415
+ pipeline.spec.description.output[0].ParseFromString(nms_model._spec.description.output[0].SerializeToString())
1416
+ pipeline.spec.description.output[1].ParseFromString(nms_model._spec.description.output[1].SerializeToString())
1417
+
1418
+ # Update metadata
1419
+ pipeline.spec.specificationVersion = spec.specificationVersion
1420
+ pipeline.spec.description.metadata.userDefined.update(
1421
+ {"IoU threshold": str(nms.iouThreshold), "Confidence threshold": str(nms.confidenceThreshold)}
1422
+ )
1423
+
1424
+ # Save the model
1425
+ model = ct.models.MLModel(pipeline.spec, weights_dir=weights_dir)
1426
+ model.input_description["image"] = "Input image"
1427
+ model.input_description["iouThreshold"] = f"(optional) IoU threshold override (default: {nms.iouThreshold})"
1428
+ model.input_description["confidenceThreshold"] = (
1429
+ f"(optional) Confidence threshold override (default: {nms.confidenceThreshold})"
1430
+ )
1431
+ model.output_description["confidence"] = 'Boxes × Class confidence (see user-defined metadata "classes")'
1432
+ model.output_description["coordinates"] = "Boxes × [x, y, width, height] (relative to image size)"
1433
+ LOGGER.info(f"{prefix} pipeline success")
1434
+ return model
1435
+
1436
+ @staticmethod
1437
+ def _transform_fn(data_item) -> np.ndarray:
1438
+ """The transformation function for Axelera/OpenVINO quantization preprocessing."""
1439
+ data_item: torch.Tensor = data_item["img"] if isinstance(data_item, dict) else data_item
1440
+ assert data_item.dtype == torch.uint8, "Input image must be uint8 for the quantization preprocessing"
1441
+ im = data_item.numpy().astype(np.float32) / 255.0 # uint8 to fp16/32 and 0 - 255 to 0.0 - 1.0
1442
+ return im[None] if im.ndim == 3 else im
1443
+
1444
+ def add_callback(self, event: str, callback):
1445
+ """Append the given callback to the specified event."""
1446
+ self.callbacks[event].append(callback)
1447
+
1448
+ def run_callbacks(self, event: str):
1449
+ """Execute all callbacks for a given event."""
1450
+ for callback in self.callbacks.get(event, []):
1451
+ callback(self)
1452
+
1453
+
1454
+ class IOSDetectModel(torch.nn.Module):
1455
+ """Wrap an Ultralytics YOLO model for Apple iOS CoreML export."""
1456
+
1457
+ def __init__(self, model, im, mlprogram=True):
1458
+ """Initialize the IOSDetectModel class with a YOLO model and example image.
1459
+
1460
+ Args:
1461
+ model (torch.nn.Module): The YOLO model to wrap.
1462
+ im (torch.Tensor): Example input tensor with shape (B, C, H, W).
1463
+ mlprogram (bool): Whether exporting to MLProgram format to fix NMS bug.
1464
+ """
1465
+ super().__init__()
1466
+ _, _, h, w = im.shape # batch, channel, height, width
1467
+ self.model = model
1468
+ self.nc = len(model.names) # number of classes
1469
+ self.mlprogram = mlprogram
1470
+ if w == h:
1471
+ self.normalize = 1.0 / w # scalar
1472
+ else:
1473
+ self.normalize = torch.tensor(
1474
+ [1.0 / w, 1.0 / h, 1.0 / w, 1.0 / h], # broadcast (slower, smaller)
1475
+ device=next(model.parameters()).device,
1476
+ )
1477
+
1478
+ def forward(self, x):
1479
+ """Normalize predictions of object detection model with input size-dependent factors."""
1480
+ xywh, cls = self.model(x)[0].transpose(0, 1).split((4, self.nc), 1)
1481
+ if self.mlprogram and self.nc % 80 != 0: # NMS bug https://github.com/ultralytics/ultralytics/issues/22309
1482
+ pad_length = int(((self.nc + 79) // 80) * 80) - self.nc # pad class length to multiple of 80
1483
+ cls = torch.nn.functional.pad(cls, (0, pad_length, 0, 0), "constant", 0)
1484
+
1485
+ return cls, xywh * self.normalize
1486
+
1487
+
1488
+ class NMSModel(torch.nn.Module):
1489
+ """Model wrapper with embedded NMS for Detect, Segment, Pose and OBB."""
1490
+
1491
+ def __init__(self, model, args):
1492
+ """Initialize the NMSModel.
1493
+
1494
+ Args:
1495
+ model (torch.nn.Module): The model to wrap with NMS postprocessing.
1496
+ args (Namespace): The export arguments.
1497
+ """
1498
+ super().__init__()
1499
+ self.model = model
1500
+ self.args = args
1501
+ self.obb = model.task == "obb"
1502
+ self.is_tf = self.args.format in frozenset({"saved_model", "tflite", "tfjs"})
1503
+
1504
+ def forward(self, x):
1505
+ """Perform inference with NMS post-processing. Supports Detect, Segment, OBB and Pose.
1506
+
1507
+ Args:
1508
+ x (torch.Tensor): The preprocessed tensor with shape (N, 3, H, W).
1509
+
1510
+ Returns:
1511
+ (torch.Tensor): List of detections, each an (N, max_det, 4 + 2 + extra_shape) Tensor where N is the number
1512
+ of detections after NMS.
1513
+ """
1514
+ from functools import partial
1515
+
1516
+ from torchvision.ops import nms
1517
+
1518
+ preds = self.model(x)
1519
+ pred = preds[0] if isinstance(preds, tuple) else preds
1520
+ kwargs = dict(device=pred.device, dtype=pred.dtype)
1521
+ bs = pred.shape[0]
1522
+ pred = pred.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
1523
+ extra_shape = pred.shape[-1] - (4 + len(self.model.names)) # extras from Segment, OBB, Pose
1524
+ if self.args.dynamic and self.args.batch > 1: # batch size needs to always be same due to loop unroll
1525
+ pad = torch.zeros(torch.max(torch.tensor(self.args.batch - bs), torch.tensor(0)), *pred.shape[1:], **kwargs)
1526
+ pred = torch.cat((pred, pad))
1527
+ boxes, scores, extras = pred.split([4, len(self.model.names), extra_shape], dim=2)
1528
+ scores, classes = scores.max(dim=-1)
1529
+ self.args.max_det = min(pred.shape[1], self.args.max_det) # in case num_anchors < max_det
1530
+ # (N, max_det, 4 coords + 1 class score + 1 class label + extra_shape).
1531
+ out = torch.zeros(pred.shape[0], self.args.max_det, boxes.shape[-1] + 2 + extra_shape, **kwargs)
1532
+ for i in range(bs):
1533
+ box, cls, score, extra = boxes[i], classes[i], scores[i], extras[i]
1534
+ mask = score > self.args.conf
1535
+ if self.is_tf or (self.args.format == "onnx" and self.obb):
1536
+ # TFLite GatherND error if mask is empty
1537
+ score *= mask
1538
+ # Explicit length otherwise reshape error, hardcoded to `self.args.max_det * 5`
1539
+ mask = score.topk(min(self.args.max_det * 5, score.shape[0])).indices
1540
+ box, score, cls, extra = box[mask], score[mask], cls[mask], extra[mask]
1541
+ nmsbox = box.clone()
1542
+ # `8` is the minimum value experimented to get correct NMS results for obb
1543
+ multiplier = 8 if self.obb else 1 / max(len(self.model.names), 1)
1544
+ # Normalize boxes for NMS since large values for class offset causes issue with int8 quantization
1545
+ if self.args.format == "tflite": # TFLite is already normalized
1546
+ nmsbox *= multiplier
1547
+ else:
1548
+ nmsbox = multiplier * (nmsbox / torch.tensor(x.shape[2:], **kwargs).max())
1549
+ if not self.args.agnostic_nms: # class-wise NMS
1550
+ end = 2 if self.obb else 4
1551
+ # fully explicit expansion otherwise reshape error
1552
+ cls_offset = cls.view(cls.shape[0], 1).expand(cls.shape[0], end)
1553
+ offbox = nmsbox[:, :end] + cls_offset * multiplier
1554
+ nmsbox = torch.cat((offbox, nmsbox[:, end:]), dim=-1)
1555
+ nms_fn = (
1556
+ partial(
1557
+ TorchNMS.fast_nms,
1558
+ use_triu=not (
1559
+ self.is_tf
1560
+ or (self.args.opset or 14) < 14
1561
+ or (self.args.format == "openvino" and self.args.int8) # OpenVINO int8 error with triu
1562
+ ),
1563
+ iou_func=batch_probiou,
1564
+ exit_early=False,
1565
+ )
1566
+ if self.obb
1567
+ else nms
1568
+ )
1569
+ keep = nms_fn(
1570
+ torch.cat([nmsbox, extra], dim=-1) if self.obb else nmsbox,
1571
+ score,
1572
+ self.args.iou,
1573
+ )[: self.args.max_det]
1574
+ dets = torch.cat(
1575
+ [box[keep], score[keep].view(-1, 1), cls[keep].view(-1, 1).to(out.dtype), extra[keep]], dim=-1
1576
+ )
1577
+ # Zero-pad to max_det size to avoid reshape error
1578
+ pad = (0, 0, 0, self.args.max_det - dets.shape[0])
1579
+ out[i] = torch.nn.functional.pad(dets, pad)
1580
+ return (out[:bs], preds[1]) if self.model.task == "segment" else out[:bs]