dgenerate-ultralytics-headless 8.3.134__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (272) hide show
  1. dgenerate_ultralytics_headless-8.3.134.dist-info/METADATA +400 -0
  2. dgenerate_ultralytics_headless-8.3.134.dist-info/RECORD +272 -0
  3. dgenerate_ultralytics_headless-8.3.134.dist-info/WHEEL +5 -0
  4. dgenerate_ultralytics_headless-8.3.134.dist-info/entry_points.txt +3 -0
  5. dgenerate_ultralytics_headless-8.3.134.dist-info/licenses/LICENSE +661 -0
  6. dgenerate_ultralytics_headless-8.3.134.dist-info/top_level.txt +1 -0
  7. tests/__init__.py +22 -0
  8. tests/conftest.py +83 -0
  9. tests/test_cli.py +138 -0
  10. tests/test_cuda.py +215 -0
  11. tests/test_engine.py +131 -0
  12. tests/test_exports.py +236 -0
  13. tests/test_integrations.py +154 -0
  14. tests/test_python.py +694 -0
  15. tests/test_solutions.py +187 -0
  16. ultralytics/__init__.py +30 -0
  17. ultralytics/assets/bus.jpg +0 -0
  18. ultralytics/assets/zidane.jpg +0 -0
  19. ultralytics/cfg/__init__.py +1023 -0
  20. ultralytics/cfg/datasets/Argoverse.yaml +77 -0
  21. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  22. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  23. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  24. ultralytics/cfg/datasets/HomeObjects-3K.yaml +33 -0
  25. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  26. ultralytics/cfg/datasets/Objects365.yaml +443 -0
  27. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  28. ultralytics/cfg/datasets/VOC.yaml +106 -0
  29. ultralytics/cfg/datasets/VisDrone.yaml +77 -0
  30. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  31. ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
  32. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  33. ultralytics/cfg/datasets/coco-pose.yaml +42 -0
  34. ultralytics/cfg/datasets/coco.yaml +118 -0
  35. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  36. ultralytics/cfg/datasets/coco128.yaml +101 -0
  37. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  38. ultralytics/cfg/datasets/coco8-pose.yaml +26 -0
  39. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  40. ultralytics/cfg/datasets/coco8.yaml +101 -0
  41. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  42. ultralytics/cfg/datasets/dog-pose.yaml +24 -0
  43. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  44. ultralytics/cfg/datasets/dota8.yaml +35 -0
  45. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
  46. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  47. ultralytics/cfg/datasets/medical-pills.yaml +22 -0
  48. ultralytics/cfg/datasets/open-images-v7.yaml +666 -0
  49. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  50. ultralytics/cfg/datasets/signature.yaml +21 -0
  51. ultralytics/cfg/datasets/tiger-pose.yaml +25 -0
  52. ultralytics/cfg/datasets/xView.yaml +155 -0
  53. ultralytics/cfg/default.yaml +127 -0
  54. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  55. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  56. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  57. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  58. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  59. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  60. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  61. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  62. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  63. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  64. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  65. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  66. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  67. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  68. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  69. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  70. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  71. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  72. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  73. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  74. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  75. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  76. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  77. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  78. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  79. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  80. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  81. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  82. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  83. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
  84. ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
  85. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  86. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  87. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  88. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  89. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  90. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  91. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  92. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  93. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  94. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  95. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  96. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  97. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  98. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  99. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  100. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  101. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  102. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  103. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  104. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  105. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  106. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  107. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  108. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  109. ultralytics/cfg/trackers/botsort.yaml +22 -0
  110. ultralytics/cfg/trackers/bytetrack.yaml +14 -0
  111. ultralytics/data/__init__.py +26 -0
  112. ultralytics/data/annotator.py +66 -0
  113. ultralytics/data/augment.py +2945 -0
  114. ultralytics/data/base.py +438 -0
  115. ultralytics/data/build.py +258 -0
  116. ultralytics/data/converter.py +754 -0
  117. ultralytics/data/dataset.py +834 -0
  118. ultralytics/data/loaders.py +676 -0
  119. ultralytics/data/scripts/download_weights.sh +18 -0
  120. ultralytics/data/scripts/get_coco.sh +61 -0
  121. ultralytics/data/scripts/get_coco128.sh +18 -0
  122. ultralytics/data/scripts/get_imagenet.sh +52 -0
  123. ultralytics/data/split.py +125 -0
  124. ultralytics/data/split_dota.py +325 -0
  125. ultralytics/data/utils.py +777 -0
  126. ultralytics/engine/__init__.py +1 -0
  127. ultralytics/engine/exporter.py +1519 -0
  128. ultralytics/engine/model.py +1156 -0
  129. ultralytics/engine/predictor.py +502 -0
  130. ultralytics/engine/results.py +1840 -0
  131. ultralytics/engine/trainer.py +853 -0
  132. ultralytics/engine/tuner.py +243 -0
  133. ultralytics/engine/validator.py +377 -0
  134. ultralytics/hub/__init__.py +168 -0
  135. ultralytics/hub/auth.py +137 -0
  136. ultralytics/hub/google/__init__.py +176 -0
  137. ultralytics/hub/session.py +446 -0
  138. ultralytics/hub/utils.py +248 -0
  139. ultralytics/models/__init__.py +9 -0
  140. ultralytics/models/fastsam/__init__.py +7 -0
  141. ultralytics/models/fastsam/model.py +61 -0
  142. ultralytics/models/fastsam/predict.py +181 -0
  143. ultralytics/models/fastsam/utils.py +24 -0
  144. ultralytics/models/fastsam/val.py +40 -0
  145. ultralytics/models/nas/__init__.py +7 -0
  146. ultralytics/models/nas/model.py +102 -0
  147. ultralytics/models/nas/predict.py +58 -0
  148. ultralytics/models/nas/val.py +39 -0
  149. ultralytics/models/rtdetr/__init__.py +7 -0
  150. ultralytics/models/rtdetr/model.py +63 -0
  151. ultralytics/models/rtdetr/predict.py +84 -0
  152. ultralytics/models/rtdetr/train.py +85 -0
  153. ultralytics/models/rtdetr/val.py +191 -0
  154. ultralytics/models/sam/__init__.py +6 -0
  155. ultralytics/models/sam/amg.py +260 -0
  156. ultralytics/models/sam/build.py +358 -0
  157. ultralytics/models/sam/model.py +170 -0
  158. ultralytics/models/sam/modules/__init__.py +1 -0
  159. ultralytics/models/sam/modules/blocks.py +1129 -0
  160. ultralytics/models/sam/modules/decoders.py +515 -0
  161. ultralytics/models/sam/modules/encoders.py +854 -0
  162. ultralytics/models/sam/modules/memory_attention.py +299 -0
  163. ultralytics/models/sam/modules/sam.py +1006 -0
  164. ultralytics/models/sam/modules/tiny_encoder.py +1002 -0
  165. ultralytics/models/sam/modules/transformer.py +351 -0
  166. ultralytics/models/sam/modules/utils.py +394 -0
  167. ultralytics/models/sam/predict.py +1605 -0
  168. ultralytics/models/utils/__init__.py +1 -0
  169. ultralytics/models/utils/loss.py +455 -0
  170. ultralytics/models/utils/ops.py +268 -0
  171. ultralytics/models/yolo/__init__.py +7 -0
  172. ultralytics/models/yolo/classify/__init__.py +7 -0
  173. ultralytics/models/yolo/classify/predict.py +88 -0
  174. ultralytics/models/yolo/classify/train.py +233 -0
  175. ultralytics/models/yolo/classify/val.py +215 -0
  176. ultralytics/models/yolo/detect/__init__.py +7 -0
  177. ultralytics/models/yolo/detect/predict.py +124 -0
  178. ultralytics/models/yolo/detect/train.py +217 -0
  179. ultralytics/models/yolo/detect/val.py +451 -0
  180. ultralytics/models/yolo/model.py +354 -0
  181. ultralytics/models/yolo/obb/__init__.py +7 -0
  182. ultralytics/models/yolo/obb/predict.py +66 -0
  183. ultralytics/models/yolo/obb/train.py +81 -0
  184. ultralytics/models/yolo/obb/val.py +283 -0
  185. ultralytics/models/yolo/pose/__init__.py +7 -0
  186. ultralytics/models/yolo/pose/predict.py +79 -0
  187. ultralytics/models/yolo/pose/train.py +154 -0
  188. ultralytics/models/yolo/pose/val.py +394 -0
  189. ultralytics/models/yolo/segment/__init__.py +7 -0
  190. ultralytics/models/yolo/segment/predict.py +113 -0
  191. ultralytics/models/yolo/segment/train.py +123 -0
  192. ultralytics/models/yolo/segment/val.py +428 -0
  193. ultralytics/models/yolo/world/__init__.py +5 -0
  194. ultralytics/models/yolo/world/train.py +119 -0
  195. ultralytics/models/yolo/world/train_world.py +176 -0
  196. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  197. ultralytics/models/yolo/yoloe/predict.py +169 -0
  198. ultralytics/models/yolo/yoloe/train.py +298 -0
  199. ultralytics/models/yolo/yoloe/train_seg.py +124 -0
  200. ultralytics/models/yolo/yoloe/val.py +191 -0
  201. ultralytics/nn/__init__.py +29 -0
  202. ultralytics/nn/autobackend.py +842 -0
  203. ultralytics/nn/modules/__init__.py +182 -0
  204. ultralytics/nn/modules/activation.py +53 -0
  205. ultralytics/nn/modules/block.py +1966 -0
  206. ultralytics/nn/modules/conv.py +712 -0
  207. ultralytics/nn/modules/head.py +880 -0
  208. ultralytics/nn/modules/transformer.py +713 -0
  209. ultralytics/nn/modules/utils.py +164 -0
  210. ultralytics/nn/tasks.py +1627 -0
  211. ultralytics/nn/text_model.py +351 -0
  212. ultralytics/solutions/__init__.py +41 -0
  213. ultralytics/solutions/ai_gym.py +116 -0
  214. ultralytics/solutions/analytics.py +252 -0
  215. ultralytics/solutions/config.py +106 -0
  216. ultralytics/solutions/distance_calculation.py +124 -0
  217. ultralytics/solutions/heatmap.py +127 -0
  218. ultralytics/solutions/instance_segmentation.py +84 -0
  219. ultralytics/solutions/object_blurrer.py +90 -0
  220. ultralytics/solutions/object_counter.py +195 -0
  221. ultralytics/solutions/object_cropper.py +84 -0
  222. ultralytics/solutions/parking_management.py +273 -0
  223. ultralytics/solutions/queue_management.py +93 -0
  224. ultralytics/solutions/region_counter.py +120 -0
  225. ultralytics/solutions/security_alarm.py +154 -0
  226. ultralytics/solutions/similarity_search.py +172 -0
  227. ultralytics/solutions/solutions.py +724 -0
  228. ultralytics/solutions/speed_estimation.py +110 -0
  229. ultralytics/solutions/streamlit_inference.py +196 -0
  230. ultralytics/solutions/templates/similarity-search.html +160 -0
  231. ultralytics/solutions/trackzone.py +88 -0
  232. ultralytics/solutions/vision_eye.py +68 -0
  233. ultralytics/trackers/__init__.py +7 -0
  234. ultralytics/trackers/basetrack.py +124 -0
  235. ultralytics/trackers/bot_sort.py +260 -0
  236. ultralytics/trackers/byte_tracker.py +480 -0
  237. ultralytics/trackers/track.py +125 -0
  238. ultralytics/trackers/utils/__init__.py +1 -0
  239. ultralytics/trackers/utils/gmc.py +376 -0
  240. ultralytics/trackers/utils/kalman_filter.py +493 -0
  241. ultralytics/trackers/utils/matching.py +157 -0
  242. ultralytics/utils/__init__.py +1435 -0
  243. ultralytics/utils/autobatch.py +106 -0
  244. ultralytics/utils/autodevice.py +174 -0
  245. ultralytics/utils/benchmarks.py +695 -0
  246. ultralytics/utils/callbacks/__init__.py +5 -0
  247. ultralytics/utils/callbacks/base.py +234 -0
  248. ultralytics/utils/callbacks/clearml.py +153 -0
  249. ultralytics/utils/callbacks/comet.py +552 -0
  250. ultralytics/utils/callbacks/dvc.py +205 -0
  251. ultralytics/utils/callbacks/hub.py +108 -0
  252. ultralytics/utils/callbacks/mlflow.py +138 -0
  253. ultralytics/utils/callbacks/neptune.py +140 -0
  254. ultralytics/utils/callbacks/raytune.py +43 -0
  255. ultralytics/utils/callbacks/tensorboard.py +132 -0
  256. ultralytics/utils/callbacks/wb.py +185 -0
  257. ultralytics/utils/checks.py +897 -0
  258. ultralytics/utils/dist.py +119 -0
  259. ultralytics/utils/downloads.py +499 -0
  260. ultralytics/utils/errors.py +43 -0
  261. ultralytics/utils/export.py +219 -0
  262. ultralytics/utils/files.py +221 -0
  263. ultralytics/utils/instance.py +499 -0
  264. ultralytics/utils/loss.py +813 -0
  265. ultralytics/utils/metrics.py +1356 -0
  266. ultralytics/utils/ops.py +885 -0
  267. ultralytics/utils/patches.py +143 -0
  268. ultralytics/utils/plotting.py +1011 -0
  269. ultralytics/utils/tal.py +416 -0
  270. ultralytics/utils/torch_utils.py +990 -0
  271. ultralytics/utils/triton.py +116 -0
  272. ultralytics/utils/tuner.py +159 -0
