dgenerate-ultralytics-headless 8.3.134__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dgenerate_ultralytics_headless-8.3.134.dist-info/METADATA +400 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/RECORD +272 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/WHEEL +5 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/entry_points.txt +3 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/licenses/LICENSE +661 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/top_level.txt +1 -0
- tests/__init__.py +22 -0
- tests/conftest.py +83 -0
- tests/test_cli.py +138 -0
- tests/test_cuda.py +215 -0
- tests/test_engine.py +131 -0
- tests/test_exports.py +236 -0
- tests/test_integrations.py +154 -0
- tests/test_python.py +694 -0
- tests/test_solutions.py +187 -0
- ultralytics/__init__.py +30 -0
- ultralytics/assets/bus.jpg +0 -0
- ultralytics/assets/zidane.jpg +0 -0
- ultralytics/cfg/__init__.py +1023 -0
- ultralytics/cfg/datasets/Argoverse.yaml +77 -0
- ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
- ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +33 -0
- ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
- ultralytics/cfg/datasets/Objects365.yaml +443 -0
- ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
- ultralytics/cfg/datasets/VOC.yaml +106 -0
- ultralytics/cfg/datasets/VisDrone.yaml +77 -0
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
- ultralytics/cfg/datasets/coco-pose.yaml +42 -0
- ultralytics/cfg/datasets/coco.yaml +118 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco128.yaml +101 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
- ultralytics/cfg/datasets/coco8-pose.yaml +26 -0
- ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco8.yaml +101 -0
- ultralytics/cfg/datasets/crack-seg.yaml +22 -0
- ultralytics/cfg/datasets/dog-pose.yaml +24 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
- ultralytics/cfg/datasets/dota8.yaml +35 -0
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
- ultralytics/cfg/datasets/lvis.yaml +1240 -0
- ultralytics/cfg/datasets/medical-pills.yaml +22 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +666 -0
- ultralytics/cfg/datasets/package-seg.yaml +22 -0
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +25 -0
- ultralytics/cfg/datasets/xView.yaml +155 -0
- ultralytics/cfg/default.yaml +127 -0
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
- ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
- ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
- ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
- ultralytics/cfg/models/12/yolo12.yaml +48 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
- ultralytics/cfg/models/v3/yolov3.yaml +49 -0
- ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
- ultralytics/cfg/models/v5/yolov5.yaml +51 -0
- ultralytics/cfg/models/v6/yolov6.yaml +56 -0
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
- ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
- ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8.yaml +49 -0
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/trackers/botsort.yaml +22 -0
- ultralytics/cfg/trackers/bytetrack.yaml +14 -0
- ultralytics/data/__init__.py +26 -0
- ultralytics/data/annotator.py +66 -0
- ultralytics/data/augment.py +2945 -0
- ultralytics/data/base.py +438 -0
- ultralytics/data/build.py +258 -0
- ultralytics/data/converter.py +754 -0
- ultralytics/data/dataset.py +834 -0
- ultralytics/data/loaders.py +676 -0
- ultralytics/data/scripts/download_weights.sh +18 -0
- ultralytics/data/scripts/get_coco.sh +61 -0
- ultralytics/data/scripts/get_coco128.sh +18 -0
- ultralytics/data/scripts/get_imagenet.sh +52 -0
- ultralytics/data/split.py +125 -0
- ultralytics/data/split_dota.py +325 -0
- ultralytics/data/utils.py +777 -0
- ultralytics/engine/__init__.py +1 -0
- ultralytics/engine/exporter.py +1519 -0
- ultralytics/engine/model.py +1156 -0
- ultralytics/engine/predictor.py +502 -0
- ultralytics/engine/results.py +1840 -0
- ultralytics/engine/trainer.py +853 -0
- ultralytics/engine/tuner.py +243 -0
- ultralytics/engine/validator.py +377 -0
- ultralytics/hub/__init__.py +168 -0
- ultralytics/hub/auth.py +137 -0
- ultralytics/hub/google/__init__.py +176 -0
- ultralytics/hub/session.py +446 -0
- ultralytics/hub/utils.py +248 -0
- ultralytics/models/__init__.py +9 -0
- ultralytics/models/fastsam/__init__.py +7 -0
- ultralytics/models/fastsam/model.py +61 -0
- ultralytics/models/fastsam/predict.py +181 -0
- ultralytics/models/fastsam/utils.py +24 -0
- ultralytics/models/fastsam/val.py +40 -0
- ultralytics/models/nas/__init__.py +7 -0
- ultralytics/models/nas/model.py +102 -0
- ultralytics/models/nas/predict.py +58 -0
- ultralytics/models/nas/val.py +39 -0
- ultralytics/models/rtdetr/__init__.py +7 -0
- ultralytics/models/rtdetr/model.py +63 -0
- ultralytics/models/rtdetr/predict.py +84 -0
- ultralytics/models/rtdetr/train.py +85 -0
- ultralytics/models/rtdetr/val.py +191 -0
- ultralytics/models/sam/__init__.py +6 -0
- ultralytics/models/sam/amg.py +260 -0
- ultralytics/models/sam/build.py +358 -0
- ultralytics/models/sam/model.py +170 -0
- ultralytics/models/sam/modules/__init__.py +1 -0
- ultralytics/models/sam/modules/blocks.py +1129 -0
- ultralytics/models/sam/modules/decoders.py +515 -0
- ultralytics/models/sam/modules/encoders.py +854 -0
- ultralytics/models/sam/modules/memory_attention.py +299 -0
- ultralytics/models/sam/modules/sam.py +1006 -0
- ultralytics/models/sam/modules/tiny_encoder.py +1002 -0
- ultralytics/models/sam/modules/transformer.py +351 -0
- ultralytics/models/sam/modules/utils.py +394 -0
- ultralytics/models/sam/predict.py +1605 -0
- ultralytics/models/utils/__init__.py +1 -0
- ultralytics/models/utils/loss.py +455 -0
- ultralytics/models/utils/ops.py +268 -0
- ultralytics/models/yolo/__init__.py +7 -0
- ultralytics/models/yolo/classify/__init__.py +7 -0
- ultralytics/models/yolo/classify/predict.py +88 -0
- ultralytics/models/yolo/classify/train.py +233 -0
- ultralytics/models/yolo/classify/val.py +215 -0
- ultralytics/models/yolo/detect/__init__.py +7 -0
- ultralytics/models/yolo/detect/predict.py +124 -0
- ultralytics/models/yolo/detect/train.py +217 -0
- ultralytics/models/yolo/detect/val.py +451 -0
- ultralytics/models/yolo/model.py +354 -0
- ultralytics/models/yolo/obb/__init__.py +7 -0
- ultralytics/models/yolo/obb/predict.py +66 -0
- ultralytics/models/yolo/obb/train.py +81 -0
- ultralytics/models/yolo/obb/val.py +283 -0
- ultralytics/models/yolo/pose/__init__.py +7 -0
- ultralytics/models/yolo/pose/predict.py +79 -0
- ultralytics/models/yolo/pose/train.py +154 -0
- ultralytics/models/yolo/pose/val.py +394 -0
- ultralytics/models/yolo/segment/__init__.py +7 -0
- ultralytics/models/yolo/segment/predict.py +113 -0
- ultralytics/models/yolo/segment/train.py +123 -0
- ultralytics/models/yolo/segment/val.py +428 -0
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +119 -0
- ultralytics/models/yolo/world/train_world.py +176 -0
- ultralytics/models/yolo/yoloe/__init__.py +22 -0
- ultralytics/models/yolo/yoloe/predict.py +169 -0
- ultralytics/models/yolo/yoloe/train.py +298 -0
- ultralytics/models/yolo/yoloe/train_seg.py +124 -0
- ultralytics/models/yolo/yoloe/val.py +191 -0
- ultralytics/nn/__init__.py +29 -0
- ultralytics/nn/autobackend.py +842 -0
- ultralytics/nn/modules/__init__.py +182 -0
- ultralytics/nn/modules/activation.py +53 -0
- ultralytics/nn/modules/block.py +1966 -0
- ultralytics/nn/modules/conv.py +712 -0
- ultralytics/nn/modules/head.py +880 -0
- ultralytics/nn/modules/transformer.py +713 -0
- ultralytics/nn/modules/utils.py +164 -0
- ultralytics/nn/tasks.py +1627 -0
- ultralytics/nn/text_model.py +351 -0
- ultralytics/solutions/__init__.py +41 -0
- ultralytics/solutions/ai_gym.py +116 -0
- ultralytics/solutions/analytics.py +252 -0
- ultralytics/solutions/config.py +106 -0
- ultralytics/solutions/distance_calculation.py +124 -0
- ultralytics/solutions/heatmap.py +127 -0
- ultralytics/solutions/instance_segmentation.py +84 -0
- ultralytics/solutions/object_blurrer.py +90 -0
- ultralytics/solutions/object_counter.py +195 -0
- ultralytics/solutions/object_cropper.py +84 -0
- ultralytics/solutions/parking_management.py +273 -0
- ultralytics/solutions/queue_management.py +93 -0
- ultralytics/solutions/region_counter.py +120 -0
- ultralytics/solutions/security_alarm.py +154 -0
- ultralytics/solutions/similarity_search.py +172 -0
- ultralytics/solutions/solutions.py +724 -0
- ultralytics/solutions/speed_estimation.py +110 -0
- ultralytics/solutions/streamlit_inference.py +196 -0
- ultralytics/solutions/templates/similarity-search.html +160 -0
- ultralytics/solutions/trackzone.py +88 -0
- ultralytics/solutions/vision_eye.py +68 -0
- ultralytics/trackers/__init__.py +7 -0
- ultralytics/trackers/basetrack.py +124 -0
- ultralytics/trackers/bot_sort.py +260 -0
- ultralytics/trackers/byte_tracker.py +480 -0
- ultralytics/trackers/track.py +125 -0
- ultralytics/trackers/utils/__init__.py +1 -0
- ultralytics/trackers/utils/gmc.py +376 -0
- ultralytics/trackers/utils/kalman_filter.py +493 -0
- ultralytics/trackers/utils/matching.py +157 -0
- ultralytics/utils/__init__.py +1435 -0
- ultralytics/utils/autobatch.py +106 -0
- ultralytics/utils/autodevice.py +174 -0
- ultralytics/utils/benchmarks.py +695 -0
- ultralytics/utils/callbacks/__init__.py +5 -0
- ultralytics/utils/callbacks/base.py +234 -0
- ultralytics/utils/callbacks/clearml.py +153 -0
- ultralytics/utils/callbacks/comet.py +552 -0
- ultralytics/utils/callbacks/dvc.py +205 -0
- ultralytics/utils/callbacks/hub.py +108 -0
- ultralytics/utils/callbacks/mlflow.py +138 -0
- ultralytics/utils/callbacks/neptune.py +140 -0
- ultralytics/utils/callbacks/raytune.py +43 -0
- ultralytics/utils/callbacks/tensorboard.py +132 -0
- ultralytics/utils/callbacks/wb.py +185 -0
- ultralytics/utils/checks.py +897 -0
- ultralytics/utils/dist.py +119 -0
- ultralytics/utils/downloads.py +499 -0
- ultralytics/utils/errors.py +43 -0
- ultralytics/utils/export.py +219 -0
- ultralytics/utils/files.py +221 -0
- ultralytics/utils/instance.py +499 -0
- ultralytics/utils/loss.py +813 -0
- ultralytics/utils/metrics.py +1356 -0
- ultralytics/utils/ops.py +885 -0
- ultralytics/utils/patches.py +143 -0
- ultralytics/utils/plotting.py +1011 -0
- ultralytics/utils/tal.py +416 -0
- ultralytics/utils/torch_utils.py +990 -0
- ultralytics/utils/triton.py +116 -0
- ultralytics/utils/tuner.py +159 -0
@@ -0,0 +1,1519 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
"""
|
3
|
+
Export a YOLO PyTorch model to other formats. TensorFlow exports authored by https://github.com/zldrobit.
|
4
|
+
|
5
|
+
Format | `format=argument` | Model
|
6
|
+
--- | --- | ---
|
7
|
+
PyTorch | - | yolo11n.pt
|
8
|
+
TorchScript | `torchscript` | yolo11n.torchscript
|
9
|
+
ONNX | `onnx` | yolo11n.onnx
|
10
|
+
OpenVINO | `openvino` | yolo11n_openvino_model/
|
11
|
+
TensorRT | `engine` | yolo11n.engine
|
12
|
+
CoreML | `coreml` | yolo11n.mlpackage
|
13
|
+
TensorFlow SavedModel | `saved_model` | yolo11n_saved_model/
|
14
|
+
TensorFlow GraphDef | `pb` | yolo11n.pb
|
15
|
+
TensorFlow Lite | `tflite` | yolo11n.tflite
|
16
|
+
TensorFlow Edge TPU | `edgetpu` | yolo11n_edgetpu.tflite
|
17
|
+
TensorFlow.js | `tfjs` | yolo11n_web_model/
|
18
|
+
PaddlePaddle | `paddle` | yolo11n_paddle_model/
|
19
|
+
MNN | `mnn` | yolo11n.mnn
|
20
|
+
NCNN | `ncnn` | yolo11n_ncnn_model/
|
21
|
+
IMX | `imx` | yolo11n_imx_model/
|
22
|
+
RKNN | `rknn` | yolo11n_rknn_model/
|
23
|
+
|
24
|
+
Requirements:
|
25
|
+
$ pip install "ultralytics[export]"
|
26
|
+
|
27
|
+
Python:
|
28
|
+
from ultralytics import YOLO
|
29
|
+
model = YOLO('yolo11n.pt')
|
30
|
+
results = model.export(format='onnx')
|
31
|
+
|
32
|
+
CLI:
|
33
|
+
$ yolo mode=export model=yolo11n.pt format=onnx
|
34
|
+
|
35
|
+
Inference:
|
36
|
+
$ yolo predict model=yolo11n.pt # PyTorch
|
37
|
+
yolo11n.torchscript # TorchScript
|
38
|
+
yolo11n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
|
39
|
+
yolo11n_openvino_model # OpenVINO
|
40
|
+
yolo11n.engine # TensorRT
|
41
|
+
yolo11n.mlpackage # CoreML (macOS-only)
|
42
|
+
yolo11n_saved_model # TensorFlow SavedModel
|
43
|
+
yolo11n.pb # TensorFlow GraphDef
|
44
|
+
yolo11n.tflite # TensorFlow Lite
|
45
|
+
yolo11n_edgetpu.tflite # TensorFlow Edge TPU
|
46
|
+
yolo11n_paddle_model # PaddlePaddle
|
47
|
+
yolo11n.mnn # MNN
|
48
|
+
yolo11n_ncnn_model # NCNN
|
49
|
+
yolo11n_imx_model # IMX
|
50
|
+
|
51
|
+
TensorFlow.js:
|
52
|
+
$ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
|
53
|
+
$ npm install
|
54
|
+
$ ln -s ../../yolo11n_web_model public/yolo11n_web_model
|
55
|
+
$ npm start
|
56
|
+
"""
|
57
|
+
|
58
|
+
import json
|
59
|
+
import os
|
60
|
+
import re
|
61
|
+
import shutil
|
62
|
+
import subprocess
|
63
|
+
import time
|
64
|
+
import warnings
|
65
|
+
from contextlib import contextmanager
|
66
|
+
from copy import deepcopy
|
67
|
+
from datetime import datetime
|
68
|
+
from pathlib import Path
|
69
|
+
|
70
|
+
import numpy as np
|
71
|
+
import torch
|
72
|
+
|
73
|
+
from ultralytics import __version__
|
74
|
+
from ultralytics.cfg import TASK2DATA, get_cfg
|
75
|
+
from ultralytics.data import build_dataloader
|
76
|
+
from ultralytics.data.dataset import YOLODataset
|
77
|
+
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
78
|
+
from ultralytics.nn.autobackend import check_class_names, default_class_names
|
79
|
+
from ultralytics.nn.modules import C2f, Classify, Detect, RTDETRDecoder
|
80
|
+
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, WorldModel
|
81
|
+
from ultralytics.utils import (
|
82
|
+
ARM64,
|
83
|
+
DEFAULT_CFG,
|
84
|
+
IS_COLAB,
|
85
|
+
IS_JETSON,
|
86
|
+
LINUX,
|
87
|
+
LOGGER,
|
88
|
+
MACOS,
|
89
|
+
MACOS_VERSION,
|
90
|
+
RKNN_CHIPS,
|
91
|
+
ROOT,
|
92
|
+
SETTINGS,
|
93
|
+
WINDOWS,
|
94
|
+
YAML,
|
95
|
+
callbacks,
|
96
|
+
colorstr,
|
97
|
+
get_default_args,
|
98
|
+
)
|
99
|
+
from ultralytics.utils.checks import (
|
100
|
+
check_imgsz,
|
101
|
+
check_is_path_safe,
|
102
|
+
check_requirements,
|
103
|
+
check_version,
|
104
|
+
is_sudo_available,
|
105
|
+
)
|
106
|
+
from ultralytics.utils.downloads import attempt_download_asset, get_github_assets, safe_download
|
107
|
+
from ultralytics.utils.export import export_engine, export_onnx
|
108
|
+
from ultralytics.utils.files import file_size, spaces_in_path
|
109
|
+
from ultralytics.utils.ops import Profile, nms_rotated
|
110
|
+
from ultralytics.utils.torch_utils import TORCH_1_13, get_cpu_info, get_latest_opset, select_device
|
111
|
+
|
112
|
+
|
113
|
+
def export_formats():
|
114
|
+
"""Return a dictionary of Ultralytics YOLO export formats."""
|
115
|
+
x = [
|
116
|
+
["PyTorch", "-", ".pt", True, True, []],
|
117
|
+
["TorchScript", "torchscript", ".torchscript", True, True, ["batch", "optimize", "half", "nms"]],
|
118
|
+
["ONNX", "onnx", ".onnx", True, True, ["batch", "dynamic", "half", "opset", "simplify", "nms"]],
|
119
|
+
[
|
120
|
+
"OpenVINO",
|
121
|
+
"openvino",
|
122
|
+
"_openvino_model",
|
123
|
+
True,
|
124
|
+
False,
|
125
|
+
["batch", "dynamic", "half", "int8", "nms", "fraction"],
|
126
|
+
],
|
127
|
+
[
|
128
|
+
"TensorRT",
|
129
|
+
"engine",
|
130
|
+
".engine",
|
131
|
+
False,
|
132
|
+
True,
|
133
|
+
["batch", "dynamic", "half", "int8", "simplify", "nms", "fraction"],
|
134
|
+
],
|
135
|
+
["CoreML", "coreml", ".mlpackage", True, False, ["batch", "half", "int8", "nms"]],
|
136
|
+
["TensorFlow SavedModel", "saved_model", "_saved_model", True, True, ["batch", "int8", "keras", "nms"]],
|
137
|
+
["TensorFlow GraphDef", "pb", ".pb", True, True, ["batch"]],
|
138
|
+
["TensorFlow Lite", "tflite", ".tflite", True, False, ["batch", "half", "int8", "nms", "fraction"]],
|
139
|
+
["TensorFlow Edge TPU", "edgetpu", "_edgetpu.tflite", True, False, []],
|
140
|
+
["TensorFlow.js", "tfjs", "_web_model", True, False, ["batch", "half", "int8", "nms"]],
|
141
|
+
["PaddlePaddle", "paddle", "_paddle_model", True, True, ["batch"]],
|
142
|
+
["MNN", "mnn", ".mnn", True, True, ["batch", "half", "int8"]],
|
143
|
+
["NCNN", "ncnn", "_ncnn_model", True, True, ["batch", "half"]],
|
144
|
+
["IMX", "imx", "_imx_model", True, True, ["int8", "fraction"]],
|
145
|
+
["RKNN", "rknn", "_rknn_model", False, False, ["batch", "name"]],
|
146
|
+
]
|
147
|
+
return dict(zip(["Format", "Argument", "Suffix", "CPU", "GPU", "Arguments"], zip(*x)))
|
148
|
+
|
149
|
+
|
150
|
+
def validate_args(format, passed_args, valid_args):
|
151
|
+
"""
|
152
|
+
Validate arguments based on the export format.
|
153
|
+
|
154
|
+
Args:
|
155
|
+
format (str): The export format.
|
156
|
+
passed_args (Namespace): The arguments used during export.
|
157
|
+
valid_args (list): List of valid arguments for the format.
|
158
|
+
|
159
|
+
Raises:
|
160
|
+
AssertionError: If an unsupported argument is used, or if the format lacks supported argument listings.
|
161
|
+
"""
|
162
|
+
export_args = ["half", "int8", "dynamic", "keras", "nms", "batch", "fraction"]
|
163
|
+
|
164
|
+
assert valid_args is not None, f"ERROR ❌️ valid arguments for '{format}' not listed."
|
165
|
+
custom = {"batch": 1, "data": None, "device": None} # exporter defaults
|
166
|
+
default_args = get_cfg(DEFAULT_CFG, custom)
|
167
|
+
for arg in export_args:
|
168
|
+
not_default = getattr(passed_args, arg, None) != getattr(default_args, arg, None)
|
169
|
+
if not_default:
|
170
|
+
assert arg in valid_args, f"ERROR ❌️ argument '{arg}' is not supported for format='{format}'"
|
171
|
+
|
172
|
+
|
173
|
+
def gd_outputs(gd):
|
174
|
+
"""Return TensorFlow GraphDef model output node names."""
|
175
|
+
name_list, input_list = [], []
|
176
|
+
for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef
|
177
|
+
name_list.append(node.name)
|
178
|
+
input_list.extend(node.input)
|
179
|
+
return sorted(f"{x}:0" for x in list(set(name_list) - set(input_list)) if not x.startswith("NoOp"))
|
180
|
+
|
181
|
+
|
182
|
+
def try_export(inner_func):
|
183
|
+
"""YOLO export decorator, i.e. @try_export."""
|
184
|
+
inner_args = get_default_args(inner_func)
|
185
|
+
|
186
|
+
def outer_func(*args, **kwargs):
|
187
|
+
"""Export a model."""
|
188
|
+
prefix = inner_args["prefix"]
|
189
|
+
dt = 0.0
|
190
|
+
try:
|
191
|
+
with Profile() as dt:
|
192
|
+
f, model = inner_func(*args, **kwargs)
|
193
|
+
LOGGER.info(f"{prefix} export success ✅ {dt.t:.1f}s, saved as '{f}' ({file_size(f):.1f} MB)")
|
194
|
+
return f, model
|
195
|
+
except Exception as e:
|
196
|
+
LOGGER.error(f"{prefix} export failure {dt.t:.1f}s: {e}")
|
197
|
+
raise e
|
198
|
+
|
199
|
+
return outer_func
|
200
|
+
|
201
|
+
|
202
|
+
@contextmanager
|
203
|
+
def arange_patch(args):
|
204
|
+
"""
|
205
|
+
Workaround for ONNX torch.arange incompatibility with FP16.
|
206
|
+
|
207
|
+
https://github.com/pytorch/pytorch/issues/148041.
|
208
|
+
"""
|
209
|
+
if args.dynamic and args.half and args.format == "onnx":
|
210
|
+
func = torch.arange
|
211
|
+
|
212
|
+
def arange(*args, dtype=None, **kwargs):
|
213
|
+
"""Return a 1-D tensor of size with values from the interval and common difference."""
|
214
|
+
return func(*args, **kwargs).to(dtype) # cast to dtype instead of passing dtype
|
215
|
+
|
216
|
+
torch.arange = arange # patch
|
217
|
+
yield
|
218
|
+
torch.arange = func # unpatch
|
219
|
+
else:
|
220
|
+
yield
|
221
|
+
|
222
|
+
|
223
|
+
class Exporter:
|
224
|
+
"""
|
225
|
+
A class for exporting a model.
|
226
|
+
|
227
|
+
Attributes:
|
228
|
+
args (SimpleNamespace): Configuration for the exporter.
|
229
|
+
callbacks (list, optional): List of callback functions.
|
230
|
+
"""
|
231
|
+
|
232
|
+
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
233
|
+
"""
|
234
|
+
Initialize the Exporter class.
|
235
|
+
|
236
|
+
Args:
|
237
|
+
cfg (str, optional): Path to a configuration file.
|
238
|
+
overrides (dict, optional): Configuration overrides.
|
239
|
+
_callbacks (dict, optional): Dictionary of callback functions.
|
240
|
+
"""
|
241
|
+
self.args = get_cfg(cfg, overrides)
|
242
|
+
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
243
|
+
callbacks.add_integration_callbacks(self)
|
244
|
+
|
245
|
+
def __call__(self, model=None) -> str:
|
246
|
+
"""Return list of exported files/dirs after running callbacks."""
|
247
|
+
self.run_callbacks("on_export_start")
|
248
|
+
t = time.time()
|
249
|
+
fmt = self.args.format.lower() # to lowercase
|
250
|
+
if fmt in {"tensorrt", "trt"}: # 'engine' aliases
|
251
|
+
fmt = "engine"
|
252
|
+
if fmt in {"mlmodel", "mlpackage", "mlprogram", "apple", "ios", "coreml"}: # 'coreml' aliases
|
253
|
+
fmt = "coreml"
|
254
|
+
fmts_dict = export_formats()
|
255
|
+
fmts = tuple(fmts_dict["Argument"][1:]) # available export formats
|
256
|
+
if fmt not in fmts:
|
257
|
+
import difflib
|
258
|
+
|
259
|
+
# Get the closest match if format is invalid
|
260
|
+
matches = difflib.get_close_matches(fmt, fmts, n=1, cutoff=0.6) # 60% similarity required to match
|
261
|
+
if not matches:
|
262
|
+
raise ValueError(f"Invalid export format='{fmt}'. Valid formats are {fmts}")
|
263
|
+
LOGGER.warning(f"Invalid export format='{fmt}', updating to format='{matches[0]}'")
|
264
|
+
fmt = matches[0]
|
265
|
+
flags = [x == fmt for x in fmts]
|
266
|
+
if sum(flags) != 1:
|
267
|
+
raise ValueError(f"Invalid export format='{fmt}'. Valid formats are {fmts}")
|
268
|
+
(jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, mnn, ncnn, imx, rknn) = (
|
269
|
+
flags # export booleans
|
270
|
+
)
|
271
|
+
|
272
|
+
is_tf_format = any((saved_model, pb, tflite, edgetpu, tfjs))
|
273
|
+
|
274
|
+
# Device
|
275
|
+
dla = None
|
276
|
+
if fmt == "engine" and self.args.device is None:
|
277
|
+
LOGGER.warning("TensorRT requires GPU export, automatically assigning device=0")
|
278
|
+
self.args.device = "0"
|
279
|
+
if fmt == "engine" and "dla" in str(self.args.device): # convert int/list to str first
|
280
|
+
dla = self.args.device.split(":")[-1]
|
281
|
+
self.args.device = "0" # update device to "0"
|
282
|
+
assert dla in {"0", "1"}, f"Expected self.args.device='dla:0' or 'dla:1, but got {self.args.device}."
|
283
|
+
if imx and self.args.device is None and torch.cuda.is_available():
|
284
|
+
LOGGER.warning("Exporting on CPU while CUDA is available, setting device=0 for faster export on GPU.")
|
285
|
+
self.args.device = "0" # update device to "0"
|
286
|
+
self.device = select_device("cpu" if self.args.device is None else self.args.device)
|
287
|
+
|
288
|
+
# Argument compatibility checks
|
289
|
+
fmt_keys = fmts_dict["Arguments"][flags.index(True) + 1]
|
290
|
+
validate_args(fmt, self.args, fmt_keys)
|
291
|
+
if imx:
|
292
|
+
if not self.args.int8:
|
293
|
+
LOGGER.warning("IMX export requires int8=True, setting int8=True.")
|
294
|
+
self.args.int8 = True
|
295
|
+
if model.task != "detect":
|
296
|
+
raise ValueError("IMX export only supported for detection models.")
|
297
|
+
if not hasattr(model, "names"):
|
298
|
+
model.names = default_class_names()
|
299
|
+
model.names = check_class_names(model.names)
|
300
|
+
if self.args.half and self.args.int8:
|
301
|
+
LOGGER.warning("half=True and int8=True are mutually exclusive, setting half=False.")
|
302
|
+
self.args.half = False
|
303
|
+
if self.args.half and onnx and self.device.type == "cpu":
|
304
|
+
LOGGER.warning("half=True only compatible with GPU export, i.e. use device=0")
|
305
|
+
self.args.half = False
|
306
|
+
self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size
|
307
|
+
if self.args.int8 and engine:
|
308
|
+
self.args.dynamic = True # enforce dynamic to export TensorRT INT8
|
309
|
+
if self.args.optimize:
|
310
|
+
assert not ncnn, "optimize=True not compatible with format='ncnn', i.e. use optimize=False"
|
311
|
+
assert self.device.type == "cpu", "optimize=True not compatible with cuda devices, i.e. use device='cpu'"
|
312
|
+
if rknn:
|
313
|
+
if not self.args.name:
|
314
|
+
LOGGER.warning(
|
315
|
+
"Rockchip RKNN export requires a missing 'name' arg for processor type. "
|
316
|
+
"Using default name='rk3588'."
|
317
|
+
)
|
318
|
+
self.args.name = "rk3588"
|
319
|
+
self.args.name = self.args.name.lower()
|
320
|
+
assert self.args.name in RKNN_CHIPS, (
|
321
|
+
f"Invalid processor name '{self.args.name}' for Rockchip RKNN export. Valid names are {RKNN_CHIPS}."
|
322
|
+
)
|
323
|
+
if self.args.int8 and tflite:
|
324
|
+
assert not getattr(model, "end2end", False), "TFLite INT8 export not supported for end2end models."
|
325
|
+
if self.args.nms:
|
326
|
+
assert not isinstance(model, ClassificationModel), "'nms=True' is not valid for classification models."
|
327
|
+
assert not (tflite and ARM64 and LINUX), "TFLite export with NMS unsupported on ARM64 Linux"
|
328
|
+
if getattr(model, "end2end", False):
|
329
|
+
LOGGER.warning("'nms=True' is not available for end2end models. Forcing 'nms=False'.")
|
330
|
+
self.args.nms = False
|
331
|
+
self.args.conf = self.args.conf or 0.25 # set conf default value for nms export
|
332
|
+
if edgetpu:
|
333
|
+
if not LINUX or ARM64:
|
334
|
+
raise SystemError(
|
335
|
+
"Edge TPU export only supported on non-aarch64 Linux. See https://coral.ai/docs/edgetpu/compiler"
|
336
|
+
)
|
337
|
+
elif self.args.batch != 1: # see github.com/ultralytics/ultralytics/pull/13420
|
338
|
+
LOGGER.warning("Edge TPU export requires batch size 1, setting batch=1.")
|
339
|
+
self.args.batch = 1
|
340
|
+
if isinstance(model, WorldModel):
|
341
|
+
LOGGER.warning(
|
342
|
+
"YOLOWorld (original version) export is not supported to any format. "
|
343
|
+
"YOLOWorldv2 models (i.e. 'yolov8s-worldv2.pt') only support export to "
|
344
|
+
"(torchscript, onnx, openvino, engine, coreml) formats. "
|
345
|
+
"See https://docs.ultralytics.com/models/yolo-world for details."
|
346
|
+
)
|
347
|
+
model.clip_model = None # openvino int8 export error: https://github.com/ultralytics/ultralytics/pull/18445
|
348
|
+
if self.args.int8 and not self.args.data:
|
349
|
+
self.args.data = DEFAULT_CFG.data or TASK2DATA[getattr(model, "task", "detect")] # assign default data
|
350
|
+
LOGGER.warning(
|
351
|
+
f"INT8 export requires a missing 'data' arg for calibration. Using default 'data={self.args.data}'."
|
352
|
+
)
|
353
|
+
if tfjs and (ARM64 and LINUX):
|
354
|
+
raise SystemError("TF.js exports are not currently supported on ARM64 Linux")
|
355
|
+
# Recommend OpenVINO if export and Intel CPU
|
356
|
+
if SETTINGS.get("openvino_msg"):
|
357
|
+
if "intel" in get_cpu_info().lower():
|
358
|
+
LOGGER.info(
|
359
|
+
"💡 ProTip: Export to OpenVINO format for best performance on Intel CPUs."
|
360
|
+
" Learn more at https://docs.ultralytics.com/integrations/openvino/"
|
361
|
+
)
|
362
|
+
SETTINGS["openvino_msg"] = False
|
363
|
+
|
364
|
+
# Input
|
365
|
+
im = torch.zeros(self.args.batch, model.yaml.get("channels", 3), *self.imgsz).to(self.device)
|
366
|
+
file = Path(
|
367
|
+
getattr(model, "pt_path", None) or getattr(model, "yaml_file", None) or model.yaml.get("yaml_file", "")
|
368
|
+
)
|
369
|
+
if file.suffix in {".yaml", ".yml"}:
|
370
|
+
file = Path(file.name)
|
371
|
+
|
372
|
+
# Update model
|
373
|
+
model = deepcopy(model).to(self.device)
|
374
|
+
for p in model.parameters():
|
375
|
+
p.requires_grad = False
|
376
|
+
model.eval()
|
377
|
+
model.float()
|
378
|
+
model = model.fuse()
|
379
|
+
|
380
|
+
if imx:
|
381
|
+
from ultralytics.utils.torch_utils import FXModel
|
382
|
+
|
383
|
+
model = FXModel(model)
|
384
|
+
for m in model.modules():
|
385
|
+
if isinstance(m, Classify):
|
386
|
+
m.export = True
|
387
|
+
if isinstance(m, (Detect, RTDETRDecoder)): # includes all Detect subclasses like Segment, Pose, OBB
|
388
|
+
m.dynamic = self.args.dynamic
|
389
|
+
m.export = True
|
390
|
+
m.format = self.args.format
|
391
|
+
m.max_det = self.args.max_det
|
392
|
+
m.xyxy = self.args.nms and not coreml
|
393
|
+
elif isinstance(m, C2f) and not is_tf_format:
|
394
|
+
# EdgeTPU does not support FlexSplitV while split provides cleaner ONNX graph
|
395
|
+
m.forward = m.forward_split
|
396
|
+
if isinstance(m, Detect) and imx:
|
397
|
+
from ultralytics.utils.tal import make_anchors
|
398
|
+
|
399
|
+
m.anchors, m.strides = (
|
400
|
+
x.transpose(0, 1)
|
401
|
+
for x in make_anchors(
|
402
|
+
torch.cat([s / m.stride.unsqueeze(-1) for s in self.imgsz], dim=1), m.stride, 0.5
|
403
|
+
)
|
404
|
+
)
|
405
|
+
|
406
|
+
y = None
|
407
|
+
for _ in range(2): # dry runs
|
408
|
+
y = NMSModel(model, self.args)(im) if self.args.nms and not coreml else model(im)
|
409
|
+
if self.args.half and onnx and self.device.type != "cpu":
|
410
|
+
im, model = im.half(), model.half() # to FP16
|
411
|
+
|
412
|
+
# Filter warnings
|
413
|
+
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) # suppress TracerWarning
|
414
|
+
warnings.filterwarnings("ignore", category=UserWarning) # suppress shape prim::Constant missing ONNX warning
|
415
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning) # suppress CoreML np.bool deprecation warning
|
416
|
+
|
417
|
+
# Assign
|
418
|
+
self.im = im
|
419
|
+
self.model = model
|
420
|
+
self.file = file
|
421
|
+
self.output_shape = (
|
422
|
+
tuple(y.shape)
|
423
|
+
if isinstance(y, torch.Tensor)
|
424
|
+
else tuple(tuple(x.shape if isinstance(x, torch.Tensor) else []) for x in y)
|
425
|
+
)
|
426
|
+
self.pretty_name = Path(self.model.yaml.get("yaml_file", self.file)).stem.replace("yolo", "YOLO")
|
427
|
+
data = model.args["data"] if hasattr(model, "args") and isinstance(model.args, dict) else ""
|
428
|
+
description = f"Ultralytics {self.pretty_name} model {f'trained on {data}' if data else ''}"
|
429
|
+
self.metadata = {
|
430
|
+
"description": description,
|
431
|
+
"author": "Ultralytics",
|
432
|
+
"date": datetime.now().isoformat(),
|
433
|
+
"version": __version__,
|
434
|
+
"license": "AGPL-3.0 License (https://ultralytics.com/license)",
|
435
|
+
"docs": "https://docs.ultralytics.com",
|
436
|
+
"stride": int(max(model.stride)),
|
437
|
+
"task": model.task,
|
438
|
+
"batch": self.args.batch,
|
439
|
+
"imgsz": self.imgsz,
|
440
|
+
"names": model.names,
|
441
|
+
"args": {k: v for k, v in self.args if k in fmt_keys},
|
442
|
+
"channels": model.yaml.get("channels", 3),
|
443
|
+
} # model metadata
|
444
|
+
if dla is not None:
|
445
|
+
self.metadata["dla"] = dla # make sure `AutoBackend` uses correct dla device if it has one
|
446
|
+
if model.task == "pose":
|
447
|
+
self.metadata["kpt_shape"] = model.model[-1].kpt_shape
|
448
|
+
|
449
|
+
LOGGER.info(
|
450
|
+
f"\n{colorstr('PyTorch:')} starting from '{file}' with input shape {tuple(im.shape)} BCHW and "
|
451
|
+
f"output shape(s) {self.output_shape} ({file_size(file):.1f} MB)"
|
452
|
+
)
|
453
|
+
|
454
|
+
# Exports
|
455
|
+
f = [""] * len(fmts) # exported filenames
|
456
|
+
if jit or ncnn: # TorchScript
|
457
|
+
f[0], _ = self.export_torchscript()
|
458
|
+
if engine: # TensorRT required before ONNX
|
459
|
+
f[1], _ = self.export_engine(dla=dla)
|
460
|
+
if onnx: # ONNX
|
461
|
+
f[2], _ = self.export_onnx()
|
462
|
+
if xml: # OpenVINO
|
463
|
+
f[3], _ = self.export_openvino()
|
464
|
+
if coreml: # CoreML
|
465
|
+
f[4], _ = self.export_coreml()
|
466
|
+
if is_tf_format: # TensorFlow formats
|
467
|
+
self.args.int8 |= edgetpu
|
468
|
+
f[5], keras_model = self.export_saved_model()
|
469
|
+
if pb or tfjs: # pb prerequisite to tfjs
|
470
|
+
f[6], _ = self.export_pb(keras_model=keras_model)
|
471
|
+
if tflite:
|
472
|
+
f[7], _ = self.export_tflite()
|
473
|
+
if edgetpu:
|
474
|
+
f[8], _ = self.export_edgetpu(tflite_model=Path(f[5]) / f"{self.file.stem}_full_integer_quant.tflite")
|
475
|
+
if tfjs:
|
476
|
+
f[9], _ = self.export_tfjs()
|
477
|
+
if paddle: # PaddlePaddle
|
478
|
+
f[10], _ = self.export_paddle()
|
479
|
+
if mnn: # MNN
|
480
|
+
f[11], _ = self.export_mnn()
|
481
|
+
if ncnn: # NCNN
|
482
|
+
f[12], _ = self.export_ncnn()
|
483
|
+
if imx:
|
484
|
+
f[13], _ = self.export_imx()
|
485
|
+
if rknn:
|
486
|
+
f[14], _ = self.export_rknn()
|
487
|
+
|
488
|
+
# Finish
|
489
|
+
f = [str(x) for x in f if x] # filter out '' and None
|
490
|
+
if any(f):
|
491
|
+
f = str(Path(f[-1]))
|
492
|
+
square = self.imgsz[0] == self.imgsz[1]
|
493
|
+
s = (
|
494
|
+
""
|
495
|
+
if square
|
496
|
+
else f"WARNING ⚠️ non-PyTorch val requires square images, 'imgsz={self.imgsz}' will not "
|
497
|
+
f"work. Use export 'imgsz={max(self.imgsz)}' if val is required."
|
498
|
+
)
|
499
|
+
imgsz = self.imgsz[0] if square else str(self.imgsz)[1:-1].replace(" ", "")
|
500
|
+
predict_data = f"data={data}" if model.task == "segment" and fmt == "pb" else ""
|
501
|
+
q = "int8" if self.args.int8 else "half" if self.args.half else "" # quantization
|
502
|
+
LOGGER.info(
|
503
|
+
f"\nExport complete ({time.time() - t:.1f}s)"
|
504
|
+
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
|
505
|
+
f"\nPredict: yolo predict task={model.task} model={f} imgsz={imgsz} {q} {predict_data}"
|
506
|
+
f"\nValidate: yolo val task={model.task} model={f} imgsz={imgsz} data={data} {q} {s}"
|
507
|
+
f"\nVisualize: https://netron.app"
|
508
|
+
)
|
509
|
+
|
510
|
+
self.run_callbacks("on_export_end")
|
511
|
+
return f # return list of exported files/dirs
|
512
|
+
|
513
|
+
def get_int8_calibration_dataloader(self, prefix=""):
|
514
|
+
"""Build and return a dataloader for calibration of INT8 models."""
|
515
|
+
LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'")
|
516
|
+
data = (check_cls_dataset if self.model.task == "classify" else check_det_dataset)(self.args.data)
|
517
|
+
# TensorRT INT8 calibration should use 2x batch size
|
518
|
+
batch = self.args.batch * (2 if self.args.format == "engine" else 1)
|
519
|
+
dataset = YOLODataset(
|
520
|
+
data[self.args.split or "val"],
|
521
|
+
data=data,
|
522
|
+
fraction=self.args.fraction,
|
523
|
+
task=self.model.task,
|
524
|
+
imgsz=self.imgsz[0],
|
525
|
+
augment=False,
|
526
|
+
batch_size=batch,
|
527
|
+
)
|
528
|
+
n = len(dataset)
|
529
|
+
if n < self.args.batch:
|
530
|
+
raise ValueError(
|
531
|
+
f"The calibration dataset ({n} images) must have at least as many images as the batch size "
|
532
|
+
f"('batch={self.args.batch}')."
|
533
|
+
)
|
534
|
+
elif n < 300:
|
535
|
+
LOGGER.warning(f"{prefix} >300 images recommended for INT8 calibration, found {n} images.")
|
536
|
+
return build_dataloader(dataset, batch=batch, workers=0) # required for batch loading
|
537
|
+
|
538
|
+
@try_export
|
539
|
+
def export_torchscript(self, prefix=colorstr("TorchScript:")):
|
540
|
+
"""YOLO TorchScript model export."""
|
541
|
+
LOGGER.info(f"\n{prefix} starting export with torch {torch.__version__}...")
|
542
|
+
f = self.file.with_suffix(".torchscript")
|
543
|
+
|
544
|
+
ts = torch.jit.trace(NMSModel(self.model, self.args) if self.args.nms else self.model, self.im, strict=False)
|
545
|
+
extra_files = {"config.txt": json.dumps(self.metadata)} # torch._C.ExtraFilesMap()
|
546
|
+
if self.args.optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
|
547
|
+
LOGGER.info(f"{prefix} optimizing for mobile...")
|
548
|
+
from torch.utils.mobile_optimizer import optimize_for_mobile
|
549
|
+
|
550
|
+
optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files)
|
551
|
+
else:
|
552
|
+
ts.save(str(f), _extra_files=extra_files)
|
553
|
+
return f, None
|
554
|
+
|
555
|
+
@try_export
|
556
|
+
def export_onnx(self, prefix=colorstr("ONNX:")):
|
557
|
+
"""YOLO ONNX export."""
|
558
|
+
requirements = ["onnx>=1.12.0,<1.18.0"]
|
559
|
+
if self.args.simplify:
|
560
|
+
requirements += ["onnxslim>=0.1.46", "onnxruntime" + ("-gpu" if torch.cuda.is_available() else "")]
|
561
|
+
check_requirements(requirements)
|
562
|
+
import onnx # noqa
|
563
|
+
|
564
|
+
opset_version = self.args.opset or get_latest_opset()
|
565
|
+
LOGGER.info(f"\n{prefix} starting export with onnx {onnx.__version__} opset {opset_version}...")
|
566
|
+
f = str(self.file.with_suffix(".onnx"))
|
567
|
+
output_names = ["output0", "output1"] if isinstance(self.model, SegmentationModel) else ["output0"]
|
568
|
+
dynamic = self.args.dynamic
|
569
|
+
if dynamic:
|
570
|
+
dynamic = {"images": {0: "batch", 2: "height", 3: "width"}} # shape(1,3,640,640)
|
571
|
+
if isinstance(self.model, SegmentationModel):
|
572
|
+
dynamic["output0"] = {0: "batch", 2: "anchors"} # shape(1, 116, 8400)
|
573
|
+
dynamic["output1"] = {0: "batch", 2: "mask_height", 3: "mask_width"} # shape(1,32,160,160)
|
574
|
+
elif isinstance(self.model, DetectionModel):
|
575
|
+
dynamic["output0"] = {0: "batch", 2: "anchors"} # shape(1, 84, 8400)
|
576
|
+
if self.args.nms: # only batch size is dynamic with NMS
|
577
|
+
dynamic["output0"].pop(2)
|
578
|
+
if self.args.nms and self.model.task == "obb":
|
579
|
+
self.args.opset = opset_version # for NMSModel
|
580
|
+
|
581
|
+
with arange_patch(self.args):
|
582
|
+
export_onnx(
|
583
|
+
NMSModel(self.model, self.args) if self.args.nms else self.model,
|
584
|
+
self.im,
|
585
|
+
f,
|
586
|
+
opset=opset_version,
|
587
|
+
input_names=["images"],
|
588
|
+
output_names=output_names,
|
589
|
+
dynamic=dynamic or None,
|
590
|
+
)
|
591
|
+
|
592
|
+
# Checks
|
593
|
+
model_onnx = onnx.load(f) # load onnx model
|
594
|
+
|
595
|
+
# Simplify
|
596
|
+
if self.args.simplify:
|
597
|
+
try:
|
598
|
+
import onnxslim
|
599
|
+
|
600
|
+
LOGGER.info(f"{prefix} slimming with onnxslim {onnxslim.__version__}...")
|
601
|
+
model_onnx = onnxslim.slim(model_onnx)
|
602
|
+
|
603
|
+
except Exception as e:
|
604
|
+
LOGGER.warning(f"{prefix} simplifier failure: {e}")
|
605
|
+
|
606
|
+
# Metadata
|
607
|
+
for k, v in self.metadata.items():
|
608
|
+
meta = model_onnx.metadata_props.add()
|
609
|
+
meta.key, meta.value = k, str(v)
|
610
|
+
|
611
|
+
onnx.save(model_onnx, f)
|
612
|
+
return f, model_onnx
|
613
|
+
|
614
|
+
@try_export
|
615
|
+
def export_openvino(self, prefix=colorstr("OpenVINO:")):
|
616
|
+
"""YOLO OpenVINO export."""
|
617
|
+
if MACOS:
|
618
|
+
msg = "OpenVINO error in macOS>=15.4 https://github.com/openvinotoolkit/openvino/issues/30023"
|
619
|
+
check_version(MACOS_VERSION, "<15.4", name="macOS ", hard=True, msg=msg)
|
620
|
+
check_requirements("openvino>=2024.0.0")
|
621
|
+
import openvino as ov
|
622
|
+
|
623
|
+
LOGGER.info(f"\n{prefix} starting export with openvino {ov.__version__}...")
|
624
|
+
assert TORCH_1_13, f"OpenVINO export requires torch>=1.13.0 but torch=={torch.__version__} is installed"
|
625
|
+
ov_model = ov.convert_model(
|
626
|
+
NMSModel(self.model, self.args) if self.args.nms else self.model,
|
627
|
+
input=None if self.args.dynamic else [self.im.shape],
|
628
|
+
example_input=self.im,
|
629
|
+
)
|
630
|
+
|
631
|
+
def serialize(ov_model, file):
|
632
|
+
"""Set RT info, serialize, and save metadata YAML."""
|
633
|
+
ov_model.set_rt_info("YOLO", ["model_info", "model_type"])
|
634
|
+
ov_model.set_rt_info(True, ["model_info", "reverse_input_channels"])
|
635
|
+
ov_model.set_rt_info(114, ["model_info", "pad_value"])
|
636
|
+
ov_model.set_rt_info([255.0], ["model_info", "scale_values"])
|
637
|
+
ov_model.set_rt_info(self.args.iou, ["model_info", "iou_threshold"])
|
638
|
+
ov_model.set_rt_info([v.replace(" ", "_") for v in self.model.names.values()], ["model_info", "labels"])
|
639
|
+
if self.model.task != "classify":
|
640
|
+
ov_model.set_rt_info("fit_to_window_letterbox", ["model_info", "resize_type"])
|
641
|
+
|
642
|
+
ov.save_model(ov_model, file, compress_to_fp16=self.args.half)
|
643
|
+
YAML.save(Path(file).parent / "metadata.yaml", self.metadata) # add metadata.yaml
|
644
|
+
|
645
|
+
if self.args.int8:
|
646
|
+
fq = str(self.file).replace(self.file.suffix, f"_int8_openvino_model{os.sep}")
|
647
|
+
fq_ov = str(Path(fq) / self.file.with_suffix(".xml").name)
|
648
|
+
# INT8 requires nncf, nncf requires packaging>=23.2 https://github.com/openvinotoolkit/nncf/issues/3463
|
649
|
+
check_requirements("packaging>=23.2") # must be installed first to build nncf wheel
|
650
|
+
check_requirements("nncf>=2.14.0")
|
651
|
+
import nncf
|
652
|
+
|
653
|
+
def transform_fn(data_item) -> np.ndarray:
|
654
|
+
"""Quantization transform function."""
|
655
|
+
data_item: torch.Tensor = data_item["img"] if isinstance(data_item, dict) else data_item
|
656
|
+
assert data_item.dtype == torch.uint8, "Input image must be uint8 for the quantization preprocessing"
|
657
|
+
im = data_item.numpy().astype(np.float32) / 255.0 # uint8 to fp16/32 and 0-255 to 0.0-1.0
|
658
|
+
return np.expand_dims(im, 0) if im.ndim == 3 else im
|
659
|
+
|
660
|
+
# Generate calibration data for integer quantization
|
661
|
+
ignored_scope = None
|
662
|
+
if isinstance(self.model.model[-1], Detect):
|
663
|
+
# Includes all Detect subclasses like Segment, Pose, OBB, WorldDetect, YOLOEDetect
|
664
|
+
head_module_name = ".".join(list(self.model.named_modules())[-1][0].split(".")[:2])
|
665
|
+
ignored_scope = nncf.IgnoredScope( # ignore operations
|
666
|
+
patterns=[
|
667
|
+
f".*{head_module_name}/.*/Add",
|
668
|
+
f".*{head_module_name}/.*/Sub*",
|
669
|
+
f".*{head_module_name}/.*/Mul*",
|
670
|
+
f".*{head_module_name}/.*/Div*",
|
671
|
+
f".*{head_module_name}\\.dfl.*",
|
672
|
+
],
|
673
|
+
types=["Sigmoid"],
|
674
|
+
)
|
675
|
+
|
676
|
+
quantized_ov_model = nncf.quantize(
|
677
|
+
model=ov_model,
|
678
|
+
calibration_dataset=nncf.Dataset(self.get_int8_calibration_dataloader(prefix), transform_fn),
|
679
|
+
preset=nncf.QuantizationPreset.MIXED,
|
680
|
+
ignored_scope=ignored_scope,
|
681
|
+
)
|
682
|
+
serialize(quantized_ov_model, fq_ov)
|
683
|
+
return fq, None
|
684
|
+
|
685
|
+
f = str(self.file).replace(self.file.suffix, f"_openvino_model{os.sep}")
|
686
|
+
f_ov = str(Path(f) / self.file.with_suffix(".xml").name)
|
687
|
+
|
688
|
+
serialize(ov_model, f_ov)
|
689
|
+
return f, None
|
690
|
+
|
691
|
+
@try_export
|
692
|
+
def export_paddle(self, prefix=colorstr("PaddlePaddle:")):
|
693
|
+
"""YOLO Paddle export."""
|
694
|
+
assert not IS_JETSON, "Jetson Paddle exports not supported yet"
|
695
|
+
check_requirements(("paddlepaddle-gpu" if torch.cuda.is_available() else "paddlepaddle>=3.0.0", "x2paddle"))
|
696
|
+
import x2paddle # noqa
|
697
|
+
from x2paddle.convert import pytorch2paddle # noqa
|
698
|
+
|
699
|
+
LOGGER.info(f"\n{prefix} starting export with X2Paddle {x2paddle.__version__}...")
|
700
|
+
f = str(self.file).replace(self.file.suffix, f"_paddle_model{os.sep}")
|
701
|
+
|
702
|
+
pytorch2paddle(module=self.model, save_dir=f, jit_type="trace", input_examples=[self.im]) # export
|
703
|
+
YAML.save(Path(f) / "metadata.yaml", self.metadata) # add metadata.yaml
|
704
|
+
return f, None
|
705
|
+
|
706
|
+
@try_export
|
707
|
+
def export_mnn(self, prefix=colorstr("MNN:")):
|
708
|
+
"""YOLO MNN export using MNN https://github.com/alibaba/MNN."""
|
709
|
+
f_onnx, _ = self.export_onnx() # get onnx model first
|
710
|
+
|
711
|
+
check_requirements("MNN>=2.9.6")
|
712
|
+
import MNN # noqa
|
713
|
+
from MNN.tools import mnnconvert
|
714
|
+
|
715
|
+
# Setup and checks
|
716
|
+
LOGGER.info(f"\n{prefix} starting export with MNN {MNN.version()}...")
|
717
|
+
assert Path(f_onnx).exists(), f"failed to export ONNX file: {f_onnx}"
|
718
|
+
f = str(self.file.with_suffix(".mnn")) # MNN model file
|
719
|
+
args = ["", "-f", "ONNX", "--modelFile", f_onnx, "--MNNModel", f, "--bizCode", json.dumps(self.metadata)]
|
720
|
+
if self.args.int8:
|
721
|
+
args.extend(("--weightQuantBits", "8"))
|
722
|
+
if self.args.half:
|
723
|
+
args.append("--fp16")
|
724
|
+
mnnconvert.convert(args)
|
725
|
+
# remove scratch file for model convert optimize
|
726
|
+
convert_scratch = Path(self.file.parent / ".__convert_external_data.bin")
|
727
|
+
if convert_scratch.exists():
|
728
|
+
convert_scratch.unlink()
|
729
|
+
return f, None
|
730
|
+
|
731
|
+
@try_export
|
732
|
+
def export_ncnn(self, prefix=colorstr("NCNN:")):
|
733
|
+
"""YOLO NCNN export using PNNX https://github.com/pnnx/pnnx."""
|
734
|
+
check_requirements("ncnn")
|
735
|
+
import ncnn # noqa
|
736
|
+
|
737
|
+
LOGGER.info(f"\n{prefix} starting export with NCNN {ncnn.__version__}...")
|
738
|
+
f = Path(str(self.file).replace(self.file.suffix, f"_ncnn_model{os.sep}"))
|
739
|
+
f_ts = self.file.with_suffix(".torchscript")
|
740
|
+
|
741
|
+
name = Path("pnnx.exe" if WINDOWS else "pnnx") # PNNX filename
|
742
|
+
pnnx = name if name.is_file() else (ROOT / name)
|
743
|
+
if not pnnx.is_file():
|
744
|
+
LOGGER.warning(
|
745
|
+
f"{prefix} PNNX not found. Attempting to download binary file from "
|
746
|
+
"https://github.com/pnnx/pnnx/.\nNote PNNX Binary file must be placed in current working directory "
|
747
|
+
f"or in {ROOT}. See PNNX repo for full installation instructions."
|
748
|
+
)
|
749
|
+
system = "macos" if MACOS else "windows" if WINDOWS else "linux-aarch64" if ARM64 else "linux"
|
750
|
+
try:
|
751
|
+
release, assets = get_github_assets(repo="pnnx/pnnx")
|
752
|
+
asset = [x for x in assets if f"{system}.zip" in x][0]
|
753
|
+
assert isinstance(asset, str), "Unable to retrieve PNNX repo assets" # i.e. pnnx-20240410-macos.zip
|
754
|
+
LOGGER.info(f"{prefix} successfully found latest PNNX asset file {asset}")
|
755
|
+
except Exception as e:
|
756
|
+
release = "20240410"
|
757
|
+
asset = f"pnnx-{release}-{system}.zip"
|
758
|
+
LOGGER.warning(f"{prefix} PNNX GitHub assets not found: {e}, using default {asset}")
|
759
|
+
unzip_dir = safe_download(f"https://github.com/pnnx/pnnx/releases/download/{release}/{asset}", delete=True)
|
760
|
+
if check_is_path_safe(Path.cwd(), unzip_dir): # avoid path traversal security vulnerability
|
761
|
+
shutil.move(src=unzip_dir / name, dst=pnnx) # move binary to ROOT
|
762
|
+
pnnx.chmod(0o777) # set read, write, and execute permissions for everyone
|
763
|
+
shutil.rmtree(unzip_dir) # delete unzip dir
|
764
|
+
|
765
|
+
ncnn_args = [
|
766
|
+
f"ncnnparam={f / 'model.ncnn.param'}",
|
767
|
+
f"ncnnbin={f / 'model.ncnn.bin'}",
|
768
|
+
f"ncnnpy={f / 'model_ncnn.py'}",
|
769
|
+
]
|
770
|
+
|
771
|
+
pnnx_args = [
|
772
|
+
f"pnnxparam={f / 'model.pnnx.param'}",
|
773
|
+
f"pnnxbin={f / 'model.pnnx.bin'}",
|
774
|
+
f"pnnxpy={f / 'model_pnnx.py'}",
|
775
|
+
f"pnnxonnx={f / 'model.pnnx.onnx'}",
|
776
|
+
]
|
777
|
+
|
778
|
+
cmd = [
|
779
|
+
str(pnnx),
|
780
|
+
str(f_ts),
|
781
|
+
*ncnn_args,
|
782
|
+
*pnnx_args,
|
783
|
+
f"fp16={int(self.args.half)}",
|
784
|
+
f"device={self.device.type}",
|
785
|
+
f'inputshape="{[self.args.batch, 3, *self.imgsz]}"',
|
786
|
+
]
|
787
|
+
f.mkdir(exist_ok=True) # make ncnn_model directory
|
788
|
+
LOGGER.info(f"{prefix} running '{' '.join(cmd)}'")
|
789
|
+
subprocess.run(cmd, check=True)
|
790
|
+
|
791
|
+
# Remove debug files
|
792
|
+
pnnx_files = [x.split("=")[-1] for x in pnnx_args]
|
793
|
+
for f_debug in ("debug.bin", "debug.param", "debug2.bin", "debug2.param", *pnnx_files):
|
794
|
+
Path(f_debug).unlink(missing_ok=True)
|
795
|
+
|
796
|
+
YAML.save(f / "metadata.yaml", self.metadata) # add metadata.yaml
|
797
|
+
return str(f), None
|
798
|
+
|
799
|
+
@try_export
|
800
|
+
def export_coreml(self, prefix=colorstr("CoreML:")):
|
801
|
+
"""YOLO CoreML export."""
|
802
|
+
mlmodel = self.args.format.lower() == "mlmodel" # legacy *.mlmodel export format requested
|
803
|
+
check_requirements("coremltools>=8.0")
|
804
|
+
import coremltools as ct # noqa
|
805
|
+
|
806
|
+
LOGGER.info(f"\n{prefix} starting export with coremltools {ct.__version__}...")
|
807
|
+
assert not WINDOWS, "CoreML export is not supported on Windows, please run on macOS or Linux."
|
808
|
+
assert self.args.batch == 1, "CoreML batch sizes > 1 are not supported. Please retry at 'batch=1'."
|
809
|
+
f = self.file.with_suffix(".mlmodel" if mlmodel else ".mlpackage")
|
810
|
+
if f.is_dir():
|
811
|
+
shutil.rmtree(f)
|
812
|
+
|
813
|
+
bias = [0.0, 0.0, 0.0]
|
814
|
+
scale = 1 / 255
|
815
|
+
classifier_config = None
|
816
|
+
if self.model.task == "classify":
|
817
|
+
classifier_config = ct.ClassifierConfig(list(self.model.names.values()))
|
818
|
+
model = self.model
|
819
|
+
elif self.model.task == "detect":
|
820
|
+
model = IOSDetectModel(self.model, self.im) if self.args.nms else self.model
|
821
|
+
else:
|
822
|
+
if self.args.nms:
|
823
|
+
LOGGER.warning(f"{prefix} 'nms=True' is only available for Detect models like 'yolo11n.pt'.")
|
824
|
+
# TODO CoreML Segment and Pose model pipelining
|
825
|
+
model = self.model
|
826
|
+
ts = torch.jit.trace(model.eval(), self.im, strict=False) # TorchScript model
|
827
|
+
|
828
|
+
# Based on apple's documentation it is better to leave out the minimum_deployment target and let that get set
|
829
|
+
# Internally based on the model conversion and output type.
|
830
|
+
# Setting minimum_depoloyment_target >= iOS16 will require setting compute_precision=ct.precision.FLOAT32.
|
831
|
+
# iOS16 adds in better support for FP16, but none of the CoreML NMS specifications handle FP16 as input.
|
832
|
+
ct_model = ct.convert(
|
833
|
+
ts,
|
834
|
+
inputs=[ct.ImageType("image", shape=self.im.shape, scale=scale, bias=bias)], # expects ct.TensorType
|
835
|
+
classifier_config=classifier_config,
|
836
|
+
convert_to="neuralnetwork" if mlmodel else "mlprogram",
|
837
|
+
)
|
838
|
+
bits, mode = (8, "kmeans") if self.args.int8 else (16, "linear") if self.args.half else (32, None)
|
839
|
+
if bits < 32:
|
840
|
+
if "kmeans" in mode:
|
841
|
+
check_requirements("scikit-learn") # scikit-learn package required for k-means quantization
|
842
|
+
if mlmodel:
|
843
|
+
ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
|
844
|
+
elif bits == 8: # mlprogram already quantized to FP16
|
845
|
+
import coremltools.optimize.coreml as cto
|
846
|
+
|
847
|
+
op_config = cto.OpPalettizerConfig(mode="kmeans", nbits=bits, weight_threshold=512)
|
848
|
+
config = cto.OptimizationConfig(global_config=op_config)
|
849
|
+
ct_model = cto.palettize_weights(ct_model, config=config)
|
850
|
+
if self.args.nms and self.model.task == "detect":
|
851
|
+
if mlmodel:
|
852
|
+
weights_dir = None
|
853
|
+
else:
|
854
|
+
ct_model.save(str(f)) # save otherwise weights_dir does not exist
|
855
|
+
weights_dir = str(f / "Data/com.apple.CoreML/weights")
|
856
|
+
ct_model = self._pipeline_coreml(ct_model, weights_dir=weights_dir)
|
857
|
+
|
858
|
+
m = self.metadata # metadata dict
|
859
|
+
ct_model.short_description = m.pop("description")
|
860
|
+
ct_model.author = m.pop("author")
|
861
|
+
ct_model.license = m.pop("license")
|
862
|
+
ct_model.version = m.pop("version")
|
863
|
+
ct_model.user_defined_metadata.update({k: str(v) for k, v in m.items()})
|
864
|
+
if self.model.task == "classify":
|
865
|
+
ct_model.user_defined_metadata.update({"com.apple.coreml.model.preview.type": "imageClassifier"})
|
866
|
+
|
867
|
+
try:
|
868
|
+
ct_model.save(str(f)) # save *.mlpackage
|
869
|
+
except Exception as e:
|
870
|
+
LOGGER.warning(
|
871
|
+
f"{prefix} CoreML export to *.mlpackage failed ({e}), reverting to *.mlmodel export. "
|
872
|
+
f"Known coremltools Python 3.11 and Windows bugs https://github.com/apple/coremltools/issues/1928."
|
873
|
+
)
|
874
|
+
f = f.with_suffix(".mlmodel")
|
875
|
+
ct_model.save(str(f))
|
876
|
+
return f, ct_model
|
877
|
+
|
878
|
+
@try_export
|
879
|
+
def export_engine(self, dla=None, prefix=colorstr("TensorRT:")):
|
880
|
+
"""YOLO TensorRT export https://developer.nvidia.com/tensorrt."""
|
881
|
+
assert self.im.device.type != "cpu", "export running on CPU but must be on GPU, i.e. use 'device=0'"
|
882
|
+
f_onnx, _ = self.export_onnx() # run before TRT import https://github.com/ultralytics/ultralytics/issues/7016
|
883
|
+
|
884
|
+
try:
|
885
|
+
import tensorrt as trt # noqa
|
886
|
+
except ImportError:
|
887
|
+
if LINUX:
|
888
|
+
check_requirements("tensorrt>7.0.0,!=10.1.0")
|
889
|
+
import tensorrt as trt # noqa
|
890
|
+
check_version(trt.__version__, ">=7.0.0", hard=True)
|
891
|
+
check_version(trt.__version__, "!=10.1.0", msg="https://github.com/ultralytics/ultralytics/pull/14239")
|
892
|
+
|
893
|
+
# Setup and checks
|
894
|
+
LOGGER.info(f"\n{prefix} starting export with TensorRT {trt.__version__}...")
|
895
|
+
assert Path(f_onnx).exists(), f"failed to export ONNX file: {f_onnx}"
|
896
|
+
f = self.file.with_suffix(".engine") # TensorRT engine file
|
897
|
+
export_engine(
|
898
|
+
f_onnx,
|
899
|
+
f,
|
900
|
+
self.args.workspace,
|
901
|
+
self.args.half,
|
902
|
+
self.args.int8,
|
903
|
+
self.args.dynamic,
|
904
|
+
self.im.shape,
|
905
|
+
dla=dla,
|
906
|
+
dataset=self.get_int8_calibration_dataloader(prefix) if self.args.int8 else None,
|
907
|
+
metadata=self.metadata,
|
908
|
+
verbose=self.args.verbose,
|
909
|
+
prefix=prefix,
|
910
|
+
)
|
911
|
+
|
912
|
+
return f, None
|
913
|
+
|
914
|
+
@try_export
|
915
|
+
def export_saved_model(self, prefix=colorstr("TensorFlow SavedModel:")):
|
916
|
+
"""YOLO TensorFlow SavedModel export."""
|
917
|
+
cuda = torch.cuda.is_available()
|
918
|
+
try:
|
919
|
+
import tensorflow as tf # noqa
|
920
|
+
except ImportError:
|
921
|
+
check_requirements("tensorflow>=2.0.0")
|
922
|
+
import tensorflow as tf # noqa
|
923
|
+
check_requirements(
|
924
|
+
(
|
925
|
+
"tf_keras", # required by 'onnx2tf' package
|
926
|
+
"sng4onnx>=1.0.1", # required by 'onnx2tf' package
|
927
|
+
"onnx_graphsurgeon>=0.3.26", # required by 'onnx2tf' package
|
928
|
+
"ai-edge-litert>=1.2.0", # required by 'onnx2tf' package
|
929
|
+
"onnx>=1.12.0",
|
930
|
+
"onnx2tf>=1.26.3",
|
931
|
+
"onnxslim>=0.1.46",
|
932
|
+
"onnxruntime-gpu" if cuda else "onnxruntime",
|
933
|
+
"protobuf>=5",
|
934
|
+
),
|
935
|
+
cmds="--extra-index-url https://pypi.ngc.nvidia.com", # onnx_graphsurgeon only on NVIDIA
|
936
|
+
)
|
937
|
+
|
938
|
+
LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
|
939
|
+
check_version(
|
940
|
+
tf.__version__,
|
941
|
+
">=2.0.0",
|
942
|
+
name="tensorflow",
|
943
|
+
verbose=True,
|
944
|
+
msg="https://github.com/ultralytics/ultralytics/issues/5161",
|
945
|
+
)
|
946
|
+
import onnx2tf
|
947
|
+
|
948
|
+
f = Path(str(self.file).replace(self.file.suffix, "_saved_model"))
|
949
|
+
if f.is_dir():
|
950
|
+
shutil.rmtree(f) # delete output folder
|
951
|
+
|
952
|
+
# Pre-download calibration file to fix https://github.com/PINTO0309/onnx2tf/issues/545
|
953
|
+
onnx2tf_file = Path("calibration_image_sample_data_20x128x128x3_float32.npy")
|
954
|
+
if not onnx2tf_file.exists():
|
955
|
+
attempt_download_asset(f"{onnx2tf_file}.zip", unzip=True, delete=True)
|
956
|
+
|
957
|
+
# Export to ONNX
|
958
|
+
self.args.simplify = True
|
959
|
+
f_onnx, _ = self.export_onnx()
|
960
|
+
|
961
|
+
# Export to TF
|
962
|
+
np_data = None
|
963
|
+
if self.args.int8:
|
964
|
+
tmp_file = f / "tmp_tflite_int8_calibration_images.npy" # int8 calibration images file
|
965
|
+
if self.args.data:
|
966
|
+
f.mkdir()
|
967
|
+
images = [batch["img"] for batch in self.get_int8_calibration_dataloader(prefix)]
|
968
|
+
images = torch.nn.functional.interpolate(torch.cat(images, 0).float(), size=self.imgsz).permute(
|
969
|
+
0, 2, 3, 1
|
970
|
+
)
|
971
|
+
np.save(str(tmp_file), images.numpy().astype(np.float32)) # BHWC
|
972
|
+
np_data = [["images", tmp_file, [[[[0, 0, 0]]]], [[[[255, 255, 255]]]]]]
|
973
|
+
|
974
|
+
LOGGER.info(f"{prefix} starting TFLite export with onnx2tf {onnx2tf.__version__}...")
|
975
|
+
keras_model = onnx2tf.convert(
|
976
|
+
input_onnx_file_path=f_onnx,
|
977
|
+
output_folder_path=str(f),
|
978
|
+
not_use_onnxsim=True,
|
979
|
+
verbosity="error", # note INT8-FP16 activation bug https://github.com/ultralytics/ultralytics/issues/15873
|
980
|
+
output_integer_quantized_tflite=self.args.int8,
|
981
|
+
quant_type="per-tensor", # "per-tensor" (faster) or "per-channel" (slower but more accurate)
|
982
|
+
custom_input_op_name_np_data_path=np_data,
|
983
|
+
enable_batchmatmul_unfold=True, # fix lower no. of detected objects on GPU delegate
|
984
|
+
output_signaturedefs=True, # fix error with Attention block group convolution
|
985
|
+
optimization_for_gpu_delegate=True,
|
986
|
+
)
|
987
|
+
YAML.save(f / "metadata.yaml", self.metadata) # add metadata.yaml
|
988
|
+
|
989
|
+
# Remove/rename TFLite models
|
990
|
+
if self.args.int8:
|
991
|
+
tmp_file.unlink(missing_ok=True)
|
992
|
+
for file in f.rglob("*_dynamic_range_quant.tflite"):
|
993
|
+
file.rename(file.with_name(file.stem.replace("_dynamic_range_quant", "_int8") + file.suffix))
|
994
|
+
for file in f.rglob("*_integer_quant_with_int16_act.tflite"):
|
995
|
+
file.unlink() # delete extra fp16 activation TFLite files
|
996
|
+
|
997
|
+
# Add TFLite metadata
|
998
|
+
for file in f.rglob("*.tflite"):
|
999
|
+
f.unlink() if "quant_with_int16_act.tflite" in str(f) else self._add_tflite_metadata(file)
|
1000
|
+
|
1001
|
+
return str(f), keras_model # or keras_model = tf.saved_model.load(f, tags=None, options=None)
|
1002
|
+
|
1003
|
+
@try_export
|
1004
|
+
def export_pb(self, keras_model, prefix=colorstr("TensorFlow GraphDef:")):
|
1005
|
+
"""YOLO TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen-Graph-TensorFlow."""
|
1006
|
+
import tensorflow as tf # noqa
|
1007
|
+
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa
|
1008
|
+
|
1009
|
+
LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
|
1010
|
+
f = self.file.with_suffix(".pb")
|
1011
|
+
|
1012
|
+
m = tf.function(lambda x: keras_model(x)) # full model
|
1013
|
+
m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
|
1014
|
+
frozen_func = convert_variables_to_constants_v2(m)
|
1015
|
+
frozen_func.graph.as_graph_def()
|
1016
|
+
tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)
|
1017
|
+
return f, None
|
1018
|
+
|
1019
|
+
@try_export
|
1020
|
+
def export_tflite(self, prefix=colorstr("TensorFlow Lite:")):
|
1021
|
+
"""YOLO TensorFlow Lite export."""
|
1022
|
+
# BUG https://github.com/ultralytics/ultralytics/issues/13436
|
1023
|
+
import tensorflow as tf # noqa
|
1024
|
+
|
1025
|
+
LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
|
1026
|
+
saved_model = Path(str(self.file).replace(self.file.suffix, "_saved_model"))
|
1027
|
+
if self.args.int8:
|
1028
|
+
f = saved_model / f"{self.file.stem}_int8.tflite" # fp32 in/out
|
1029
|
+
elif self.args.half:
|
1030
|
+
f = saved_model / f"{self.file.stem}_float16.tflite" # fp32 in/out
|
1031
|
+
else:
|
1032
|
+
f = saved_model / f"{self.file.stem}_float32.tflite"
|
1033
|
+
return str(f), None
|
1034
|
+
|
1035
|
+
@try_export
|
1036
|
+
def export_edgetpu(self, tflite_model="", prefix=colorstr("Edge TPU:")):
|
1037
|
+
"""YOLO Edge TPU export https://coral.ai/docs/edgetpu/models-intro/."""
|
1038
|
+
cmd = "edgetpu_compiler --version"
|
1039
|
+
help_url = "https://coral.ai/docs/edgetpu/compiler/"
|
1040
|
+
assert LINUX, f"export only supported on Linux. See {help_url}"
|
1041
|
+
if subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True).returncode != 0:
|
1042
|
+
LOGGER.info(f"\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}")
|
1043
|
+
for c in (
|
1044
|
+
"curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -",
|
1045
|
+
'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | '
|
1046
|
+
"sudo tee /etc/apt/sources.list.d/coral-edgetpu.list",
|
1047
|
+
"sudo apt-get update",
|
1048
|
+
"sudo apt-get install edgetpu-compiler",
|
1049
|
+
):
|
1050
|
+
subprocess.run(c if is_sudo_available() else c.replace("sudo ", ""), shell=True, check=True)
|
1051
|
+
ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]
|
1052
|
+
|
1053
|
+
LOGGER.info(f"\n{prefix} starting export with Edge TPU compiler {ver}...")
|
1054
|
+
f = str(tflite_model).replace(".tflite", "_edgetpu.tflite") # Edge TPU model
|
1055
|
+
|
1056
|
+
cmd = (
|
1057
|
+
"edgetpu_compiler "
|
1058
|
+
f'--out_dir "{Path(f).parent}" '
|
1059
|
+
"--show_operations "
|
1060
|
+
"--search_delegate "
|
1061
|
+
"--delegate_search_step 30 "
|
1062
|
+
"--timeout_sec 180 "
|
1063
|
+
f'"{tflite_model}"'
|
1064
|
+
)
|
1065
|
+
LOGGER.info(f"{prefix} running '{cmd}'")
|
1066
|
+
subprocess.run(cmd, shell=True)
|
1067
|
+
self._add_tflite_metadata(f)
|
1068
|
+
return f, None
|
1069
|
+
|
1070
|
+
@try_export
|
1071
|
+
def export_tfjs(self, prefix=colorstr("TensorFlow.js:")):
|
1072
|
+
"""YOLO TensorFlow.js export."""
|
1073
|
+
check_requirements("tensorflowjs")
|
1074
|
+
import tensorflow as tf
|
1075
|
+
import tensorflowjs as tfjs # noqa
|
1076
|
+
|
1077
|
+
LOGGER.info(f"\n{prefix} starting export with tensorflowjs {tfjs.__version__}...")
|
1078
|
+
f = str(self.file).replace(self.file.suffix, "_web_model") # js dir
|
1079
|
+
f_pb = str(self.file.with_suffix(".pb")) # *.pb path
|
1080
|
+
|
1081
|
+
gd = tf.Graph().as_graph_def() # TF GraphDef
|
1082
|
+
with open(f_pb, "rb") as file:
|
1083
|
+
gd.ParseFromString(file.read())
|
1084
|
+
outputs = ",".join(gd_outputs(gd))
|
1085
|
+
LOGGER.info(f"\n{prefix} output node names: {outputs}")
|
1086
|
+
|
1087
|
+
quantization = "--quantize_float16" if self.args.half else "--quantize_uint8" if self.args.int8 else ""
|
1088
|
+
with spaces_in_path(f_pb) as fpb_, spaces_in_path(f) as f_: # exporter can not handle spaces in path
|
1089
|
+
cmd = (
|
1090
|
+
"tensorflowjs_converter "
|
1091
|
+
f'--input_format=tf_frozen_model {quantization} --output_node_names={outputs} "{fpb_}" "{f_}"'
|
1092
|
+
)
|
1093
|
+
LOGGER.info(f"{prefix} running '{cmd}'")
|
1094
|
+
subprocess.run(cmd, shell=True)
|
1095
|
+
|
1096
|
+
if " " in f:
|
1097
|
+
LOGGER.warning(f"{prefix} your model may not work correctly with spaces in path '{f}'.")
|
1098
|
+
|
1099
|
+
# Add metadata
|
1100
|
+
YAML.save(Path(f) / "metadata.yaml", self.metadata) # add metadata.yaml
|
1101
|
+
return f, None
|
1102
|
+
|
1103
|
+
@try_export
|
1104
|
+
def export_rknn(self, prefix=colorstr("RKNN:")):
|
1105
|
+
"""YOLO RKNN model export."""
|
1106
|
+
LOGGER.info(f"\n{prefix} starting export with rknn-toolkit2...")
|
1107
|
+
|
1108
|
+
check_requirements("rknn-toolkit2")
|
1109
|
+
if IS_COLAB:
|
1110
|
+
# Prevent 'exit' from closing the notebook https://github.com/airockchip/rknn-toolkit2/issues/259
|
1111
|
+
import builtins
|
1112
|
+
|
1113
|
+
builtins.exit = lambda: None
|
1114
|
+
|
1115
|
+
from rknn.api import RKNN
|
1116
|
+
|
1117
|
+
f, _ = self.export_onnx()
|
1118
|
+
export_path = Path(f"{Path(f).stem}_rknn_model")
|
1119
|
+
export_path.mkdir(exist_ok=True)
|
1120
|
+
|
1121
|
+
rknn = RKNN(verbose=False)
|
1122
|
+
rknn.config(mean_values=[[0, 0, 0]], std_values=[[255, 255, 255]], target_platform=self.args.name)
|
1123
|
+
rknn.load_onnx(model=f)
|
1124
|
+
rknn.build(do_quantization=False) # TODO: Add quantization support
|
1125
|
+
f = f.replace(".onnx", f"-{self.args.name}.rknn")
|
1126
|
+
rknn.export_rknn(f"{export_path / f}")
|
1127
|
+
YAML.save(export_path / "metadata.yaml", self.metadata)
|
1128
|
+
return export_path, None
|
1129
|
+
|
1130
|
+
@try_export
|
1131
|
+
def export_imx(self, prefix=colorstr("IMX:")):
|
1132
|
+
"""YOLO IMX export."""
|
1133
|
+
gptq = False
|
1134
|
+
assert LINUX, (
|
1135
|
+
"export only supported on Linux. "
|
1136
|
+
"See https://developer.aitrios.sony-semicon.com/en/raspberrypi-ai-camera/documentation/imx500-converter"
|
1137
|
+
)
|
1138
|
+
if getattr(self.model, "end2end", False):
|
1139
|
+
raise ValueError("IMX export is not supported for end2end models.")
|
1140
|
+
check_requirements(("model-compression-toolkit>=2.3.0", "sony-custom-layers>=0.3.0", "edge-mdt-tpc>=1.1.0"))
|
1141
|
+
check_requirements("imx500-converter[pt]>=3.16.1") # Separate requirements for imx500-converter
|
1142
|
+
|
1143
|
+
import model_compression_toolkit as mct
|
1144
|
+
import onnx
|
1145
|
+
from edgemdt_tpc import get_target_platform_capabilities
|
1146
|
+
from sony_custom_layers.pytorch import multiclass_nms
|
1147
|
+
|
1148
|
+
LOGGER.info(f"\n{prefix} starting export with model_compression_toolkit {mct.__version__}...")
|
1149
|
+
|
1150
|
+
# Install Java>=17
|
1151
|
+
try:
|
1152
|
+
java_output = subprocess.run(["java", "--version"], check=True, capture_output=True).stdout.decode()
|
1153
|
+
version_match = re.search(r"(?:openjdk|java) (\d+)", java_output)
|
1154
|
+
java_version = int(version_match.group(1)) if version_match else 0
|
1155
|
+
assert java_version >= 17, "Java version too old"
|
1156
|
+
except (FileNotFoundError, subprocess.CalledProcessError, AssertionError):
|
1157
|
+
cmd = (["sudo"] if is_sudo_available() else []) + ["apt", "install", "-y", "openjdk-21-jre"]
|
1158
|
+
subprocess.run(cmd, check=True)
|
1159
|
+
|
1160
|
+
def representative_dataset_gen(dataloader=self.get_int8_calibration_dataloader(prefix)):
|
1161
|
+
for batch in dataloader:
|
1162
|
+
img = batch["img"]
|
1163
|
+
img = img / 255.0
|
1164
|
+
yield [img]
|
1165
|
+
|
1166
|
+
tpc = get_target_platform_capabilities(tpc_version="4.0", device_type="imx500")
|
1167
|
+
|
1168
|
+
bit_cfg = mct.core.BitWidthConfig()
|
1169
|
+
if "C2PSA" in self.model.__str__(): # YOLO11
|
1170
|
+
layer_names = ["sub", "mul_2", "add_14", "cat_21"]
|
1171
|
+
weights_memory = 2585350.2439
|
1172
|
+
n_layers = 238 # 238 layers for fused YOLO11n
|
1173
|
+
else: # YOLOv8
|
1174
|
+
layer_names = ["sub", "mul", "add_6", "cat_17"]
|
1175
|
+
weights_memory = 2550540.8
|
1176
|
+
n_layers = 168 # 168 layers for fused YOLOv8n
|
1177
|
+
|
1178
|
+
# Check if the model has the expected number of layers
|
1179
|
+
if len(list(self.model.modules())) != n_layers:
|
1180
|
+
raise ValueError("IMX export only supported for YOLOv8n and YOLO11n models.")
|
1181
|
+
|
1182
|
+
for layer_name in layer_names:
|
1183
|
+
bit_cfg.set_manual_activation_bit_width([mct.core.common.network_editors.NodeNameFilter(layer_name)], 16)
|
1184
|
+
|
1185
|
+
config = mct.core.CoreConfig(
|
1186
|
+
mixed_precision_config=mct.core.MixedPrecisionQuantizationConfig(num_of_images=10),
|
1187
|
+
quantization_config=mct.core.QuantizationConfig(concat_threshold_update=True),
|
1188
|
+
bit_width_config=bit_cfg,
|
1189
|
+
)
|
1190
|
+
|
1191
|
+
resource_utilization = mct.core.ResourceUtilization(weights_memory=weights_memory)
|
1192
|
+
|
1193
|
+
quant_model = (
|
1194
|
+
mct.gptq.pytorch_gradient_post_training_quantization( # Perform Gradient-Based Post Training Quantization
|
1195
|
+
model=self.model,
|
1196
|
+
representative_data_gen=representative_dataset_gen,
|
1197
|
+
target_resource_utilization=resource_utilization,
|
1198
|
+
gptq_config=mct.gptq.get_pytorch_gptq_config(
|
1199
|
+
n_epochs=1000, use_hessian_based_weights=False, use_hessian_sample_attention=False
|
1200
|
+
),
|
1201
|
+
core_config=config,
|
1202
|
+
target_platform_capabilities=tpc,
|
1203
|
+
)[0]
|
1204
|
+
if gptq
|
1205
|
+
else mct.ptq.pytorch_post_training_quantization( # Perform post training quantization
|
1206
|
+
in_module=self.model,
|
1207
|
+
representative_data_gen=representative_dataset_gen,
|
1208
|
+
target_resource_utilization=resource_utilization,
|
1209
|
+
core_config=config,
|
1210
|
+
target_platform_capabilities=tpc,
|
1211
|
+
)[0]
|
1212
|
+
)
|
1213
|
+
|
1214
|
+
class NMSWrapper(torch.nn.Module):
|
1215
|
+
def __init__(
|
1216
|
+
self,
|
1217
|
+
model: torch.nn.Module,
|
1218
|
+
score_threshold: float = 0.001,
|
1219
|
+
iou_threshold: float = 0.7,
|
1220
|
+
max_detections: int = 300,
|
1221
|
+
):
|
1222
|
+
"""
|
1223
|
+
Wrapping PyTorch Module with multiclass_nms layer from sony_custom_layers.
|
1224
|
+
|
1225
|
+
Args:
|
1226
|
+
model (nn.Module): Model instance.
|
1227
|
+
score_threshold (float): Score threshold for non-maximum suppression.
|
1228
|
+
iou_threshold (float): Intersection over union threshold for non-maximum suppression.
|
1229
|
+
max_detections (float): The number of detections to return.
|
1230
|
+
"""
|
1231
|
+
super().__init__()
|
1232
|
+
self.model = model
|
1233
|
+
self.score_threshold = score_threshold
|
1234
|
+
self.iou_threshold = iou_threshold
|
1235
|
+
self.max_detections = max_detections
|
1236
|
+
|
1237
|
+
def forward(self, images):
|
1238
|
+
# model inference
|
1239
|
+
outputs = self.model(images)
|
1240
|
+
|
1241
|
+
boxes = outputs[0]
|
1242
|
+
scores = outputs[1]
|
1243
|
+
nms = multiclass_nms(
|
1244
|
+
boxes=boxes,
|
1245
|
+
scores=scores,
|
1246
|
+
score_threshold=self.score_threshold,
|
1247
|
+
iou_threshold=self.iou_threshold,
|
1248
|
+
max_detections=self.max_detections,
|
1249
|
+
)
|
1250
|
+
return nms
|
1251
|
+
|
1252
|
+
quant_model = NMSWrapper(
|
1253
|
+
model=quant_model,
|
1254
|
+
score_threshold=self.args.conf or 0.001,
|
1255
|
+
iou_threshold=self.args.iou,
|
1256
|
+
max_detections=self.args.max_det,
|
1257
|
+
).to(self.device)
|
1258
|
+
|
1259
|
+
f = Path(str(self.file).replace(self.file.suffix, "_imx_model"))
|
1260
|
+
f.mkdir(exist_ok=True)
|
1261
|
+
onnx_model = f / Path(str(self.file.name).replace(self.file.suffix, "_imx.onnx")) # js dir
|
1262
|
+
mct.exporter.pytorch_export_model(
|
1263
|
+
model=quant_model, save_model_path=onnx_model, repr_dataset=representative_dataset_gen
|
1264
|
+
)
|
1265
|
+
|
1266
|
+
model_onnx = onnx.load(onnx_model) # load onnx model
|
1267
|
+
for k, v in self.metadata.items():
|
1268
|
+
meta = model_onnx.metadata_props.add()
|
1269
|
+
meta.key, meta.value = k, str(v)
|
1270
|
+
|
1271
|
+
onnx.save(model_onnx, onnx_model)
|
1272
|
+
|
1273
|
+
subprocess.run(
|
1274
|
+
["imxconv-pt", "-i", str(onnx_model), "-o", str(f), "--no-input-persistency", "--overwrite-output"],
|
1275
|
+
check=True,
|
1276
|
+
)
|
1277
|
+
|
1278
|
+
# Needed for imx models.
|
1279
|
+
with open(f / "labels.txt", "w", encoding="utf-8") as file:
|
1280
|
+
file.writelines([f"{name}\n" for _, name in self.model.names.items()])
|
1281
|
+
|
1282
|
+
return f, None
|
1283
|
+
|
1284
|
+
def _add_tflite_metadata(self, file):
|
1285
|
+
"""Add metadata to *.tflite models per https://ai.google.dev/edge/litert/models/metadata."""
|
1286
|
+
import zipfile
|
1287
|
+
|
1288
|
+
with zipfile.ZipFile(file, "a", zipfile.ZIP_DEFLATED) as zf:
|
1289
|
+
zf.writestr("metadata.json", json.dumps(self.metadata, indent=2))
|
1290
|
+
|
1291
|
+
def _pipeline_coreml(self, model, weights_dir=None, prefix=colorstr("CoreML Pipeline:")):
|
1292
|
+
"""YOLO CoreML pipeline."""
|
1293
|
+
import coremltools as ct # noqa
|
1294
|
+
|
1295
|
+
LOGGER.info(f"{prefix} starting pipeline with coremltools {ct.__version__}...")
|
1296
|
+
_, _, h, w = list(self.im.shape) # BCHW
|
1297
|
+
|
1298
|
+
# Output shapes
|
1299
|
+
spec = model.get_spec()
|
1300
|
+
out0, out1 = iter(spec.description.output)
|
1301
|
+
if MACOS:
|
1302
|
+
from PIL import Image
|
1303
|
+
|
1304
|
+
img = Image.new("RGB", (w, h)) # w=192, h=320
|
1305
|
+
out = model.predict({"image": img})
|
1306
|
+
out0_shape = out[out0.name].shape # (3780, 80)
|
1307
|
+
out1_shape = out[out1.name].shape # (3780, 4)
|
1308
|
+
else: # linux and windows can not run model.predict(), get sizes from PyTorch model output y
|
1309
|
+
out0_shape = self.output_shape[2], self.output_shape[1] - 4 # (3780, 80)
|
1310
|
+
out1_shape = self.output_shape[2], 4 # (3780, 4)
|
1311
|
+
|
1312
|
+
# Checks
|
1313
|
+
names = self.metadata["names"]
|
1314
|
+
nx, ny = spec.description.input[0].type.imageType.width, spec.description.input[0].type.imageType.height
|
1315
|
+
_, nc = out0_shape # number of anchors, number of classes
|
1316
|
+
assert len(names) == nc, f"{len(names)} names found for nc={nc}" # check
|
1317
|
+
|
1318
|
+
# Define output shapes (missing)
|
1319
|
+
out0.type.multiArrayType.shape[:] = out0_shape # (3780, 80)
|
1320
|
+
out1.type.multiArrayType.shape[:] = out1_shape # (3780, 4)
|
1321
|
+
|
1322
|
+
# Model from spec
|
1323
|
+
model = ct.models.MLModel(spec, weights_dir=weights_dir)
|
1324
|
+
|
1325
|
+
# 3. Create NMS protobuf
|
1326
|
+
nms_spec = ct.proto.Model_pb2.Model()
|
1327
|
+
nms_spec.specificationVersion = spec.specificationVersion
|
1328
|
+
for i in range(2):
|
1329
|
+
decoder_output = model._spec.description.output[i].SerializeToString()
|
1330
|
+
nms_spec.description.input.add()
|
1331
|
+
nms_spec.description.input[i].ParseFromString(decoder_output)
|
1332
|
+
nms_spec.description.output.add()
|
1333
|
+
nms_spec.description.output[i].ParseFromString(decoder_output)
|
1334
|
+
|
1335
|
+
nms_spec.description.output[0].name = "confidence"
|
1336
|
+
nms_spec.description.output[1].name = "coordinates"
|
1337
|
+
|
1338
|
+
output_sizes = [nc, 4]
|
1339
|
+
for i in range(2):
|
1340
|
+
ma_type = nms_spec.description.output[i].type.multiArrayType
|
1341
|
+
ma_type.shapeRange.sizeRanges.add()
|
1342
|
+
ma_type.shapeRange.sizeRanges[0].lowerBound = 0
|
1343
|
+
ma_type.shapeRange.sizeRanges[0].upperBound = -1
|
1344
|
+
ma_type.shapeRange.sizeRanges.add()
|
1345
|
+
ma_type.shapeRange.sizeRanges[1].lowerBound = output_sizes[i]
|
1346
|
+
ma_type.shapeRange.sizeRanges[1].upperBound = output_sizes[i]
|
1347
|
+
del ma_type.shape[:]
|
1348
|
+
|
1349
|
+
nms = nms_spec.nonMaximumSuppression
|
1350
|
+
nms.confidenceInputFeatureName = out0.name # 1x507x80
|
1351
|
+
nms.coordinatesInputFeatureName = out1.name # 1x507x4
|
1352
|
+
nms.confidenceOutputFeatureName = "confidence"
|
1353
|
+
nms.coordinatesOutputFeatureName = "coordinates"
|
1354
|
+
nms.iouThresholdInputFeatureName = "iouThreshold"
|
1355
|
+
nms.confidenceThresholdInputFeatureName = "confidenceThreshold"
|
1356
|
+
nms.iouThreshold = self.args.iou
|
1357
|
+
nms.confidenceThreshold = self.args.conf
|
1358
|
+
nms.pickTop.perClass = True
|
1359
|
+
nms.stringClassLabels.vector.extend(names.values())
|
1360
|
+
nms_model = ct.models.MLModel(nms_spec)
|
1361
|
+
|
1362
|
+
# 4. Pipeline models together
|
1363
|
+
pipeline = ct.models.pipeline.Pipeline(
|
1364
|
+
input_features=[
|
1365
|
+
("image", ct.models.datatypes.Array(3, ny, nx)),
|
1366
|
+
("iouThreshold", ct.models.datatypes.Double()),
|
1367
|
+
("confidenceThreshold", ct.models.datatypes.Double()),
|
1368
|
+
],
|
1369
|
+
output_features=["confidence", "coordinates"],
|
1370
|
+
)
|
1371
|
+
pipeline.add_model(model)
|
1372
|
+
pipeline.add_model(nms_model)
|
1373
|
+
|
1374
|
+
# Correct datatypes
|
1375
|
+
pipeline.spec.description.input[0].ParseFromString(model._spec.description.input[0].SerializeToString())
|
1376
|
+
pipeline.spec.description.output[0].ParseFromString(nms_model._spec.description.output[0].SerializeToString())
|
1377
|
+
pipeline.spec.description.output[1].ParseFromString(nms_model._spec.description.output[1].SerializeToString())
|
1378
|
+
|
1379
|
+
# Update metadata
|
1380
|
+
pipeline.spec.specificationVersion = spec.specificationVersion
|
1381
|
+
pipeline.spec.description.metadata.userDefined.update(
|
1382
|
+
{"IoU threshold": str(nms.iouThreshold), "Confidence threshold": str(nms.confidenceThreshold)}
|
1383
|
+
)
|
1384
|
+
|
1385
|
+
# Save the model
|
1386
|
+
model = ct.models.MLModel(pipeline.spec, weights_dir=weights_dir)
|
1387
|
+
model.input_description["image"] = "Input image"
|
1388
|
+
model.input_description["iouThreshold"] = f"(optional) IoU threshold override (default: {nms.iouThreshold})"
|
1389
|
+
model.input_description["confidenceThreshold"] = (
|
1390
|
+
f"(optional) Confidence threshold override (default: {nms.confidenceThreshold})"
|
1391
|
+
)
|
1392
|
+
model.output_description["confidence"] = 'Boxes × Class confidence (see user-defined metadata "classes")'
|
1393
|
+
model.output_description["coordinates"] = "Boxes × [x, y, width, height] (relative to image size)"
|
1394
|
+
LOGGER.info(f"{prefix} pipeline success")
|
1395
|
+
return model
|
1396
|
+
|
1397
|
+
def add_callback(self, event: str, callback):
|
1398
|
+
"""Appends the given callback."""
|
1399
|
+
self.callbacks[event].append(callback)
|
1400
|
+
|
1401
|
+
def run_callbacks(self, event: str):
|
1402
|
+
"""Execute all callbacks for a given event."""
|
1403
|
+
for callback in self.callbacks.get(event, []):
|
1404
|
+
callback(self)
|
1405
|
+
|
1406
|
+
|
1407
|
+
class IOSDetectModel(torch.nn.Module):
|
1408
|
+
"""Wrap an Ultralytics YOLO model for Apple iOS CoreML export."""
|
1409
|
+
|
1410
|
+
def __init__(self, model, im):
|
1411
|
+
"""Initialize the IOSDetectModel class with a YOLO model and example image."""
|
1412
|
+
super().__init__()
|
1413
|
+
_, _, h, w = im.shape # batch, channel, height, width
|
1414
|
+
self.model = model
|
1415
|
+
self.nc = len(model.names) # number of classes
|
1416
|
+
if w == h:
|
1417
|
+
self.normalize = 1.0 / w # scalar
|
1418
|
+
else:
|
1419
|
+
self.normalize = torch.tensor([1.0 / w, 1.0 / h, 1.0 / w, 1.0 / h]) # broadcast (slower, smaller)
|
1420
|
+
|
1421
|
+
def forward(self, x):
|
1422
|
+
"""Normalize predictions of object detection model with input size-dependent factors."""
|
1423
|
+
xywh, cls = self.model(x)[0].transpose(0, 1).split((4, self.nc), 1)
|
1424
|
+
return cls, xywh * self.normalize # confidence (3780, 80), coordinates (3780, 4)
|
1425
|
+
|
1426
|
+
|
1427
|
+
class NMSModel(torch.nn.Module):
|
1428
|
+
"""Model wrapper with embedded NMS for Detect, Segment, Pose and OBB."""
|
1429
|
+
|
1430
|
+
def __init__(self, model, args):
|
1431
|
+
"""
|
1432
|
+
Initialize the NMSModel.
|
1433
|
+
|
1434
|
+
Args:
|
1435
|
+
model (torch.nn.module): The model to wrap with NMS postprocessing.
|
1436
|
+
args (Namespace): The export arguments.
|
1437
|
+
"""
|
1438
|
+
super().__init__()
|
1439
|
+
self.model = model
|
1440
|
+
self.args = args
|
1441
|
+
self.obb = model.task == "obb"
|
1442
|
+
self.is_tf = self.args.format in frozenset({"saved_model", "tflite", "tfjs"})
|
1443
|
+
|
1444
|
+
def forward(self, x):
|
1445
|
+
"""
|
1446
|
+
Performs inference with NMS post-processing. Supports Detect, Segment, OBB and Pose.
|
1447
|
+
|
1448
|
+
Args:
|
1449
|
+
x (torch.Tensor): The preprocessed tensor with shape (N, 3, H, W).
|
1450
|
+
|
1451
|
+
Returns:
|
1452
|
+
(torch.Tensor): List of detections, each an (N, max_det, 4 + 2 + extra_shape) Tensor where N is the number of detections after NMS.
|
1453
|
+
"""
|
1454
|
+
from functools import partial
|
1455
|
+
|
1456
|
+
from torchvision.ops import nms
|
1457
|
+
|
1458
|
+
preds = self.model(x)
|
1459
|
+
pred = preds[0] if isinstance(preds, tuple) else preds
|
1460
|
+
kwargs = dict(device=pred.device, dtype=pred.dtype)
|
1461
|
+
bs = pred.shape[0]
|
1462
|
+
pred = pred.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
|
1463
|
+
extra_shape = pred.shape[-1] - (4 + len(self.model.names)) # extras from Segment, OBB, Pose
|
1464
|
+
if self.args.dynamic and self.args.batch > 1: # batch size needs to always be same due to loop unroll
|
1465
|
+
pad = torch.zeros(torch.max(torch.tensor(self.args.batch - bs), torch.tensor(0)), *pred.shape[1:], **kwargs)
|
1466
|
+
pred = torch.cat((pred, pad))
|
1467
|
+
boxes, scores, extras = pred.split([4, len(self.model.names), extra_shape], dim=2)
|
1468
|
+
scores, classes = scores.max(dim=-1)
|
1469
|
+
self.args.max_det = min(pred.shape[1], self.args.max_det) # in case num_anchors < max_det
|
1470
|
+
# (N, max_det, 4 coords + 1 class score + 1 class label + extra_shape).
|
1471
|
+
out = torch.zeros(bs, self.args.max_det, boxes.shape[-1] + 2 + extra_shape, **kwargs)
|
1472
|
+
for i in range(bs):
|
1473
|
+
box, cls, score, extra = boxes[i], classes[i], scores[i], extras[i]
|
1474
|
+
mask = score > self.args.conf
|
1475
|
+
if self.is_tf:
|
1476
|
+
# TFLite GatherND error if mask is empty
|
1477
|
+
score *= mask
|
1478
|
+
# Explicit length otherwise reshape error, hardcoded to `self.args.max_det * 5`
|
1479
|
+
mask = score.topk(min(self.args.max_det * 5, score.shape[0])).indices
|
1480
|
+
box, score, cls, extra = box[mask], score[mask], cls[mask], extra[mask]
|
1481
|
+
nmsbox = box.clone()
|
1482
|
+
# `8` is the minimum value experimented to get correct NMS results for obb
|
1483
|
+
multiplier = 8 if self.obb else 1
|
1484
|
+
# Normalize boxes for NMS since large values for class offset causes issue with int8 quantization
|
1485
|
+
if self.args.format == "tflite": # TFLite is already normalized
|
1486
|
+
nmsbox *= multiplier
|
1487
|
+
else:
|
1488
|
+
nmsbox = multiplier * nmsbox / torch.tensor(x.shape[2:], **kwargs).max()
|
1489
|
+
if not self.args.agnostic_nms: # class-specific NMS
|
1490
|
+
end = 2 if self.obb else 4
|
1491
|
+
# fully explicit expansion otherwise reshape error
|
1492
|
+
# large max_wh causes issues when quantizing
|
1493
|
+
cls_offset = cls.reshape(-1, 1).expand(nmsbox.shape[0], end)
|
1494
|
+
offbox = nmsbox[:, :end] + cls_offset * multiplier
|
1495
|
+
nmsbox = torch.cat((offbox, nmsbox[:, end:]), dim=-1)
|
1496
|
+
nms_fn = (
|
1497
|
+
partial(
|
1498
|
+
nms_rotated,
|
1499
|
+
use_triu=not (
|
1500
|
+
self.is_tf
|
1501
|
+
or (self.args.opset or 14) < 14
|
1502
|
+
or (self.args.format == "openvino" and self.args.int8) # OpenVINO int8 error with triu
|
1503
|
+
),
|
1504
|
+
)
|
1505
|
+
if self.obb
|
1506
|
+
else nms
|
1507
|
+
)
|
1508
|
+
keep = nms_fn(
|
1509
|
+
torch.cat([nmsbox, extra], dim=-1) if self.obb else nmsbox,
|
1510
|
+
score,
|
1511
|
+
self.args.iou,
|
1512
|
+
)[: self.args.max_det]
|
1513
|
+
dets = torch.cat(
|
1514
|
+
[box[keep], score[keep].view(-1, 1), cls[keep].view(-1, 1).to(out.dtype), extra[keep]], dim=-1
|
1515
|
+
)
|
1516
|
+
# Zero-pad to max_det size to avoid reshape error
|
1517
|
+
pad = (0, 0, 0, self.args.max_det - dets.shape[0])
|
1518
|
+
out[i] = torch.nn.functional.pad(dets, pad)
|
1519
|
+
return (out[:bs], preds[1]) if self.model.task == "segment" else out[:bs]
|