@@ -0,0 +1,1519 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+ """
3
+ Export a YOLO PyTorch model to other formats. TensorFlow exports authored by https://github.com/zldrobit.
4
+
5
+ Format | `format=argument` | Model
6
+ --- | --- | ---
7
+ PyTorch | - | yolo11n.pt
8
+ TorchScript | `torchscript` | yolo11n.torchscript
9
+ ONNX | `onnx` | yolo11n.onnx
10
+ OpenVINO | `openvino` | yolo11n_openvino_model/
11
+ TensorRT | `engine` | yolo11n.engine
12
+ CoreML | `coreml` | yolo11n.mlpackage
13
+ TensorFlow SavedModel | `saved_model` | yolo11n_saved_model/
14
+ TensorFlow GraphDef | `pb` | yolo11n.pb
15
+ TensorFlow Lite | `tflite` | yolo11n.tflite
16
+ TensorFlow Edge TPU | `edgetpu` | yolo11n_edgetpu.tflite
17
+ TensorFlow.js | `tfjs` | yolo11n_web_model/
18
+ PaddlePaddle | `paddle` | yolo11n_paddle_model/
19
+ MNN | `mnn` | yolo11n.mnn
20
+ NCNN | `ncnn` | yolo11n_ncnn_model/
21
+ IMX | `imx` | yolo11n_imx_model/
22
+ RKNN | `rknn` | yolo11n_rknn_model/
23
+
24
+ Requirements:
25
+ $ pip install "ultralytics[export]"
26
+
27
+ Python:
28
+ from ultralytics import YOLO
29
+ model = YOLO('yolo11n.pt')
30
+ results = model.export(format='onnx')
31
+
32
+ CLI:
33
+ $ yolo mode=export model=yolo11n.pt format=onnx
34
+
35
+ Inference:
36
+ $ yolo predict model=yolo11n.pt # PyTorch
37
+ yolo11n.torchscript # TorchScript
38
+ yolo11n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
39
+ yolo11n_openvino_model # OpenVINO
40
+ yolo11n.engine # TensorRT
41
+ yolo11n.mlpackage # CoreML (macOS-only)
42
+ yolo11n_saved_model # TensorFlow SavedModel
43
+ yolo11n.pb # TensorFlow GraphDef
44
+ yolo11n.tflite # TensorFlow Lite
45
+ yolo11n_edgetpu.tflite # TensorFlow Edge TPU
46
+ yolo11n_paddle_model # PaddlePaddle
47
+ yolo11n.mnn # MNN
48
+ yolo11n_ncnn_model # NCNN
49
+ yolo11n_imx_model # IMX
50
+
51
+ TensorFlow.js:
52
+ $ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
53
+ $ npm install
54
+ $ ln -s ../../yolo11n_web_model public/yolo11n_web_model
55
+ $ npm start
56
+ """
57
+
58
+ import json
59
+ import os
60
+ import re
61
+ import shutil
62
+ import subprocess
63
+ import time
64
+ import warnings
65
+ from contextlib import contextmanager
66
+ from copy import deepcopy
67
+ from datetime import datetime
68
+ from pathlib import Path
69
+
70
+ import numpy as np
71
+ import torch
72
+
73
+ from ultralytics import __version__
74
+ from ultralytics.cfg import TASK2DATA, get_cfg
75
+ from ultralytics.data import build_dataloader
76
+ from ultralytics.data.dataset import YOLODataset
77
+ from ultralytics.data.utils import check_cls_dataset, check_det_dataset
78
+ from ultralytics.nn.autobackend import check_class_names, default_class_names
79
+ from ultralytics.nn.modules import C2f, Classify, Detect, RTDETRDecoder
80
+ from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, WorldModel
81
+ from ultralytics.utils import (
82
+ ARM64,
83
+ DEFAULT_CFG,
84
+ IS_COLAB,
85
+ IS_JETSON,
86
+ LINUX,
87
+ LOGGER,
88
+ MACOS,
89
+ MACOS_VERSION,
90
+ RKNN_CHIPS,
91
+ ROOT,
92
+ SETTINGS,
93
+ WINDOWS,
94
+ YAML,
95
+ callbacks,
96
+ colorstr,
97
+ get_default_args,
98
+ )
99
+ from ultralytics.utils.checks import (
100
+ check_imgsz,
101
+ check_is_path_safe,
102
+ check_requirements,
103
+ check_version,
104
+ is_sudo_available,
105
+ )
106
+ from ultralytics.utils.downloads import attempt_download_asset, get_github_assets, safe_download
107
+ from ultralytics.utils.export import export_engine, export_onnx
108
+ from ultralytics.utils.files import file_size, spaces_in_path
109
+ from ultralytics.utils.ops import Profile, nms_rotated
110
+ from ultralytics.utils.torch_utils import TORCH_1_13, get_cpu_info, get_latest_opset, select_device
111
+
112
+
113
+ def export_formats():
114
+ """Return a dictionary of Ultralytics YOLO export formats."""
115
+ x = [
116
+ ["PyTorch", "-", ".pt", True, True, []],
117
+ ["TorchScript", "torchscript", ".torchscript", True, True, ["batch", "optimize", "half", "nms"]],
118
+ ["ONNX", "onnx", ".onnx", True, True, ["batch", "dynamic", "half", "opset", "simplify", "nms"]],
119
+ [
120
+ "OpenVINO",
121
+ "openvino",
122
+ "_openvino_model",
123
+ True,
124
+ False,
125
+ ["batch", "dynamic", "half", "int8", "nms", "fraction"],
126
+ ],
127
+ [
128
+ "TensorRT",
129
+ "engine",
130
+ ".engine",
131
+ False,
132
+ True,
133
+ ["batch", "dynamic", "half", "int8", "simplify", "nms", "fraction"],
134
+ ],
135
+ ["CoreML", "coreml", ".mlpackage", True, False, ["batch", "half", "int8", "nms"]],
136
+ ["TensorFlow SavedModel", "saved_model", "_saved_model", True, True, ["batch", "int8", "keras", "nms"]],
137
+ ["TensorFlow GraphDef", "pb", ".pb", True, True, ["batch"]],
138
+ ["TensorFlow Lite", "tflite", ".tflite", True, False, ["batch", "half", "int8", "nms", "fraction"]],
139
+ ["TensorFlow Edge TPU", "edgetpu", "_edgetpu.tflite", True, False, []],
140
+ ["TensorFlow.js", "tfjs", "_web_model", True, False, ["batch", "half", "int8", "nms"]],
141
+ ["PaddlePaddle", "paddle", "_paddle_model", True, True, ["batch"]],
142
+ ["MNN", "mnn", ".mnn", True, True, ["batch", "half", "int8"]],
143
+ ["NCNN", "ncnn", "_ncnn_model", True, True, ["batch", "half"]],
144
+ ["IMX", "imx", "_imx_model", True, True, ["int8", "fraction"]],
145
+ ["RKNN", "rknn", "_rknn_model", False, False, ["batch", "name"]],
146
+ ]
147
+ return dict(zip(["Format", "Argument", "Suffix", "CPU", "GPU", "Arguments"], zip(*x)))
148
+
149
+
150
+ def validate_args(format, passed_args, valid_args):
151
+ """
152
+ Validate arguments based on the export format.
153
+
154
+ Args:
155
+ format (str): The export format.
156
+ passed_args (Namespace): The arguments used during export.
157
+ valid_args (list): List of valid arguments for the format.
158
+
159
+ Raises:
160
+ AssertionError: If an unsupported argument is used, or if the format lacks supported argument listings.
161
+ """
162
+ export_args = ["half", "int8", "dynamic", "keras", "nms", "batch", "fraction"]
163
+
164
+ assert valid_args is not None, f"ERROR ❌️ valid arguments for '{format}' not listed."
165
+ custom = {"batch": 1, "data": None, "device": None} # exporter defaults
166
+ default_args = get_cfg(DEFAULT_CFG, custom)
167
+ for arg in export_args:
168
+ not_default = getattr(passed_args, arg, None) != getattr(default_args, arg, None)
169
+ if not_default:
170
+ assert arg in valid_args, f"ERROR ❌️ argument '{arg}' is not supported for format='{format}'"
171
+
172
+
173
+ def gd_outputs(gd):
174
+ """Return TensorFlow GraphDef model output node names."""
175
+ name_list, input_list = [], []
176
+ for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef
177
+ name_list.append(node.name)
178
+ input_list.extend(node.input)
179
+ return sorted(f"{x}:0" for x in list(set(name_list) - set(input_list)) if not x.startswith("NoOp"))
180
+
181
+
182
+ def try_export(inner_func):
183
+ """YOLO export decorator, i.e. @try_export."""
184
+ inner_args = get_default_args(inner_func)
185
+
186
+ def outer_func(*args, **kwargs):
187
+ """Export a model."""
188
+ prefix = inner_args["prefix"]
189
+ dt = 0.0
190
+ try:
191
+ with Profile() as dt:
192
+ f, model = inner_func(*args, **kwargs)
193
+ LOGGER.info(f"{prefix} export success ✅ {dt.t:.1f}s, saved as '{f}' ({file_size(f):.1f} MB)")
194
+ return f, model
195
+ except Exception as e:
196
+ LOGGER.error(f"{prefix} export failure {dt.t:.1f}s: {e}")
197
+ raise e
198
+
199
+ return outer_func
200
+
201
+
202
+ @contextmanager
203
+ def arange_patch(args):
204
+ """
205
+ Workaround for ONNX torch.arange incompatibility with FP16.
206
+
207
+ https://github.com/pytorch/pytorch/issues/148041.
208
+ """
209
+ if args.dynamic and args.half and args.format == "onnx":
210
+ func = torch.arange
211
+
212
+ def arange(*args, dtype=None, **kwargs):
213
+ """Return a 1-D tensor of size with values from the interval and common difference."""
214
+ return func(*args, **kwargs).to(dtype) # cast to dtype instead of passing dtype
215
+
216
+ torch.arange = arange # patch
217
+ yield
218
+ torch.arange = func # unpatch
219
+ else:
220
+ yield
221
+
222
+
223
+ class Exporter:
224
+ """
225
+ A class for exporting a model.
226
+
227
+ Attributes:
228
+ args (SimpleNamespace): Configuration for the exporter.
229
+ callbacks (list, optional): List of callback functions.
230
+ """
231
+
232
+ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
233
+ """
234
+ Initialize the Exporter class.
235
+
236
+ Args:
237
+ cfg (str, optional): Path to a configuration file.
238
+ overrides (dict, optional): Configuration overrides.
239
+ _callbacks (dict, optional): Dictionary of callback functions.
240
+ """
241
+ self.args = get_cfg(cfg, overrides)
242
+ self.callbacks = _callbacks or callbacks.get_default_callbacks()
243
+ callbacks.add_integration_callbacks(self)
244
+
245
+ def __call__(self, model=None) -> str:
246
+ """Return list of exported files/dirs after running callbacks."""
247
+ self.run_callbacks("on_export_start")
248
+ t = time.time()
249
+ fmt = self.args.format.lower() # to lowercase
250
+ if fmt in {"tensorrt", "trt"}: # 'engine' aliases
251
+ fmt = "engine"
252
+ if fmt in {"mlmodel", "mlpackage", "mlprogram", "apple", "ios", "coreml"}: # 'coreml' aliases
253
+ fmt = "coreml"
254
+ fmts_dict = export_formats()
255
+ fmts = tuple(fmts_dict["Argument"][1:]) # available export formats
256
+ if fmt not in fmts:
257
+ import difflib
258
+
259
+ # Get the closest match if format is invalid
260
+ matches = difflib.get_close_matches(fmt, fmts, n=1, cutoff=0.6) # 60% similarity required to match
261
+ if not matches:
262
+ raise ValueError(f"Invalid export format='{fmt}'. Valid formats are {fmts}")
263
+ LOGGER.warning(f"Invalid export format='{fmt}', updating to format='{matches[0]}'")
264
+ fmt = matches[0]
265
+ flags = [x == fmt for x in fmts]
266
+ if sum(flags) != 1:
267
+ raise ValueError(f"Invalid export format='{fmt}'. Valid formats are {fmts}")
268
+ (jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, mnn, ncnn, imx, rknn) = (
269
+ flags # export booleans
270
+ )
271
+
272
+ is_tf_format = any((saved_model, pb, tflite, edgetpu, tfjs))
273
+
274
+ # Device
275
+ dla = None
276
+ if fmt == "engine" and self.args.device is None:
277
+ LOGGER.warning("TensorRT requires GPU export, automatically assigning device=0")
278
+ self.args.device = "0"
279
+ if fmt == "engine" and "dla" in str(self.args.device): # convert int/list to str first
280
+ dla = self.args.device.split(":")[-1]
281
+ self.args.device = "0" # update device to "0"
282
+ assert dla in {"0", "1"}, f"Expected self.args.device='dla:0' or 'dla:1, but got {self.args.device}."
283
+ if imx and self.args.device is None and torch.cuda.is_available():
284
+ LOGGER.warning("Exporting on CPU while CUDA is available, setting device=0 for faster export on GPU.")
285
+ self.args.device = "0" # update device to "0"
286
+ self.device = select_device("cpu" if self.args.device is None else self.args.device)
287
+
288
+ # Argument compatibility checks
289
+ fmt_keys = fmts_dict["Arguments"][flags.index(True) + 1]
290
+ validate_args(fmt, self.args, fmt_keys)
291
+ if imx:
292
+ if not self.args.int8:
293
+ LOGGER.warning("IMX export requires int8=True, setting int8=True.")
294
+ self.args.int8 = True
295
+ if model.task != "detect":
296
+ raise ValueError("IMX export only supported for detection models.")
297
+ if not hasattr(model, "names"):
298
+ model.names = default_class_names()
299
+ model.names = check_class_names(model.names)
300
+ if self.args.half and self.args.int8:
301
+ LOGGER.warning("half=True and int8=True are mutually exclusive, setting half=False.")
302
+ self.args.half = False
303
+ if self.args.half and onnx and self.device.type == "cpu":
304
+ LOGGER.warning("half=True only compatible with GPU export, i.e. use device=0")
305
+ self.args.half = False
306
+ self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size
307
+ if self.args.int8 and engine:
308
+ self.args.dynamic = True # enforce dynamic to export TensorRT INT8
309
+ if self.args.optimize:
310
+ assert not ncnn, "optimize=True not compatible with format='ncnn', i.e. use optimize=False"
311
+ assert self.device.type == "cpu", "optimize=True not compatible with cuda devices, i.e. use device='cpu'"
312
+ if rknn:
313
+ if not self.args.name:
314
+ LOGGER.warning(
315
+ "Rockchip RKNN export requires a missing 'name' arg for processor type. "
316
+ "Using default name='rk3588'."
317
+ )
318
+ self.args.name = "rk3588"
319
+ self.args.name = self.args.name.lower()
320
+ assert self.args.name in RKNN_CHIPS, (
321
+ f"Invalid processor name '{self.args.name}' for Rockchip RKNN export. Valid names are {RKNN_CHIPS}."
322
+ )
323
+ if self.args.int8 and tflite:
324
+ assert not getattr(model, "end2end", False), "TFLite INT8 export not supported for end2end models."
325
+ if self.args.nms:
326
+ assert not isinstance(model, ClassificationModel), "'nms=True' is not valid for classification models."
327
+ assert not (tflite and ARM64 and LINUX), "TFLite export with NMS unsupported on ARM64 Linux"
328
+ if getattr(model, "end2end", False):
329
+ LOGGER.warning("'nms=True' is not available for end2end models. Forcing 'nms=False'.")
330
+ self.args.nms = False
331
+ self.args.conf = self.args.conf or 0.25 # set conf default value for nms export
332
+ if edgetpu:
333
+ if not LINUX or ARM64:
334
+ raise SystemError(
335
+ "Edge TPU export only supported on non-aarch64 Linux. See https://coral.ai/docs/edgetpu/compiler"
336
+ )
337
+ elif self.args.batch != 1: # see github.com/ultralytics/ultralytics/pull/13420
338
+ LOGGER.warning("Edge TPU export requires batch size 1, setting batch=1.")
339
+ self.args.batch = 1
340
+ if isinstance(model, WorldModel):
341
+ LOGGER.warning(
342
+ "YOLOWorld (original version) export is not supported to any format. "
343
+ "YOLOWorldv2 models (i.e. 'yolov8s-worldv2.pt') only support export to "
344
+ "(torchscript, onnx, openvino, engine, coreml) formats. "
345
+ "See https://docs.ultralytics.com/models/yolo-world for details."
346
+ )
347
+ model.clip_model = None # openvino int8 export error: https://github.com/ultralytics/ultralytics/pull/18445
348
+ if self.args.int8 and not self.args.data:
349
+ self.args.data = DEFAULT_CFG.data or TASK2DATA[getattr(model, "task", "detect")] # assign default data
350
+ LOGGER.warning(
351
+ f"INT8 export requires a missing 'data' arg for calibration. Using default 'data={self.args.data}'."
352
+ )
353
+ if tfjs and (ARM64 and LINUX):
354
+ raise SystemError("TF.js exports are not currently supported on ARM64 Linux")
355
+ # Recommend OpenVINO if export and Intel CPU
356
+ if SETTINGS.get("openvino_msg"):
357
+ if "intel" in get_cpu_info().lower():
358
+ LOGGER.info(
359
+ "💡 ProTip: Export to OpenVINO format for best performance on Intel CPUs."
360
+ " Learn more at https://docs.ultralytics.com/integrations/openvino/"
361
+ )
362
+ SETTINGS["openvino_msg"] = False
363
+
364
+ # Input
365
+ im = torch.zeros(self.args.batch, model.yaml.get("channels", 3), *self.imgsz).to(self.device)
366
+ file = Path(
367
+ getattr(model, "pt_path", None) or getattr(model, "yaml_file", None) or model.yaml.get("yaml_file", "")
368
+ )
369
+ if file.suffix in {".yaml", ".yml"}:
370
+ file = Path(file.name)
371
+
372
+ # Update model
373
+ model = deepcopy(model).to(self.device)
374
+ for p in model.parameters():
375
+ p.requires_grad = False
376
+ model.eval()
377
+ model.float()
378
+ model = model.fuse()
379
+
380
+ if imx:
381
+ from ultralytics.utils.torch_utils import FXModel
382
+
383
+ model = FXModel(model)
384
+ for m in model.modules():
385
+ if isinstance(m, Classify):
386
+ m.export = True
387
+ if isinstance(m, (Detect, RTDETRDecoder)): # includes all Detect subclasses like Segment, Pose, OBB
388
+ m.dynamic = self.args.dynamic
389
+ m.export = True
390
+ m.format = self.args.format
391
+ m.max_det = self.args.max_det
392
+ m.xyxy = self.args.nms and not coreml
393
+ elif isinstance(m, C2f) and not is_tf_format:
394
+ # EdgeTPU does not support FlexSplitV while split provides cleaner ONNX graph
395
+ m.forward = m.forward_split
396
+ if isinstance(m, Detect) and imx:
397
+ from ultralytics.utils.tal import make_anchors
398
+
399
+ m.anchors, m.strides = (
400
+ x.transpose(0, 1)
401
+ for x in make_anchors(
402
+ torch.cat([s / m.stride.unsqueeze(-1) for s in self.imgsz], dim=1), m.stride, 0.5
403
+ )
404
+ )
405
+
406
+ y = None
407
+ for _ in range(2): # dry runs
408
+ y = NMSModel(model, self.args)(im) if self.args.nms and not coreml else model(im)
409
+ if self.args.half and onnx and self.device.type != "cpu":
410
+ im, model = im.half(), model.half() # to FP16
411
+
412
+ # Filter warnings
413
+ warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) # suppress TracerWarning
414
+ warnings.filterwarnings("ignore", category=UserWarning) # suppress shape prim::Constant missing ONNX warning
415
+ warnings.filterwarnings("ignore", category=DeprecationWarning) # suppress CoreML np.bool deprecation warning
416
+
417
+ # Assign
418
+ self.im = im
419
+ self.model = model
420
+ self.file = file
421
+ self.output_shape = (
422
+ tuple(y.shape)
423
+ if isinstance(y, torch.Tensor)
424
+ else tuple(tuple(x.shape if isinstance(x, torch.Tensor) else []) for x in y)
425
+ )
426
+ self.pretty_name = Path(self.model.yaml.get("yaml_file", self.file)).stem.replace("yolo", "YOLO")
427
+ data = model.args["data"] if hasattr(model, "args") and isinstance(model.args, dict) else ""
428
+ description = f"Ultralytics {self.pretty_name} model {f'trained on {data}' if data else ''}"
429
+ self.metadata = {
430
+ "description": description,
431
+ "author": "Ultralytics",
432
+ "date": datetime.now().isoformat(),
433
+ "version": __version__,
434
+ "license": "AGPL-3.0 License (https://ultralytics.com/license)",
435
+ "docs": "https://docs.ultralytics.com",
436
+ "stride": int(max(model.stride)),
437
+ "task": model.task,
438
+ "batch": self.args.batch,
439
+ "imgsz": self.imgsz,
440
+ "names": model.names,
441
+ "args": {k: v for k, v in self.args if k in fmt_keys},
442
+ "channels": model.yaml.get("channels", 3),
443
+ } # model metadata
444
+ if dla is not None:
445
+ self.metadata["dla"] = dla # make sure `AutoBackend` uses correct dla device if it has one
446
+ if model.task == "pose":
447
+ self.metadata["kpt_shape"] = model.model[-1].kpt_shape
448
+
449
+ LOGGER.info(
450
+ f"\n{colorstr('PyTorch:')} starting from '{file}' with input shape {tuple(im.shape)} BCHW and "
451
+ f"output shape(s) {self.output_shape} ({file_size(file):.1f} MB)"
452
+ )
453
+
454
+ # Exports
455
+ f = [""] * len(fmts) # exported filenames
456
+ if jit or ncnn: # TorchScript
457
+ f[0], _ = self.export_torchscript()
458
+ if engine: # TensorRT required before ONNX
459
+ f[1], _ = self.export_engine(dla=dla)
460
+ if onnx: # ONNX
461
+ f[2], _ = self.export_onnx()
462
+ if xml: # OpenVINO
463
+ f[3], _ = self.export_openvino()
464
+ if coreml: # CoreML
465
+ f[4], _ = self.export_coreml()
466
+ if is_tf_format: # TensorFlow formats
467
+ self.args.int8 |= edgetpu
468
+ f[5], keras_model = self.export_saved_model()
469
+ if pb or tfjs: # pb prerequisite to tfjs
470
+ f[6], _ = self.export_pb(keras_model=keras_model)
471
+ if tflite:
472
+ f[7], _ = self.export_tflite()
473
+ if edgetpu:
474
+ f[8], _ = self.export_edgetpu(tflite_model=Path(f[5]) / f"{self.file.stem}_full_integer_quant.tflite")
475
+ if tfjs:
476
+ f[9], _ = self.export_tfjs()
477
+ if paddle: # PaddlePaddle
478
+ f[10], _ = self.export_paddle()
479
+ if mnn: # MNN
480
+ f[11], _ = self.export_mnn()
481
+ if ncnn: # NCNN
482
+ f[12], _ = self.export_ncnn()
483
+ if imx:
484
+ f[13], _ = self.export_imx()
485
+ if rknn:
486
+ f[14], _ = self.export_rknn()
487
+
488
+ # Finish
489
+ f = [str(x) for x in f if x] # filter out '' and None
490
+ if any(f):
491
+ f = str(Path(f[-1]))
492
+ square = self.imgsz[0] == self.imgsz[1]
493
+ s = (
494
+ ""
495
+ if square
496
+ else f"WARNING ⚠️ non-PyTorch val requires square images, 'imgsz={self.imgsz}' will not "
497
+ f"work. Use export 'imgsz={max(self.imgsz)}' if val is required."
498
+ )
499
+ imgsz = self.imgsz[0] if square else str(self.imgsz)[1:-1].replace(" ", "")
500
+ predict_data = f"data={data}" if model.task == "segment" and fmt == "pb" else ""
501
+ q = "int8" if self.args.int8 else "half" if self.args.half else "" # quantization
502
+ LOGGER.info(
503
+ f"\nExport complete ({time.time() - t:.1f}s)"
504
+ f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
505
+ f"\nPredict: yolo predict task={model.task} model={f} imgsz={imgsz} {q} {predict_data}"
506
+ f"\nValidate: yolo val task={model.task} model={f} imgsz={imgsz} data={data} {q} {s}"
507
+ f"\nVisualize: https://netron.app"
508
+ )
509
+
510
+ self.run_callbacks("on_export_end")
511
+ return f # return list of exported files/dirs
512
+
513
+ def get_int8_calibration_dataloader(self, prefix=""):
514
+ """Build and return a dataloader for calibration of INT8 models."""
515
+ LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'")
516
+ data = (check_cls_dataset if self.model.task == "classify" else check_det_dataset)(self.args.data)
517
+ # TensorRT INT8 calibration should use 2x batch size
518
+ batch = self.args.batch * (2 if self.args.format == "engine" else 1)
519
+ dataset = YOLODataset(
520
+ data[self.args.split or "val"],
521
+ data=data,
522
+ fraction=self.args.fraction,
523
+ task=self.model.task,
524
+ imgsz=self.imgsz[0],
525
+ augment=False,
526
+ batch_size=batch,
527
+ )
528
+ n = len(dataset)
529
+ if n < self.args.batch:
530
+ raise ValueError(
531
+ f"The calibration dataset ({n} images) must have at least as many images as the batch size "
532
+ f"('batch={self.args.batch}')."
533
+ )
534
+ elif n < 300:
535
+ LOGGER.warning(f"{prefix} >300 images recommended for INT8 calibration, found {n} images.")
536
+ return build_dataloader(dataset, batch=batch, workers=0) # required for batch loading
537
+
538
+ @try_export
539
+ def export_torchscript(self, prefix=colorstr("TorchScript:")):
540
+ """YOLO TorchScript model export."""
541
+ LOGGER.info(f"\n{prefix} starting export with torch {torch.__version__}...")
542
+ f = self.file.with_suffix(".torchscript")
543
+
544
+ ts = torch.jit.trace(NMSModel(self.model, self.args) if self.args.nms else self.model, self.im, strict=False)
545
+ extra_files = {"config.txt": json.dumps(self.metadata)} # torch._C.ExtraFilesMap()
546
+ if self.args.optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
547
+ LOGGER.info(f"{prefix} optimizing for mobile...")
548
+ from torch.utils.mobile_optimizer import optimize_for_mobile
549
+
550
+ optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files)
551
+ else:
552
+ ts.save(str(f), _extra_files=extra_files)
553
+ return f, None
554
+
555
+ @try_export
556
+ def export_onnx(self, prefix=colorstr("ONNX:")):
557
+ """YOLO ONNX export."""
558
+ requirements = ["onnx>=1.12.0,<1.18.0"]
559
+ if self.args.simplify:
560
+ requirements += ["onnxslim>=0.1.46", "onnxruntime" + ("-gpu" if torch.cuda.is_available() else "")]
561
+ check_requirements(requirements)
562
+ import onnx # noqa
563
+
564
+ opset_version = self.args.opset or get_latest_opset()
565
+ LOGGER.info(f"\n{prefix} starting export with onnx {onnx.__version__} opset {opset_version}...")
566
+ f = str(self.file.with_suffix(".onnx"))
567
+ output_names = ["output0", "output1"] if isinstance(self.model, SegmentationModel) else ["output0"]
568
+ dynamic = self.args.dynamic
569
+ if dynamic:
570
+ dynamic = {"images": {0: "batch", 2: "height", 3: "width"}} # shape(1,3,640,640)
571
+ if isinstance(self.model, SegmentationModel):
572
+ dynamic["output0"] = {0: "batch", 2: "anchors"} # shape(1, 116, 8400)
573
+ dynamic["output1"] = {0: "batch", 2: "mask_height", 3: "mask_width"} # shape(1,32,160,160)
574
+ elif isinstance(self.model, DetectionModel):
575
+ dynamic["output0"] = {0: "batch", 2: "anchors"} # shape(1, 84, 8400)
576
+ if self.args.nms: # only batch size is dynamic with NMS
577
+ dynamic["output0"].pop(2)
578
+ if self.args.nms and self.model.task == "obb":
579
+ self.args.opset = opset_version # for NMSModel
580
+
581
+ with arange_patch(self.args):
582
+ export_onnx(
583
+ NMSModel(self.model, self.args) if self.args.nms else self.model,
584
+ self.im,
585
+ f,
586
+ opset=opset_version,
587
+ input_names=["images"],
588
+ output_names=output_names,
589
+ dynamic=dynamic or None,
590
+ )
591
+
592
+ # Checks
593
+ model_onnx = onnx.load(f) # load onnx model
594
+
595
+ # Simplify
596
+ if self.args.simplify:
597
+ try:
598
+ import onnxslim
599
+
600
+ LOGGER.info(f"{prefix} slimming with onnxslim {onnxslim.__version__}...")
601
+ model_onnx = onnxslim.slim(model_onnx)
602
+
603
+ except Exception as e:
604
+ LOGGER.warning(f"{prefix} simplifier failure: {e}")
605
+
606
+ # Metadata
607
+ for k, v in self.metadata.items():
608
+ meta = model_onnx.metadata_props.add()
609
+ meta.key, meta.value = k, str(v)
610
+
611
+ onnx.save(model_onnx, f)
612
+ return f, model_onnx
613
+
614
+ @try_export
615
+ def export_openvino(self, prefix=colorstr("OpenVINO:")):
616
+ """YOLO OpenVINO export."""
617
+ if MACOS:
618
+ msg = "OpenVINO error in macOS>=15.4 https://github.com/openvinotoolkit/openvino/issues/30023"
619
+ check_version(MACOS_VERSION, "<15.4", name="macOS ", hard=True, msg=msg)
620
+ check_requirements("openvino>=2024.0.0")
621
+ import openvino as ov
622
+
623
+ LOGGER.info(f"\n{prefix} starting export with openvino {ov.__version__}...")
624
+ assert TORCH_1_13, f"OpenVINO export requires torch>=1.13.0 but torch=={torch.__version__} is installed"
625
+ ov_model = ov.convert_model(
626
+ NMSModel(self.model, self.args) if self.args.nms else self.model,
627
+ input=None if self.args.dynamic else [self.im.shape],
628
+ example_input=self.im,
629
+ )
630
+
631
+ def serialize(ov_model, file):
632
+ """Set RT info, serialize, and save metadata YAML."""
633
+ ov_model.set_rt_info("YOLO", ["model_info", "model_type"])
634
+ ov_model.set_rt_info(True, ["model_info", "reverse_input_channels"])
635
+ ov_model.set_rt_info(114, ["model_info", "pad_value"])
636
+ ov_model.set_rt_info([255.0], ["model_info", "scale_values"])
637
+ ov_model.set_rt_info(self.args.iou, ["model_info", "iou_threshold"])
638
+ ov_model.set_rt_info([v.replace(" ", "_") for v in self.model.names.values()], ["model_info", "labels"])
639
+ if self.model.task != "classify":
640
+ ov_model.set_rt_info("fit_to_window_letterbox", ["model_info", "resize_type"])
641
+
642
+ ov.save_model(ov_model, file, compress_to_fp16=self.args.half)
643
+ YAML.save(Path(file).parent / "metadata.yaml", self.metadata) # add metadata.yaml
644
+
645
+ if self.args.int8:
646
+ fq = str(self.file).replace(self.file.suffix, f"_int8_openvino_model{os.sep}")
647
+ fq_ov = str(Path(fq) / self.file.with_suffix(".xml").name)
648
+ # INT8 requires nncf, nncf requires packaging>=23.2 https://github.com/openvinotoolkit/nncf/issues/3463
649
+ check_requirements("packaging>=23.2") # must be installed first to build nncf wheel
650
+ check_requirements("nncf>=2.14.0")
651
+ import nncf
652
+
653
+ def transform_fn(data_item) -> np.ndarray:
654
+ """Quantization transform function."""
655
+ data_item: torch.Tensor = data_item["img"] if isinstance(data_item, dict) else data_item
656
+ assert data_item.dtype == torch.uint8, "Input image must be uint8 for the quantization preprocessing"
657
+ im = data_item.numpy().astype(np.float32) / 255.0 # uint8 to fp16/32 and 0-255 to 0.0-1.0
658
+ return np.expand_dims(im, 0) if im.ndim == 3 else im
659
+
660
+ # Generate calibration data for integer quantization
661
+ ignored_scope = None
662
+ if isinstance(self.model.model[-1], Detect):
663
+ # Includes all Detect subclasses like Segment, Pose, OBB, WorldDetect, YOLOEDetect
664
+ head_module_name = ".".join(list(self.model.named_modules())[-1][0].split(".")[:2])
665
+ ignored_scope = nncf.IgnoredScope( # ignore operations
666
+ patterns=[
667
+ f".*{head_module_name}/.*/Add",
668
+ f".*{head_module_name}/.*/Sub*",
669
+ f".*{head_module_name}/.*/Mul*",
670
+ f".*{head_module_name}/.*/Div*",
671
+ f".*{head_module_name}\\.dfl.*",
672
+ ],
673
+ types=["Sigmoid"],
674
+ )
675
+
676
+ quantized_ov_model = nncf.quantize(
677
+ model=ov_model,
678
+ calibration_dataset=nncf.Dataset(self.get_int8_calibration_dataloader(prefix), transform_fn),
679
+ preset=nncf.QuantizationPreset.MIXED,
680
+ ignored_scope=ignored_scope,
681
+ )
682
+ serialize(quantized_ov_model, fq_ov)
683
+ return fq, None
684
+
685
+ f = str(self.file).replace(self.file.suffix, f"_openvino_model{os.sep}")
686
+ f_ov = str(Path(f) / self.file.with_suffix(".xml").name)
687
+
688
+ serialize(ov_model, f_ov)
689
+ return f, None
690
+
691
+ @try_export
692
+ def export_paddle(self, prefix=colorstr("PaddlePaddle:")):
693
+ """YOLO Paddle export."""
694
+ assert not IS_JETSON, "Jetson Paddle exports not supported yet"
695
+ check_requirements(("paddlepaddle-gpu" if torch.cuda.is_available() else "paddlepaddle>=3.0.0", "x2paddle"))
696
+ import x2paddle # noqa
697
+ from x2paddle.convert import pytorch2paddle # noqa
698
+
699
+ LOGGER.info(f"\n{prefix} starting export with X2Paddle {x2paddle.__version__}...")
700
+ f = str(self.file).replace(self.file.suffix, f"_paddle_model{os.sep}")
701
+
702
+ pytorch2paddle(module=self.model, save_dir=f, jit_type="trace", input_examples=[self.im]) # export
703
+ YAML.save(Path(f) / "metadata.yaml", self.metadata) # add metadata.yaml
704
+ return f, None
705
+
706
+ @try_export
707
+ def export_mnn(self, prefix=colorstr("MNN:")):
708
+ """YOLO MNN export using MNN https://github.com/alibaba/MNN."""
709
+ f_onnx, _ = self.export_onnx() # get onnx model first
710
+
711
+ check_requirements("MNN>=2.9.6")
712
+ import MNN # noqa
713
+ from MNN.tools import mnnconvert
714
+
715
+ # Setup and checks
716
+ LOGGER.info(f"\n{prefix} starting export with MNN {MNN.version()}...")
717
+ assert Path(f_onnx).exists(), f"failed to export ONNX file: {f_onnx}"
718
+ f = str(self.file.with_suffix(".mnn")) # MNN model file
719
+ args = ["", "-f", "ONNX", "--modelFile", f_onnx, "--MNNModel", f, "--bizCode", json.dumps(self.metadata)]
720
+ if self.args.int8:
721
+ args.extend(("--weightQuantBits", "8"))
722
+ if self.args.half:
723
+ args.append("--fp16")
724
+ mnnconvert.convert(args)
725
+ # remove scratch file for model convert optimize
726
+ convert_scratch = Path(self.file.parent / ".__convert_external_data.bin")
727
+ if convert_scratch.exists():
728
+ convert_scratch.unlink()
729
+ return f, None
730
+
731
+ @try_export
732
+ def export_ncnn(self, prefix=colorstr("NCNN:")):
733
+ """YOLO NCNN export using PNNX https://github.com/pnnx/pnnx."""
734
+ check_requirements("ncnn")
735
+ import ncnn # noqa
736
+
737
+ LOGGER.info(f"\n{prefix} starting export with NCNN {ncnn.__version__}...")
738
+ f = Path(str(self.file).replace(self.file.suffix, f"_ncnn_model{os.sep}"))
739
+ f_ts = self.file.with_suffix(".torchscript")
740
+
741
+ name = Path("pnnx.exe" if WINDOWS else "pnnx") # PNNX filename
742
+ pnnx = name if name.is_file() else (ROOT / name)
743
+ if not pnnx.is_file():
744
+ LOGGER.warning(
745
+ f"{prefix} PNNX not found. Attempting to download binary file from "
746
+ "https://github.com/pnnx/pnnx/.\nNote PNNX Binary file must be placed in current working directory "
747
+ f"or in {ROOT}. See PNNX repo for full installation instructions."
748
+ )
749
+ system = "macos" if MACOS else "windows" if WINDOWS else "linux-aarch64" if ARM64 else "linux"
750
+ try:
751
+ release, assets = get_github_assets(repo="pnnx/pnnx")
752
+ asset = [x for x in assets if f"{system}.zip" in x][0]
753
+ assert isinstance(asset, str), "Unable to retrieve PNNX repo assets" # i.e. pnnx-20240410-macos.zip
754
+ LOGGER.info(f"{prefix} successfully found latest PNNX asset file {asset}")
755
+ except Exception as e:
756
+ release = "20240410"
757
+ asset = f"pnnx-{release}-{system}.zip"
758
+ LOGGER.warning(f"{prefix} PNNX GitHub assets not found: {e}, using default {asset}")
759
+ unzip_dir = safe_download(f"https://github.com/pnnx/pnnx/releases/download/{release}/{asset}", delete=True)
760
+ if check_is_path_safe(Path.cwd(), unzip_dir): # avoid path traversal security vulnerability
761
+ shutil.move(src=unzip_dir / name, dst=pnnx) # move binary to ROOT
762
+ pnnx.chmod(0o777) # set read, write, and execute permissions for everyone
763
+ shutil.rmtree(unzip_dir) # delete unzip dir
764
+
765
+ ncnn_args = [
766
+ f"ncnnparam={f / 'model.ncnn.param'}",
767
+ f"ncnnbin={f / 'model.ncnn.bin'}",
768
+ f"ncnnpy={f / 'model_ncnn.py'}",
769
+ ]
770
+
771
+ pnnx_args = [
772
+ f"pnnxparam={f / 'model.pnnx.param'}",
773
+ f"pnnxbin={f / 'model.pnnx.bin'}",
774
+ f"pnnxpy={f / 'model_pnnx.py'}",
775
+ f"pnnxonnx={f / 'model.pnnx.onnx'}",
776
+ ]
777
+
778
+ cmd = [
779
+ str(pnnx),
780
+ str(f_ts),
781
+ *ncnn_args,
782
+ *pnnx_args,
783
+ f"fp16={int(self.args.half)}",
784
+ f"device={self.device.type}",
785
+ f'inputshape="{[self.args.batch, 3, *self.imgsz]}"',
786
+ ]
787
+ f.mkdir(exist_ok=True) # make ncnn_model directory
788
+ LOGGER.info(f"{prefix} running '{' '.join(cmd)}'")
789
+ subprocess.run(cmd, check=True)
790
+
791
+ # Remove debug files
792
+ pnnx_files = [x.split("=")[-1] for x in pnnx_args]
793
+ for f_debug in ("debug.bin", "debug.param", "debug2.bin", "debug2.param", *pnnx_files):
794
+ Path(f_debug).unlink(missing_ok=True)
795
+
796
+ YAML.save(f / "metadata.yaml", self.metadata) # add metadata.yaml
797
+ return str(f), None
798
+
799
+ @try_export
800
+ def export_coreml(self, prefix=colorstr("CoreML:")):
801
+ """YOLO CoreML export."""
802
+ mlmodel = self.args.format.lower() == "mlmodel" # legacy *.mlmodel export format requested
803
+ check_requirements("coremltools>=8.0")
804
+ import coremltools as ct # noqa
805
+
806
+ LOGGER.info(f"\n{prefix} starting export with coremltools {ct.__version__}...")
807
+ assert not WINDOWS, "CoreML export is not supported on Windows, please run on macOS or Linux."
808
+ assert self.args.batch == 1, "CoreML batch sizes > 1 are not supported. Please retry at 'batch=1'."
809
+ f = self.file.with_suffix(".mlmodel" if mlmodel else ".mlpackage")
810
+ if f.is_dir():
811
+ shutil.rmtree(f)
812
+
813
+ bias = [0.0, 0.0, 0.0]
814
+ scale = 1 / 255
815
+ classifier_config = None
816
+ if self.model.task == "classify":
817
+ classifier_config = ct.ClassifierConfig(list(self.model.names.values()))
818
+ model = self.model
819
+ elif self.model.task == "detect":
820
+ model = IOSDetectModel(self.model, self.im) if self.args.nms else self.model
821
+ else:
822
+ if self.args.nms:
823
+ LOGGER.warning(f"{prefix} 'nms=True' is only available for Detect models like 'yolo11n.pt'.")
824
+ # TODO CoreML Segment and Pose model pipelining
825
+ model = self.model
826
+ ts = torch.jit.trace(model.eval(), self.im, strict=False) # TorchScript model
827
+
828
+ # Based on apple's documentation it is better to leave out the minimum_deployment target and let that get set
829
+ # Internally based on the model conversion and output type.
830
+ # Setting minimum_depoloyment_target >= iOS16 will require setting compute_precision=ct.precision.FLOAT32.
831
+ # iOS16 adds in better support for FP16, but none of the CoreML NMS specifications handle FP16 as input.
832
+ ct_model = ct.convert(
833
+ ts,
834
+ inputs=[ct.ImageType("image", shape=self.im.shape, scale=scale, bias=bias)], # expects ct.TensorType
835
+ classifier_config=classifier_config,
836
+ convert_to="neuralnetwork" if mlmodel else "mlprogram",
837
+ )
838
+ bits, mode = (8, "kmeans") if self.args.int8 else (16, "linear") if self.args.half else (32, None)
839
+ if bits < 32:
840
+ if "kmeans" in mode:
841
+ check_requirements("scikit-learn") # scikit-learn package required for k-means quantization
842
+ if mlmodel:
843
+ ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
844
+ elif bits == 8: # mlprogram already quantized to FP16
845
+ import coremltools.optimize.coreml as cto
846
+
847
+ op_config = cto.OpPalettizerConfig(mode="kmeans", nbits=bits, weight_threshold=512)
848
+ config = cto.OptimizationConfig(global_config=op_config)
849
+ ct_model = cto.palettize_weights(ct_model, config=config)
850
+ if self.args.nms and self.model.task == "detect":
851
+ if mlmodel:
852
+ weights_dir = None
853
+ else:
854
+ ct_model.save(str(f)) # save otherwise weights_dir does not exist
855
+ weights_dir = str(f / "Data/com.apple.CoreML/weights")
856
+ ct_model = self._pipeline_coreml(ct_model, weights_dir=weights_dir)
857
+
858
+ m = self.metadata # metadata dict
859
+ ct_model.short_description = m.pop("description")
860
+ ct_model.author = m.pop("author")
861
+ ct_model.license = m.pop("license")
862
+ ct_model.version = m.pop("version")
863
+ ct_model.user_defined_metadata.update({k: str(v) for k, v in m.items()})
864
+ if self.model.task == "classify":
865
+ ct_model.user_defined_metadata.update({"com.apple.coreml.model.preview.type": "imageClassifier"})
866
+
867
+ try:
868
+ ct_model.save(str(f)) # save *.mlpackage
869
+ except Exception as e:
870
+ LOGGER.warning(
871
+ f"{prefix} CoreML export to *.mlpackage failed ({e}), reverting to *.mlmodel export. "
872
+ f"Known coremltools Python 3.11 and Windows bugs https://github.com/apple/coremltools/issues/1928."
873
+ )
874
+ f = f.with_suffix(".mlmodel")
875
+ ct_model.save(str(f))
876
+ return f, ct_model
877
+
878
+ @try_export
879
+ def export_engine(self, dla=None, prefix=colorstr("TensorRT:")):
880
+ """YOLO TensorRT export https://developer.nvidia.com/tensorrt."""
881
+ assert self.im.device.type != "cpu", "export running on CPU but must be on GPU, i.e. use 'device=0'"
882
+ f_onnx, _ = self.export_onnx() # run before TRT import https://github.com/ultralytics/ultralytics/issues/7016
883
+
884
+ try:
885
+ import tensorrt as trt # noqa
886
+ except ImportError:
887
+ if LINUX:
888
+ check_requirements("tensorrt>7.0.0,!=10.1.0")
889
+ import tensorrt as trt # noqa
890
+ check_version(trt.__version__, ">=7.0.0", hard=True)
891
+ check_version(trt.__version__, "!=10.1.0", msg="https://github.com/ultralytics/ultralytics/pull/14239")
892
+
893
+ # Setup and checks
894
+ LOGGER.info(f"\n{prefix} starting export with TensorRT {trt.__version__}...")
895
+ assert Path(f_onnx).exists(), f"failed to export ONNX file: {f_onnx}"
896
+ f = self.file.with_suffix(".engine") # TensorRT engine file
897
+ export_engine(
898
+ f_onnx,
899
+ f,
900
+ self.args.workspace,
901
+ self.args.half,
902
+ self.args.int8,
903
+ self.args.dynamic,
904
+ self.im.shape,
905
+ dla=dla,
906
+ dataset=self.get_int8_calibration_dataloader(prefix) if self.args.int8 else None,
907
+ metadata=self.metadata,
908
+ verbose=self.args.verbose,
909
+ prefix=prefix,
910
+ )
911
+
912
+ return f, None
913
+
914
+ @try_export
915
+ def export_saved_model(self, prefix=colorstr("TensorFlow SavedModel:")):
916
+ """YOLO TensorFlow SavedModel export."""
917
+ cuda = torch.cuda.is_available()
918
+ try:
919
+ import tensorflow as tf # noqa
920
+ except ImportError:
921
+ check_requirements("tensorflow>=2.0.0")
922
+ import tensorflow as tf # noqa
923
+ check_requirements(
924
+ (
925
+ "tf_keras", # required by 'onnx2tf' package
926
+ "sng4onnx>=1.0.1", # required by 'onnx2tf' package
927
+ "onnx_graphsurgeon>=0.3.26", # required by 'onnx2tf' package
928
+ "ai-edge-litert>=1.2.0", # required by 'onnx2tf' package
929
+ "onnx>=1.12.0",
930
+ "onnx2tf>=1.26.3",
931
+ "onnxslim>=0.1.46",
932
+ "onnxruntime-gpu" if cuda else "onnxruntime",
933
+ "protobuf>=5",
934
+ ),
935
+ cmds="--extra-index-url https://pypi.ngc.nvidia.com", # onnx_graphsurgeon only on NVIDIA
936
+ )
937
+
938
+ LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
939
+ check_version(
940
+ tf.__version__,
941
+ ">=2.0.0",
942
+ name="tensorflow",
943
+ verbose=True,
944
+ msg="https://github.com/ultralytics/ultralytics/issues/5161",
945
+ )
946
+ import onnx2tf
947
+
948
+ f = Path(str(self.file).replace(self.file.suffix, "_saved_model"))
949
+ if f.is_dir():
950
+ shutil.rmtree(f) # delete output folder
951
+
952
+ # Pre-download calibration file to fix https://github.com/PINTO0309/onnx2tf/issues/545
953
+ onnx2tf_file = Path("calibration_image_sample_data_20x128x128x3_float32.npy")
954
+ if not onnx2tf_file.exists():
955
+ attempt_download_asset(f"{onnx2tf_file}.zip", unzip=True, delete=True)
956
+
957
+ # Export to ONNX
958
+ self.args.simplify = True
959
+ f_onnx, _ = self.export_onnx()
960
+
961
+ # Export to TF
962
+ np_data = None
963
+ if self.args.int8:
964
+ tmp_file = f / "tmp_tflite_int8_calibration_images.npy" # int8 calibration images file
965
+ if self.args.data:
966
+ f.mkdir()
967
+ images = [batch["img"] for batch in self.get_int8_calibration_dataloader(prefix)]
968
+ images = torch.nn.functional.interpolate(torch.cat(images, 0).float(), size=self.imgsz).permute(
969
+ 0, 2, 3, 1
970
+ )
971
+ np.save(str(tmp_file), images.numpy().astype(np.float32)) # BHWC
972
+ np_data = [["images", tmp_file, [[[[0, 0, 0]]]], [[[[255, 255, 255]]]]]]
973
+
974
+ LOGGER.info(f"{prefix} starting TFLite export with onnx2tf {onnx2tf.__version__}...")
975
+ keras_model = onnx2tf.convert(
976
+ input_onnx_file_path=f_onnx,
977
+ output_folder_path=str(f),
978
+ not_use_onnxsim=True,
979
+ verbosity="error", # note INT8-FP16 activation bug https://github.com/ultralytics/ultralytics/issues/15873
980
+ output_integer_quantized_tflite=self.args.int8,
981
+ quant_type="per-tensor", # "per-tensor" (faster) or "per-channel" (slower but more accurate)
982
+ custom_input_op_name_np_data_path=np_data,
983
+ enable_batchmatmul_unfold=True, # fix lower no. of detected objects on GPU delegate
984
+ output_signaturedefs=True, # fix error with Attention block group convolution
985
+ optimization_for_gpu_delegate=True,
986
+ )
987
+ YAML.save(f / "metadata.yaml", self.metadata) # add metadata.yaml
988
+
989
+ # Remove/rename TFLite models
990
+ if self.args.int8:
991
+ tmp_file.unlink(missing_ok=True)
992
+ for file in f.rglob("*_dynamic_range_quant.tflite"):
993
+ file.rename(file.with_name(file.stem.replace("_dynamic_range_quant", "_int8") + file.suffix))
994
+ for file in f.rglob("*_integer_quant_with_int16_act.tflite"):
995
+ file.unlink() # delete extra fp16 activation TFLite files
996
+
997
+ # Add TFLite metadata
998
+ for file in f.rglob("*.tflite"):
999
+ f.unlink() if "quant_with_int16_act.tflite" in str(f) else self._add_tflite_metadata(file)
1000
+
1001
+ return str(f), keras_model # or keras_model = tf.saved_model.load(f, tags=None, options=None)
1002
+
1003
+ @try_export
1004
+ def export_pb(self, keras_model, prefix=colorstr("TensorFlow GraphDef:")):
1005
+ """YOLO TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen-Graph-TensorFlow."""
1006
+ import tensorflow as tf # noqa
1007
+ from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa
1008
+
1009
+ LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
1010
+ f = self.file.with_suffix(".pb")
1011
+
1012
+ m = tf.function(lambda x: keras_model(x)) # full model
1013
+ m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
1014
+ frozen_func = convert_variables_to_constants_v2(m)
1015
+ frozen_func.graph.as_graph_def()
1016
+ tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)
1017
+ return f, None
1018
+
1019
+ @try_export
1020
+ def export_tflite(self, prefix=colorstr("TensorFlow Lite:")):
1021
+ """YOLO TensorFlow Lite export."""
1022
+ # BUG https://github.com/ultralytics/ultralytics/issues/13436
1023
+ import tensorflow as tf # noqa
1024
+
1025
+ LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
1026
+ saved_model = Path(str(self.file).replace(self.file.suffix, "_saved_model"))
1027
+ if self.args.int8:
1028
+ f = saved_model / f"{self.file.stem}_int8.tflite" # fp32 in/out
1029
+ elif self.args.half:
1030
+ f = saved_model / f"{self.file.stem}_float16.tflite" # fp32 in/out
1031
+ else:
1032
+ f = saved_model / f"{self.file.stem}_float32.tflite"
1033
+ return str(f), None
1034
+
1035
+ @try_export
1036
+ def export_edgetpu(self, tflite_model="", prefix=colorstr("Edge TPU:")):
1037
+ """YOLO Edge TPU export https://coral.ai/docs/edgetpu/models-intro/."""
1038
+ cmd = "edgetpu_compiler --version"
1039
+ help_url = "https://coral.ai/docs/edgetpu/compiler/"
1040
+ assert LINUX, f"export only supported on Linux. See {help_url}"
1041
+ if subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True).returncode != 0:
1042
+ LOGGER.info(f"\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}")
1043
+ for c in (
1044
+ "curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -",
1045
+ 'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | '
1046
+ "sudo tee /etc/apt/sources.list.d/coral-edgetpu.list",
1047
+ "sudo apt-get update",
1048
+ "sudo apt-get install edgetpu-compiler",
1049
+ ):
1050
+ subprocess.run(c if is_sudo_available() else c.replace("sudo ", ""), shell=True, check=True)
1051
+ ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]
1052
+
1053
+ LOGGER.info(f"\n{prefix} starting export with Edge TPU compiler {ver}...")
1054
+ f = str(tflite_model).replace(".tflite", "_edgetpu.tflite") # Edge TPU model
1055
+
1056
+ cmd = (
1057
+ "edgetpu_compiler "
1058
+ f'--out_dir "{Path(f).parent}" '
1059
+ "--show_operations "
1060
+ "--search_delegate "
1061
+ "--delegate_search_step 30 "
1062
+ "--timeout_sec 180 "
1063
+ f'"{tflite_model}"'
1064
+ )
1065
+ LOGGER.info(f"{prefix} running '{cmd}'")
1066
+ subprocess.run(cmd, shell=True)
1067
+ self._add_tflite_metadata(f)
1068
+ return f, None
1069
+
1070
+ @try_export
1071
+ def export_tfjs(self, prefix=colorstr("TensorFlow.js:")):
1072
+ """YOLO TensorFlow.js export."""
1073
+ check_requirements("tensorflowjs")
1074
+ import tensorflow as tf
1075
+ import tensorflowjs as tfjs # noqa
1076
+
1077
+ LOGGER.info(f"\n{prefix} starting export with tensorflowjs {tfjs.__version__}...")
1078
+ f = str(self.file).replace(self.file.suffix, "_web_model") # js dir
1079
+ f_pb = str(self.file.with_suffix(".pb")) # *.pb path
1080
+
1081
+ gd = tf.Graph().as_graph_def() # TF GraphDef
1082
+ with open(f_pb, "rb") as file:
1083
+ gd.ParseFromString(file.read())
1084
+ outputs = ",".join(gd_outputs(gd))
1085
+ LOGGER.info(f"\n{prefix} output node names: {outputs}")
1086
+
1087
+ quantization = "--quantize_float16" if self.args.half else "--quantize_uint8" if self.args.int8 else ""
1088
+ with spaces_in_path(f_pb) as fpb_, spaces_in_path(f) as f_: # exporter can not handle spaces in path
1089
+ cmd = (
1090
+ "tensorflowjs_converter "
1091
+ f'--input_format=tf_frozen_model {quantization} --output_node_names={outputs} "{fpb_}" "{f_}"'
1092
+ )
1093
+ LOGGER.info(f"{prefix} running '{cmd}'")
1094
+ subprocess.run(cmd, shell=True)
1095
+
1096
+ if " " in f:
1097
+ LOGGER.warning(f"{prefix} your model may not work correctly with spaces in path '{f}'.")
1098
+
1099
+ # Add metadata
1100
+ YAML.save(Path(f) / "metadata.yaml", self.metadata) # add metadata.yaml
1101
+ return f, None
1102
+
1103
+ @try_export
1104
+ def export_rknn(self, prefix=colorstr("RKNN:")):
1105
+ """YOLO RKNN model export."""
1106
+ LOGGER.info(f"\n{prefix} starting export with rknn-toolkit2...")
1107
+
1108
+ check_requirements("rknn-toolkit2")
1109
+ if IS_COLAB:
1110
+ # Prevent 'exit' from closing the notebook https://github.com/airockchip/rknn-toolkit2/issues/259
1111
+ import builtins
1112
+
1113
+ builtins.exit = lambda: None
1114
+
1115
+ from rknn.api import RKNN
1116
+
1117
+ f, _ = self.export_onnx()
1118
+ export_path = Path(f"{Path(f).stem}_rknn_model")
1119
+ export_path.mkdir(exist_ok=True)
1120
+
1121
+ rknn = RKNN(verbose=False)
1122
+ rknn.config(mean_values=[[0, 0, 0]], std_values=[[255, 255, 255]], target_platform=self.args.name)
1123
+ rknn.load_onnx(model=f)
1124
+ rknn.build(do_quantization=False) # TODO: Add quantization support
1125
+ f = f.replace(".onnx", f"-{self.args.name}.rknn")
1126
+ rknn.export_rknn(f"{export_path / f}")
1127
+ YAML.save(export_path / "metadata.yaml", self.metadata)
1128
+ return export_path, None
1129
+
1130
+ @try_export
1131
+ def export_imx(self, prefix=colorstr("IMX:")):
1132
+ """YOLO IMX export."""
1133
+ gptq = False
1134
+ assert LINUX, (
1135
+ "export only supported on Linux. "
1136
+ "See https://developer.aitrios.sony-semicon.com/en/raspberrypi-ai-camera/documentation/imx500-converter"
1137
+ )
1138
+ if getattr(self.model, "end2end", False):
1139
+ raise ValueError("IMX export is not supported for end2end models.")
1140
+ check_requirements(("model-compression-toolkit>=2.3.0", "sony-custom-layers>=0.3.0", "edge-mdt-tpc>=1.1.0"))
1141
+ check_requirements("imx500-converter[pt]>=3.16.1") # Separate requirements for imx500-converter
1142
+
1143
+ import model_compression_toolkit as mct
1144
+ import onnx
1145
+ from edgemdt_tpc import get_target_platform_capabilities
1146
+ from sony_custom_layers.pytorch import multiclass_nms
1147
+
1148
+ LOGGER.info(f"\n{prefix} starting export with model_compression_toolkit {mct.__version__}...")
1149
+
1150
+ # Install Java>=17
1151
+ try:
1152
+ java_output = subprocess.run(["java", "--version"], check=True, capture_output=True).stdout.decode()
1153
+ version_match = re.search(r"(?:openjdk|java) (\d+)", java_output)
1154
+ java_version = int(version_match.group(1)) if version_match else 0
1155
+ assert java_version >= 17, "Java version too old"
1156
+ except (FileNotFoundError, subprocess.CalledProcessError, AssertionError):
1157
+ cmd = (["sudo"] if is_sudo_available() else []) + ["apt", "install", "-y", "openjdk-21-jre"]
1158
+ subprocess.run(cmd, check=True)
1159
+
1160
+ def representative_dataset_gen(dataloader=self.get_int8_calibration_dataloader(prefix)):
1161
+ for batch in dataloader:
1162
+ img = batch["img"]
1163
+ img = img / 255.0
1164
+ yield [img]
1165
+
1166
+ tpc = get_target_platform_capabilities(tpc_version="4.0", device_type="imx500")
1167
+
1168
+ bit_cfg = mct.core.BitWidthConfig()
1169
+ if "C2PSA" in self.model.__str__(): # YOLO11
1170
+ layer_names = ["sub", "mul_2", "add_14", "cat_21"]
1171
+ weights_memory = 2585350.2439
1172
+ n_layers = 238 # 238 layers for fused YOLO11n
1173
+ else: # YOLOv8
1174
+ layer_names = ["sub", "mul", "add_6", "cat_17"]
1175
+ weights_memory = 2550540.8
1176
+ n_layers = 168 # 168 layers for fused YOLOv8n
1177
+
1178
+ # Check if the model has the expected number of layers
1179
+ if len(list(self.model.modules())) != n_layers:
1180
+ raise ValueError("IMX export only supported for YOLOv8n and YOLO11n models.")
1181
+
1182
+ for layer_name in layer_names:
1183
+ bit_cfg.set_manual_activation_bit_width([mct.core.common.network_editors.NodeNameFilter(layer_name)], 16)
1184
+
1185
+ config = mct.core.CoreConfig(
1186
+ mixed_precision_config=mct.core.MixedPrecisionQuantizationConfig(num_of_images=10),
1187
+ quantization_config=mct.core.QuantizationConfig(concat_threshold_update=True),
1188
+ bit_width_config=bit_cfg,
1189
+ )
1190
+
1191
+ resource_utilization = mct.core.ResourceUtilization(weights_memory=weights_memory)
1192
+
1193
+ quant_model = (
1194
+ mct.gptq.pytorch_gradient_post_training_quantization( # Perform Gradient-Based Post Training Quantization
1195
+ model=self.model,
1196
+ representative_data_gen=representative_dataset_gen,
1197
+ target_resource_utilization=resource_utilization,
1198
+ gptq_config=mct.gptq.get_pytorch_gptq_config(
1199
+ n_epochs=1000, use_hessian_based_weights=False, use_hessian_sample_attention=False
1200
+ ),
1201
+ core_config=config,
1202
+ target_platform_capabilities=tpc,
1203
+ )[0]
1204
+ if gptq
1205
+ else mct.ptq.pytorch_post_training_quantization( # Perform post training quantization
1206
+ in_module=self.model,
1207
+ representative_data_gen=representative_dataset_gen,
1208
+ target_resource_utilization=resource_utilization,
1209
+ core_config=config,
1210
+ target_platform_capabilities=tpc,
1211
+ )[0]
1212
+ )
1213
+
1214
+ class NMSWrapper(torch.nn.Module):
1215
+ def __init__(
1216
+ self,
1217
+ model: torch.nn.Module,
1218
+ score_threshold: float = 0.001,
1219
+ iou_threshold: float = 0.7,
1220
+ max_detections: int = 300,
1221
+ ):
1222
+ """
1223
+ Wrapping PyTorch Module with multiclass_nms layer from sony_custom_layers.
1224
+
1225
+ Args:
1226
+ model (nn.Module): Model instance.
1227
+ score_threshold (float): Score threshold for non-maximum suppression.
1228
+ iou_threshold (float): Intersection over union threshold for non-maximum suppression.
1229
+ max_detections (float): The number of detections to return.
1230
+ """
1231
+ super().__init__()
1232
+ self.model = model
1233
+ self.score_threshold = score_threshold
1234
+ self.iou_threshold = iou_threshold
1235
+ self.max_detections = max_detections
1236
+
1237
+ def forward(self, images):
1238
+ # model inference
1239
+ outputs = self.model(images)
1240
+
1241
+ boxes = outputs[0]
1242
+ scores = outputs[1]
1243
+ nms = multiclass_nms(
1244
+ boxes=boxes,
1245
+ scores=scores,
1246
+ score_threshold=self.score_threshold,
1247
+ iou_threshold=self.iou_threshold,
1248
+ max_detections=self.max_detections,
1249
+ )
1250
+ return nms
1251
+
1252
+ quant_model = NMSWrapper(
1253
+ model=quant_model,
1254
+ score_threshold=self.args.conf or 0.001,
1255
+ iou_threshold=self.args.iou,
1256
+ max_detections=self.args.max_det,
1257
+ ).to(self.device)
1258
+
1259
+ f = Path(str(self.file).replace(self.file.suffix, "_imx_model"))
1260
+ f.mkdir(exist_ok=True)
1261
+ onnx_model = f / Path(str(self.file.name).replace(self.file.suffix, "_imx.onnx")) # js dir
1262
+ mct.exporter.pytorch_export_model(
1263
+ model=quant_model, save_model_path=onnx_model, repr_dataset=representative_dataset_gen
1264
+ )
1265
+
1266
+ model_onnx = onnx.load(onnx_model) # load onnx model
1267
+ for k, v in self.metadata.items():
1268
+ meta = model_onnx.metadata_props.add()
1269
+ meta.key, meta.value = k, str(v)
1270
+
1271
+ onnx.save(model_onnx, onnx_model)
1272
+
1273
+ subprocess.run(
1274
+ ["imxconv-pt", "-i", str(onnx_model), "-o", str(f), "--no-input-persistency", "--overwrite-output"],
1275
+ check=True,
1276
+ )
1277
+
1278
+ # Needed for imx models.
1279
+ with open(f / "labels.txt", "w", encoding="utf-8") as file:
1280
+ file.writelines([f"{name}\n" for _, name in self.model.names.items()])
1281
+
1282
+ return f, None
1283
+
1284
+ def _add_tflite_metadata(self, file):
1285
+ """Add metadata to *.tflite models per https://ai.google.dev/edge/litert/models/metadata."""
1286
+ import zipfile
1287
+
1288
+ with zipfile.ZipFile(file, "a", zipfile.ZIP_DEFLATED) as zf:
1289
+ zf.writestr("metadata.json", json.dumps(self.metadata, indent=2))
1290
+
1291
+ def _pipeline_coreml(self, model, weights_dir=None, prefix=colorstr("CoreML Pipeline:")):
1292
+ """YOLO CoreML pipeline."""
1293
+ import coremltools as ct # noqa
1294
+
1295
+ LOGGER.info(f"{prefix} starting pipeline with coremltools {ct.__version__}...")
1296
+ _, _, h, w = list(self.im.shape) # BCHW
1297
+
1298
+ # Output shapes
1299
+ spec = model.get_spec()
1300
+ out0, out1 = iter(spec.description.output)
1301
+ if MACOS:
1302
+ from PIL import Image
1303
+
1304
+ img = Image.new("RGB", (w, h)) # w=192, h=320
1305
+ out = model.predict({"image": img})
1306
+ out0_shape = out[out0.name].shape # (3780, 80)
1307
+ out1_shape = out[out1.name].shape # (3780, 4)
1308
+ else: # linux and windows can not run model.predict(), get sizes from PyTorch model output y
1309
+ out0_shape = self.output_shape[2], self.output_shape[1] - 4 # (3780, 80)
1310
+ out1_shape = self.output_shape[2], 4 # (3780, 4)
1311
+
1312
+ # Checks
1313
+ names = self.metadata["names"]
1314
+ nx, ny = spec.description.input[0].type.imageType.width, spec.description.input[0].type.imageType.height
1315
+ _, nc = out0_shape # number of anchors, number of classes
1316
+ assert len(names) == nc, f"{len(names)} names found for nc={nc}" # check
1317
+
1318
+ # Define output shapes (missing)
1319
+ out0.type.multiArrayType.shape[:] = out0_shape # (3780, 80)
1320
+ out1.type.multiArrayType.shape[:] = out1_shape # (3780, 4)
1321
+
1322
+ # Model from spec
1323
+ model = ct.models.MLModel(spec, weights_dir=weights_dir)
1324
+
1325
+ # 3. Create NMS protobuf
1326
+ nms_spec = ct.proto.Model_pb2.Model()
1327
+ nms_spec.specificationVersion = spec.specificationVersion
1328
+ for i in range(2):
1329
+ decoder_output = model._spec.description.output[i].SerializeToString()
1330
+ nms_spec.description.input.add()
1331
+ nms_spec.description.input[i].ParseFromString(decoder_output)
1332
+ nms_spec.description.output.add()
1333
+ nms_spec.description.output[i].ParseFromString(decoder_output)
1334
+
1335
+ nms_spec.description.output[0].name = "confidence"
1336
+ nms_spec.description.output[1].name = "coordinates"
1337
+
1338
+ output_sizes = [nc, 4]
1339
+ for i in range(2):
1340
+ ma_type = nms_spec.description.output[i].type.multiArrayType
1341
+ ma_type.shapeRange.sizeRanges.add()
1342
+ ma_type.shapeRange.sizeRanges[0].lowerBound = 0
1343
+ ma_type.shapeRange.sizeRanges[0].upperBound = -1
1344
+ ma_type.shapeRange.sizeRanges.add()
1345
+ ma_type.shapeRange.sizeRanges[1].lowerBound = output_sizes[i]
1346
+ ma_type.shapeRange.sizeRanges[1].upperBound = output_sizes[i]
1347
+ del ma_type.shape[:]
1348
+
1349
+ nms = nms_spec.nonMaximumSuppression
1350
+ nms.confidenceInputFeatureName = out0.name # 1x507x80
1351
+ nms.coordinatesInputFeatureName = out1.name # 1x507x4
1352
+ nms.confidenceOutputFeatureName = "confidence"
1353
+ nms.coordinatesOutputFeatureName = "coordinates"
1354
+ nms.iouThresholdInputFeatureName = "iouThreshold"
1355
+ nms.confidenceThresholdInputFeatureName = "confidenceThreshold"
1356
+ nms.iouThreshold = self.args.iou
1357
+ nms.confidenceThreshold = self.args.conf
1358
+ nms.pickTop.perClass = True
1359
+ nms.stringClassLabels.vector.extend(names.values())
1360
+ nms_model = ct.models.MLModel(nms_spec)
1361
+
1362
+ # 4. Pipeline models together
1363
+ pipeline = ct.models.pipeline.Pipeline(
1364
+ input_features=[
1365
+ ("image", ct.models.datatypes.Array(3, ny, nx)),
1366
+ ("iouThreshold", ct.models.datatypes.Double()),
1367
+ ("confidenceThreshold", ct.models.datatypes.Double()),
1368
+ ],
1369
+ output_features=["confidence", "coordinates"],
1370
+ )
1371
+ pipeline.add_model(model)
1372
+ pipeline.add_model(nms_model)
1373
+
1374
+ # Correct datatypes
1375
+ pipeline.spec.description.input[0].ParseFromString(model._spec.description.input[0].SerializeToString())
1376
+ pipeline.spec.description.output[0].ParseFromString(nms_model._spec.description.output[0].SerializeToString())
1377
+ pipeline.spec.description.output[1].ParseFromString(nms_model._spec.description.output[1].SerializeToString())
1378
+
1379
+ # Update metadata
1380
+ pipeline.spec.specificationVersion = spec.specificationVersion
1381
+ pipeline.spec.description.metadata.userDefined.update(
1382
+ {"IoU threshold": str(nms.iouThreshold), "Confidence threshold": str(nms.confidenceThreshold)}
1383
+ )
1384
+
1385
+ # Save the model
1386
+ model = ct.models.MLModel(pipeline.spec, weights_dir=weights_dir)
1387
+ model.input_description["image"] = "Input image"
1388
+ model.input_description["iouThreshold"] = f"(optional) IoU threshold override (default: {nms.iouThreshold})"
1389
+ model.input_description["confidenceThreshold"] = (
1390
+ f"(optional) Confidence threshold override (default: {nms.confidenceThreshold})"
1391
+ )
1392
+ model.output_description["confidence"] = 'Boxes × Class confidence (see user-defined metadata "classes")'
1393
+ model.output_description["coordinates"] = "Boxes × [x, y, width, height] (relative to image size)"
1394
+ LOGGER.info(f"{prefix} pipeline success")
1395
+ return model
1396
+
1397
+ def add_callback(self, event: str, callback):
1398
+ """Appends the given callback."""
1399
+ self.callbacks[event].append(callback)
1400
+
1401
+ def run_callbacks(self, event: str):
1402
+ """Execute all callbacks for a given event."""
1403
+ for callback in self.callbacks.get(event, []):
1404
+ callback(self)
1405
+
1406
+
1407
+ class IOSDetectModel(torch.nn.Module):
1408
+ """Wrap an Ultralytics YOLO model for Apple iOS CoreML export."""
1409
+
1410
+ def __init__(self, model, im):
1411
+ """Initialize the IOSDetectModel class with a YOLO model and example image."""
1412
+ super().__init__()
1413
+ _, _, h, w = im.shape # batch, channel, height, width
1414
+ self.model = model
1415
+ self.nc = len(model.names) # number of classes
1416
+ if w == h:
1417
+ self.normalize = 1.0 / w # scalar
1418
+ else:
1419
+ self.normalize = torch.tensor([1.0 / w, 1.0 / h, 1.0 / w, 1.0 / h]) # broadcast (slower, smaller)
1420
+
1421
+ def forward(self, x):
1422
+ """Normalize predictions of object detection model with input size-dependent factors."""
1423
+ xywh, cls = self.model(x)[0].transpose(0, 1).split((4, self.nc), 1)
1424
+ return cls, xywh * self.normalize # confidence (3780, 80), coordinates (3780, 4)
1425
+
1426
+
1427
+ class NMSModel(torch.nn.Module):
1428
+ """Model wrapper with embedded NMS for Detect, Segment, Pose and OBB."""
1429
+
1430
+ def __init__(self, model, args):
1431
+ """
1432
+ Initialize the NMSModel.
1433
+
1434
+ Args:
1435
+ model (torch.nn.module): The model to wrap with NMS postprocessing.
1436
+ args (Namespace): The export arguments.
1437
+ """
1438
+ super().__init__()
1439
+ self.model = model
1440
+ self.args = args
1441
+ self.obb = model.task == "obb"
1442
+ self.is_tf = self.args.format in frozenset({"saved_model", "tflite", "tfjs"})
1443
+
1444
+ def forward(self, x):
1445
+ """
1446
+ Performs inference with NMS post-processing. Supports Detect, Segment, OBB and Pose.
1447
+
1448
+ Args:
1449
+ x (torch.Tensor): The preprocessed tensor with shape (N, 3, H, W).
1450
+
1451
+ Returns:
1452
+ (torch.Tensor): List of detections, each an (N, max_det, 4 + 2 + extra_shape) Tensor where N is the number of detections after NMS.
1453
+ """
1454
+ from functools import partial
1455
+
1456
+ from torchvision.ops import nms
1457
+
1458
+ preds = self.model(x)
1459
+ pred = preds[0] if isinstance(preds, tuple) else preds
1460
+ kwargs = dict(device=pred.device, dtype=pred.dtype)
1461
+ bs = pred.shape[0]
1462
+ pred = pred.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
1463
+ extra_shape = pred.shape[-1] - (4 + len(self.model.names)) # extras from Segment, OBB, Pose
1464
+ if self.args.dynamic and self.args.batch > 1: # batch size needs to always be same due to loop unroll
1465
+ pad = torch.zeros(torch.max(torch.tensor(self.args.batch - bs), torch.tensor(0)), *pred.shape[1:], **kwargs)
1466
+ pred = torch.cat((pred, pad))
1467
+ boxes, scores, extras = pred.split([4, len(self.model.names), extra_shape], dim=2)
1468
+ scores, classes = scores.max(dim=-1)
1469
+ self.args.max_det = min(pred.shape[1], self.args.max_det) # in case num_anchors < max_det
1470
+ # (N, max_det, 4 coords + 1 class score + 1 class label + extra_shape).
1471
+ out = torch.zeros(bs, self.args.max_det, boxes.shape[-1] + 2 + extra_shape, **kwargs)
1472
+ for i in range(bs):
1473
+ box, cls, score, extra = boxes[i], classes[i], scores[i], extras[i]
1474
+ mask = score > self.args.conf
1475
+ if self.is_tf:
1476
+ # TFLite GatherND error if mask is empty
1477
+ score *= mask
1478
+ # Explicit length otherwise reshape error, hardcoded to `self.args.max_det * 5`
1479
+ mask = score.topk(min(self.args.max_det * 5, score.shape[0])).indices
1480
+ box, score, cls, extra = box[mask], score[mask], cls[mask], extra[mask]
1481
+ nmsbox = box.clone()
1482
+ # `8` is the minimum value experimented to get correct NMS results for obb
1483
+ multiplier = 8 if self.obb else 1
1484
+ # Normalize boxes for NMS since large values for class offset causes issue with int8 quantization
1485
+ if self.args.format == "tflite": # TFLite is already normalized
1486
+ nmsbox *= multiplier
1487
+ else:
1488
+ nmsbox = multiplier * nmsbox / torch.tensor(x.shape[2:], **kwargs).max()
1489
+ if not self.args.agnostic_nms: # class-specific NMS
1490
+ end = 2 if self.obb else 4
1491
+ # fully explicit expansion otherwise reshape error
1492
+ # large max_wh causes issues when quantizing
1493
+ cls_offset = cls.reshape(-1, 1).expand(nmsbox.shape[0], end)
1494
+ offbox = nmsbox[:, :end] + cls_offset * multiplier
1495
+ nmsbox = torch.cat((offbox, nmsbox[:, end:]), dim=-1)
1496
+ nms_fn = (
1497
+ partial(
1498
+ nms_rotated,
1499
+ use_triu=not (
1500
+ self.is_tf
1501
+ or (self.args.opset or 14) < 14
1502
+ or (self.args.format == "openvino" and self.args.int8) # OpenVINO int8 error with triu
1503
+ ),
1504
+ )
1505
+ if self.obb
1506
+ else nms
1507
+ )
1508
+ keep = nms_fn(
1509
+ torch.cat([nmsbox, extra], dim=-1) if self.obb else nmsbox,
1510
+ score,
1511
+ self.args.iou,
1512
+ )[: self.args.max_det]
1513
+ dets = torch.cat(
1514
+ [box[keep], score[keep].view(-1, 1), cls[keep].view(-1, 1).to(out.dtype), extra[keep]], dim=-1
1515
+ )
1516
+ # Zero-pad to max_det size to avoid reshape error
1517
+ pad = (0, 0, 0, self.args.max_det - dets.shape[0])
1518
+ out[i] = torch.nn.functional.pad(dets, pad)
1519
+ return (out[:bs], preds[1]) if self.model.task == "segment" else out[:bs]