ultralytics 8.1.28__py3-none-any.whl → 8.3.62__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +22 -0
- tests/conftest.py +83 -0
- tests/test_cli.py +122 -0
- tests/test_cuda.py +155 -0
- tests/test_engine.py +131 -0
- tests/test_exports.py +216 -0
- tests/test_integrations.py +150 -0
- tests/test_python.py +615 -0
- tests/test_solutions.py +94 -0
- ultralytics/__init__.py +11 -8
- ultralytics/cfg/__init__.py +569 -131
- ultralytics/cfg/datasets/Argoverse.yaml +2 -1
- ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
- ultralytics/cfg/datasets/ImageNet.yaml +2 -1
- ultralytics/cfg/datasets/Objects365.yaml +5 -4
- ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
- ultralytics/cfg/datasets/VOC.yaml +3 -2
- ultralytics/cfg/datasets/VisDrone.yaml +6 -5
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
- ultralytics/cfg/datasets/coco-pose.yaml +7 -6
- ultralytics/cfg/datasets/coco.yaml +3 -2
- ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
- ultralytics/cfg/datasets/coco128.yaml +4 -3
- ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
- ultralytics/cfg/datasets/coco8.yaml +3 -2
- ultralytics/cfg/datasets/crack-seg.yaml +3 -2
- ultralytics/cfg/datasets/dog-pose.yaml +24 -0
- ultralytics/cfg/datasets/dota8.yaml +3 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
- ultralytics/cfg/datasets/lvis.yaml +1236 -0
- ultralytics/cfg/datasets/medical-pills.yaml +22 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
- ultralytics/cfg/datasets/package-seg.yaml +5 -4
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
- ultralytics/cfg/datasets/xView.yaml +2 -1
- ultralytics/cfg/default.yaml +14 -11
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
- ultralytics/cfg/models/v3/yolov3.yaml +5 -2
- ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
- ultralytics/cfg/models/v5/yolov5.yaml +5 -2
- ultralytics/cfg/models/v6/yolov6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8.yaml +5 -2
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/solutions/default.yaml +24 -0
- ultralytics/cfg/trackers/botsort.yaml +8 -5
- ultralytics/cfg/trackers/bytetrack.yaml +8 -5
- ultralytics/data/__init__.py +14 -3
- ultralytics/data/annotator.py +37 -15
- ultralytics/data/augment.py +1783 -289
- ultralytics/data/base.py +62 -27
- ultralytics/data/build.py +36 -8
- ultralytics/data/converter.py +196 -36
- ultralytics/data/dataset.py +233 -94
- ultralytics/data/loaders.py +199 -96
- ultralytics/data/split_dota.py +39 -29
- ultralytics/data/utils.py +110 -40
- ultralytics/engine/__init__.py +1 -1
- ultralytics/engine/exporter.py +569 -242
- ultralytics/engine/model.py +604 -252
- ultralytics/engine/predictor.py +22 -11
- ultralytics/engine/results.py +1228 -218
- ultralytics/engine/trainer.py +190 -129
- ultralytics/engine/tuner.py +18 -18
- ultralytics/engine/validator.py +18 -15
- ultralytics/hub/__init__.py +31 -13
- ultralytics/hub/auth.py +11 -7
- ultralytics/hub/google/__init__.py +159 -0
- ultralytics/hub/session.py +128 -94
- ultralytics/hub/utils.py +20 -21
- ultralytics/models/__init__.py +4 -2
- ultralytics/models/fastsam/__init__.py +2 -3
- ultralytics/models/fastsam/model.py +26 -4
- ultralytics/models/fastsam/predict.py +127 -63
- ultralytics/models/fastsam/utils.py +1 -44
- ultralytics/models/fastsam/val.py +1 -1
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +21 -10
- ultralytics/models/nas/predict.py +3 -6
- ultralytics/models/nas/val.py +4 -4
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +1 -1
- ultralytics/models/rtdetr/predict.py +6 -8
- ultralytics/models/rtdetr/train.py +6 -2
- ultralytics/models/rtdetr/val.py +3 -3
- ultralytics/models/sam/__init__.py +3 -3
- ultralytics/models/sam/amg.py +29 -23
- ultralytics/models/sam/build.py +211 -13
- ultralytics/models/sam/model.py +91 -30
- ultralytics/models/sam/modules/__init__.py +1 -1
- ultralytics/models/sam/modules/blocks.py +1129 -0
- ultralytics/models/sam/modules/decoders.py +381 -53
- ultralytics/models/sam/modules/encoders.py +515 -324
- ultralytics/models/sam/modules/memory_attention.py +237 -0
- ultralytics/models/sam/modules/sam.py +969 -21
- ultralytics/models/sam/modules/tiny_encoder.py +425 -154
- ultralytics/models/sam/modules/transformer.py +159 -60
- ultralytics/models/sam/modules/utils.py +293 -0
- ultralytics/models/sam/predict.py +1263 -132
- ultralytics/models/utils/__init__.py +1 -1
- ultralytics/models/utils/loss.py +36 -24
- ultralytics/models/utils/ops.py +3 -7
- ultralytics/models/yolo/__init__.py +3 -3
- ultralytics/models/yolo/classify/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +7 -8
- ultralytics/models/yolo/classify/train.py +17 -22
- ultralytics/models/yolo/classify/val.py +8 -4
- ultralytics/models/yolo/detect/__init__.py +1 -1
- ultralytics/models/yolo/detect/predict.py +3 -5
- ultralytics/models/yolo/detect/train.py +11 -4
- ultralytics/models/yolo/detect/val.py +90 -52
- ultralytics/models/yolo/model.py +14 -9
- ultralytics/models/yolo/obb/__init__.py +1 -1
- ultralytics/models/yolo/obb/predict.py +2 -2
- ultralytics/models/yolo/obb/train.py +5 -3
- ultralytics/models/yolo/obb/val.py +41 -23
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +3 -5
- ultralytics/models/yolo/pose/train.py +2 -2
- ultralytics/models/yolo/pose/val.py +51 -17
- ultralytics/models/yolo/segment/__init__.py +1 -1
- ultralytics/models/yolo/segment/predict.py +3 -5
- ultralytics/models/yolo/segment/train.py +2 -2
- ultralytics/models/yolo/segment/val.py +60 -19
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +92 -0
- ultralytics/models/yolo/world/train_world.py +109 -0
- ultralytics/nn/__init__.py +1 -1
- ultralytics/nn/autobackend.py +228 -93
- ultralytics/nn/modules/__init__.py +39 -14
- ultralytics/nn/modules/activation.py +21 -0
- ultralytics/nn/modules/block.py +527 -67
- ultralytics/nn/modules/conv.py +24 -7
- ultralytics/nn/modules/head.py +177 -34
- ultralytics/nn/modules/transformer.py +6 -5
- ultralytics/nn/modules/utils.py +1 -2
- ultralytics/nn/tasks.py +225 -77
- ultralytics/solutions/__init__.py +30 -1
- ultralytics/solutions/ai_gym.py +96 -143
- ultralytics/solutions/analytics.py +247 -0
- ultralytics/solutions/distance_calculation.py +78 -135
- ultralytics/solutions/heatmap.py +93 -247
- ultralytics/solutions/object_counter.py +184 -259
- ultralytics/solutions/parking_management.py +246 -0
- ultralytics/solutions/queue_management.py +112 -0
- ultralytics/solutions/region_counter.py +116 -0
- ultralytics/solutions/security_alarm.py +144 -0
- ultralytics/solutions/solutions.py +178 -0
- ultralytics/solutions/speed_estimation.py +86 -174
- ultralytics/solutions/streamlit_inference.py +190 -0
- ultralytics/solutions/trackzone.py +68 -0
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +32 -13
- ultralytics/trackers/bot_sort.py +61 -28
- ultralytics/trackers/byte_tracker.py +83 -51
- ultralytics/trackers/track.py +21 -6
- ultralytics/trackers/utils/__init__.py +1 -1
- ultralytics/trackers/utils/gmc.py +62 -48
- ultralytics/trackers/utils/kalman_filter.py +166 -35
- ultralytics/trackers/utils/matching.py +40 -21
- ultralytics/utils/__init__.py +511 -239
- ultralytics/utils/autobatch.py +40 -22
- ultralytics/utils/benchmarks.py +266 -85
- ultralytics/utils/callbacks/__init__.py +1 -1
- ultralytics/utils/callbacks/base.py +1 -3
- ultralytics/utils/callbacks/clearml.py +7 -6
- ultralytics/utils/callbacks/comet.py +39 -17
- ultralytics/utils/callbacks/dvc.py +1 -1
- ultralytics/utils/callbacks/hub.py +16 -16
- ultralytics/utils/callbacks/mlflow.py +28 -24
- ultralytics/utils/callbacks/neptune.py +6 -2
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +18 -18
- ultralytics/utils/callbacks/wb.py +27 -20
- ultralytics/utils/checks.py +160 -100
- ultralytics/utils/dist.py +2 -1
- ultralytics/utils/downloads.py +44 -37
- ultralytics/utils/errors.py +1 -1
- ultralytics/utils/files.py +72 -38
- ultralytics/utils/instance.py +41 -19
- ultralytics/utils/loss.py +84 -56
- ultralytics/utils/metrics.py +61 -56
- ultralytics/utils/ops.py +94 -89
- ultralytics/utils/patches.py +30 -14
- ultralytics/utils/plotting.py +600 -269
- ultralytics/utils/tal.py +67 -26
- ultralytics/utils/torch_utils.py +302 -102
- ultralytics/utils/triton.py +2 -1
- ultralytics/utils/tuner.py +21 -12
- ultralytics-8.3.62.dist-info/METADATA +370 -0
- ultralytics-8.3.62.dist-info/RECORD +241 -0
- {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/WHEEL +1 -1
- ultralytics/data/explorer/__init__.py +0 -5
- ultralytics/data/explorer/explorer.py +0 -472
- ultralytics/data/explorer/gui/__init__.py +0 -1
- ultralytics/data/explorer/gui/dash.py +0 -268
- ultralytics/data/explorer/utils.py +0 -166
- ultralytics/models/fastsam/prompt.py +0 -357
- ultralytics-8.1.28.dist-info/METADATA +0 -373
- ultralytics-8.1.28.dist-info/RECORD +0 -197
- {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/LICENSE +0 -0
- {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/top_level.txt +0 -0
ultralytics/engine/exporter.py
CHANGED
@@ -1,55 +1,60 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
"""
|
3
|
-
Export a
|
3
|
+
Export a YOLO PyTorch model to other formats. TensorFlow exports authored by https://github.com/zldrobit.
|
4
4
|
|
5
5
|
Format | `format=argument` | Model
|
6
6
|
--- | --- | ---
|
7
|
-
PyTorch | - |
|
8
|
-
TorchScript | `torchscript` |
|
9
|
-
ONNX | `onnx` |
|
10
|
-
OpenVINO | `openvino` |
|
11
|
-
TensorRT | `engine` |
|
12
|
-
CoreML | `coreml` |
|
13
|
-
TensorFlow SavedModel | `saved_model` |
|
14
|
-
TensorFlow GraphDef | `pb` |
|
15
|
-
TensorFlow Lite | `tflite` |
|
16
|
-
TensorFlow Edge TPU | `edgetpu` |
|
17
|
-
TensorFlow.js | `tfjs` |
|
18
|
-
PaddlePaddle | `paddle` |
|
19
|
-
|
7
|
+
PyTorch | - | yolo11n.pt
|
8
|
+
TorchScript | `torchscript` | yolo11n.torchscript
|
9
|
+
ONNX | `onnx` | yolo11n.onnx
|
10
|
+
OpenVINO | `openvino` | yolo11n_openvino_model/
|
11
|
+
TensorRT | `engine` | yolo11n.engine
|
12
|
+
CoreML | `coreml` | yolo11n.mlpackage
|
13
|
+
TensorFlow SavedModel | `saved_model` | yolo11n_saved_model/
|
14
|
+
TensorFlow GraphDef | `pb` | yolo11n.pb
|
15
|
+
TensorFlow Lite | `tflite` | yolo11n.tflite
|
16
|
+
TensorFlow Edge TPU | `edgetpu` | yolo11n_edgetpu.tflite
|
17
|
+
TensorFlow.js | `tfjs` | yolo11n_web_model/
|
18
|
+
PaddlePaddle | `paddle` | yolo11n_paddle_model/
|
19
|
+
MNN | `mnn` | yolo11n.mnn
|
20
|
+
NCNN | `ncnn` | yolo11n_ncnn_model/
|
21
|
+
IMX | `imx` | yolo11n_imx_model/
|
20
22
|
|
21
23
|
Requirements:
|
22
24
|
$ pip install "ultralytics[export]"
|
23
25
|
|
24
26
|
Python:
|
25
27
|
from ultralytics import YOLO
|
26
|
-
model = YOLO('
|
28
|
+
model = YOLO('yolo11n.pt')
|
27
29
|
results = model.export(format='onnx')
|
28
30
|
|
29
31
|
CLI:
|
30
|
-
$ yolo mode=export model=
|
32
|
+
$ yolo mode=export model=yolo11n.pt format=onnx
|
31
33
|
|
32
34
|
Inference:
|
33
|
-
$ yolo predict model=
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
35
|
+
$ yolo predict model=yolo11n.pt # PyTorch
|
36
|
+
yolo11n.torchscript # TorchScript
|
37
|
+
yolo11n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
|
38
|
+
yolo11n_openvino_model # OpenVINO
|
39
|
+
yolo11n.engine # TensorRT
|
40
|
+
yolo11n.mlpackage # CoreML (macOS-only)
|
41
|
+
yolo11n_saved_model # TensorFlow SavedModel
|
42
|
+
yolo11n.pb # TensorFlow GraphDef
|
43
|
+
yolo11n.tflite # TensorFlow Lite
|
44
|
+
yolo11n_edgetpu.tflite # TensorFlow Edge TPU
|
45
|
+
yolo11n_paddle_model # PaddlePaddle
|
46
|
+
yolo11n.mnn # MNN
|
47
|
+
yolo11n_ncnn_model # NCNN
|
48
|
+
yolo11n_imx_model # IMX
|
45
49
|
|
46
50
|
TensorFlow.js:
|
47
51
|
$ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
|
48
52
|
$ npm install
|
49
|
-
$ ln -s ../../
|
53
|
+
$ ln -s ../../yolo11n_web_model public/yolo11n_web_model
|
50
54
|
$ npm start
|
51
55
|
"""
|
52
56
|
|
57
|
+
import gc
|
53
58
|
import json
|
54
59
|
import os
|
55
60
|
import shutil
|
@@ -63,18 +68,21 @@ from pathlib import Path
|
|
63
68
|
import numpy as np
|
64
69
|
import torch
|
65
70
|
|
66
|
-
from ultralytics.cfg import get_cfg
|
71
|
+
from ultralytics.cfg import TASK2DATA, get_cfg
|
72
|
+
from ultralytics.data import build_dataloader
|
67
73
|
from ultralytics.data.dataset import YOLODataset
|
68
|
-
from ultralytics.data.utils import check_det_dataset
|
74
|
+
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
69
75
|
from ultralytics.nn.autobackend import check_class_names, default_class_names
|
70
|
-
from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder
|
76
|
+
from ultralytics.nn.modules import C2f, Classify, Detect, RTDETRDecoder
|
71
77
|
from ultralytics.nn.tasks import DetectionModel, SegmentationModel, WorldModel
|
72
78
|
from ultralytics.utils import (
|
73
79
|
ARM64,
|
74
80
|
DEFAULT_CFG,
|
81
|
+
IS_JETSON,
|
75
82
|
LINUX,
|
76
83
|
LOGGER,
|
77
84
|
MACOS,
|
85
|
+
PYTHON_VERSION,
|
78
86
|
ROOT,
|
79
87
|
WINDOWS,
|
80
88
|
__version__,
|
@@ -83,33 +91,57 @@ from ultralytics.utils import (
|
|
83
91
|
get_default_args,
|
84
92
|
yaml_save,
|
85
93
|
)
|
86
|
-
from ultralytics.utils.checks import
|
87
|
-
from ultralytics.utils.downloads import attempt_download_asset, get_github_assets
|
94
|
+
from ultralytics.utils.checks import check_imgsz, check_is_path_safe, check_requirements, check_version
|
95
|
+
from ultralytics.utils.downloads import attempt_download_asset, get_github_assets, safe_download
|
88
96
|
from ultralytics.utils.files import file_size, spaces_in_path
|
89
97
|
from ultralytics.utils.ops import Profile
|
90
|
-
from ultralytics.utils.torch_utils import TORCH_1_13, get_latest_opset, select_device
|
98
|
+
from ultralytics.utils.torch_utils import TORCH_1_13, get_latest_opset, select_device
|
91
99
|
|
92
100
|
|
93
101
|
def export_formats():
|
94
|
-
"""
|
95
|
-
import pandas
|
96
|
-
|
102
|
+
"""Ultralytics YOLO export formats."""
|
97
103
|
x = [
|
98
|
-
["PyTorch", "-", ".pt", True, True],
|
99
|
-
["TorchScript", "torchscript", ".torchscript", True, True],
|
100
|
-
["ONNX", "onnx", ".onnx", True, True],
|
101
|
-
["OpenVINO", "openvino", "_openvino_model", True, False],
|
102
|
-
["TensorRT", "engine", ".engine", False, True],
|
103
|
-
["CoreML", "coreml", ".mlpackage", True, False],
|
104
|
-
["TensorFlow SavedModel", "saved_model", "_saved_model", True, True],
|
105
|
-
["TensorFlow GraphDef", "pb", ".pb", True, True],
|
106
|
-
["TensorFlow Lite", "tflite", ".tflite", True, False],
|
107
|
-
["TensorFlow Edge TPU", "edgetpu", "_edgetpu.tflite", True, False],
|
108
|
-
["TensorFlow.js", "tfjs", "_web_model", True, False],
|
109
|
-
["PaddlePaddle", "paddle", "_paddle_model", True, True],
|
110
|
-
["
|
104
|
+
["PyTorch", "-", ".pt", True, True, []],
|
105
|
+
["TorchScript", "torchscript", ".torchscript", True, True, ["batch", "optimize"]],
|
106
|
+
["ONNX", "onnx", ".onnx", True, True, ["batch", "dynamic", "half", "opset", "simplify"]],
|
107
|
+
["OpenVINO", "openvino", "_openvino_model", True, False, ["batch", "dynamic", "half", "int8"]],
|
108
|
+
["TensorRT", "engine", ".engine", False, True, ["batch", "dynamic", "half", "int8", "simplify"]],
|
109
|
+
["CoreML", "coreml", ".mlpackage", True, False, ["batch", "half", "int8", "nms"]],
|
110
|
+
["TensorFlow SavedModel", "saved_model", "_saved_model", True, True, ["batch", "int8", "keras"]],
|
111
|
+
["TensorFlow GraphDef", "pb", ".pb", True, True, ["batch"]],
|
112
|
+
["TensorFlow Lite", "tflite", ".tflite", True, False, ["batch", "half", "int8"]],
|
113
|
+
["TensorFlow Edge TPU", "edgetpu", "_edgetpu.tflite", True, False, []],
|
114
|
+
["TensorFlow.js", "tfjs", "_web_model", True, False, ["batch", "half", "int8"]],
|
115
|
+
["PaddlePaddle", "paddle", "_paddle_model", True, True, ["batch"]],
|
116
|
+
["MNN", "mnn", ".mnn", True, True, ["batch", "half", "int8"]],
|
117
|
+
["NCNN", "ncnn", "_ncnn_model", True, True, ["batch", "half"]],
|
118
|
+
["IMX", "imx", "_imx_model", True, True, ["int8"]],
|
111
119
|
]
|
112
|
-
return
|
120
|
+
return dict(zip(["Format", "Argument", "Suffix", "CPU", "GPU", "Arguments"], zip(*x)))
|
121
|
+
|
122
|
+
|
123
|
+
def validate_args(format, passed_args, valid_args):
|
124
|
+
"""
|
125
|
+
Validates arguments based on format.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
format (str): The export format.
|
129
|
+
passed_args (Namespace): The arguments used during export.
|
130
|
+
valid_args (dict): List of valid arguments for the format.
|
131
|
+
|
132
|
+
Raises:
|
133
|
+
AssertionError: If an argument that's not supported by the export format is used, or if format doesn't have the supported arguments listed.
|
134
|
+
"""
|
135
|
+
# Only check valid usage of these args
|
136
|
+
export_args = ["half", "int8", "dynamic", "keras", "nms", "batch"]
|
137
|
+
|
138
|
+
assert valid_args is not None, f"ERROR ❌️ valid arguments for '{format}' not listed."
|
139
|
+
custom = {"batch": 1, "data": None, "device": None} # exporter defaults
|
140
|
+
default_args = get_cfg(DEFAULT_CFG, custom)
|
141
|
+
for arg in export_args:
|
142
|
+
not_default = getattr(passed_args, arg, None) != getattr(default_args, arg, None)
|
143
|
+
if not_default:
|
144
|
+
assert arg in valid_args, f"ERROR ❌️ argument '{arg}' is not supported for format='{format}'"
|
113
145
|
|
114
146
|
|
115
147
|
def gd_outputs(gd):
|
@@ -122,7 +154,7 @@ def gd_outputs(gd):
|
|
122
154
|
|
123
155
|
|
124
156
|
def try_export(inner_func):
|
125
|
-
"""
|
157
|
+
"""YOLO export decorator, i.e. @try_export."""
|
126
158
|
inner_args = get_default_args(inner_func)
|
127
159
|
|
128
160
|
def outer_func(*args, **kwargs):
|
@@ -134,7 +166,7 @@ def try_export(inner_func):
|
|
134
166
|
LOGGER.info(f"{prefix} export success ✅ {dt.t:.1f}s, saved as '{f}' ({file_size(f):.1f} MB)")
|
135
167
|
return f, model
|
136
168
|
except Exception as e:
|
137
|
-
LOGGER.
|
169
|
+
LOGGER.error(f"{prefix} export failure ❌ {dt.t:.1f}s: {e}")
|
138
170
|
raise e
|
139
171
|
|
140
172
|
return outer_func
|
@@ -159,48 +191,94 @@ class Exporter:
|
|
159
191
|
_callbacks (dict, optional): Dictionary of callback functions. Defaults to None.
|
160
192
|
"""
|
161
193
|
self.args = get_cfg(cfg, overrides)
|
162
|
-
if self.args.format.lower() in
|
194
|
+
if self.args.format.lower() in {"coreml", "mlmodel"}: # fix attempt for protobuf<3.20.x errors
|
163
195
|
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" # must run before TensorBoard callback
|
164
196
|
|
165
197
|
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
166
198
|
callbacks.add_integration_callbacks(self)
|
167
199
|
|
168
|
-
|
169
|
-
def __call__(self, model=None):
|
200
|
+
def __call__(self, model=None) -> str:
|
170
201
|
"""Returns list of exported files/dirs after running callbacks."""
|
171
202
|
self.run_callbacks("on_export_start")
|
172
203
|
t = time.time()
|
173
204
|
fmt = self.args.format.lower() # to lowercase
|
174
|
-
if fmt in
|
205
|
+
if fmt in {"tensorrt", "trt"}: # 'engine' aliases
|
175
206
|
fmt = "engine"
|
176
|
-
if fmt in
|
207
|
+
if fmt in {"mlmodel", "mlpackage", "mlprogram", "apple", "ios", "coreml"}: # 'coreml' aliases
|
177
208
|
fmt = "coreml"
|
178
|
-
|
209
|
+
fmts_dict = export_formats()
|
210
|
+
fmts = tuple(fmts_dict["Argument"][1:]) # available export formats
|
211
|
+
if fmt not in fmts:
|
212
|
+
import difflib
|
213
|
+
|
214
|
+
# Get the closest match if format is invalid
|
215
|
+
matches = difflib.get_close_matches(fmt, fmts, n=1, cutoff=0.6) # 60% similarity required to match
|
216
|
+
if not matches:
|
217
|
+
raise ValueError(f"Invalid export format='{fmt}'. Valid formats are {fmts}")
|
218
|
+
LOGGER.warning(f"WARNING ⚠️ Invalid export format='{fmt}', updating to format='{matches[0]}'")
|
219
|
+
fmt = matches[0]
|
179
220
|
flags = [x == fmt for x in fmts]
|
180
221
|
if sum(flags) != 1:
|
181
222
|
raise ValueError(f"Invalid export format='{fmt}'. Valid formats are {fmts}")
|
182
|
-
|
223
|
+
(
|
224
|
+
jit,
|
225
|
+
onnx,
|
226
|
+
xml,
|
227
|
+
engine,
|
228
|
+
coreml,
|
229
|
+
saved_model,
|
230
|
+
pb,
|
231
|
+
tflite,
|
232
|
+
edgetpu,
|
233
|
+
tfjs,
|
234
|
+
paddle,
|
235
|
+
mnn,
|
236
|
+
ncnn,
|
237
|
+
imx,
|
238
|
+
) = flags # export booleans
|
239
|
+
is_tf_format = any((saved_model, pb, tflite, edgetpu, tfjs))
|
183
240
|
|
184
241
|
# Device
|
242
|
+
dla = None
|
185
243
|
if fmt == "engine" and self.args.device is None:
|
186
244
|
LOGGER.warning("WARNING ⚠️ TensorRT requires GPU export, automatically assigning device=0")
|
187
245
|
self.args.device = "0"
|
246
|
+
if fmt == "engine" and "dla" in str(self.args.device): # convert int/list to str first
|
247
|
+
dla = self.args.device.split(":")[-1]
|
248
|
+
self.args.device = "0" # update device to "0"
|
249
|
+
assert dla in {"0", "1"}, f"Expected self.args.device='dla:0' or 'dla:1, but got {self.args.device}."
|
188
250
|
self.device = select_device("cpu" if self.args.device is None else self.args.device)
|
189
251
|
|
190
|
-
#
|
252
|
+
# Argument compatibility checks
|
253
|
+
fmt_keys = fmts_dict["Arguments"][flags.index(True) + 1]
|
254
|
+
validate_args(fmt, self.args, fmt_keys)
|
255
|
+
if imx and not self.args.int8:
|
256
|
+
LOGGER.warning("WARNING ⚠️ IMX only supports int8 export, setting int8=True.")
|
257
|
+
self.args.int8 = True
|
191
258
|
if not hasattr(model, "names"):
|
192
259
|
model.names = default_class_names()
|
193
260
|
model.names = check_class_names(model.names)
|
261
|
+
if self.args.half and self.args.int8:
|
262
|
+
LOGGER.warning("WARNING ⚠️ half=True and int8=True are mutually exclusive, setting half=False.")
|
263
|
+
self.args.half = False
|
194
264
|
if self.args.half and onnx and self.device.type == "cpu":
|
195
265
|
LOGGER.warning("WARNING ⚠️ half=True only compatible with GPU export, i.e. use device=0")
|
196
266
|
self.args.half = False
|
197
267
|
assert not self.args.dynamic, "half=True not compatible with dynamic=True, i.e. use only one."
|
198
268
|
self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size
|
269
|
+
if self.args.int8 and engine:
|
270
|
+
self.args.dynamic = True # enforce dynamic to export TensorRT INT8
|
199
271
|
if self.args.optimize:
|
200
272
|
assert not ncnn, "optimize=True not compatible with format='ncnn', i.e. use optimize=False"
|
201
273
|
assert self.device.type == "cpu", "optimize=True not compatible with cuda devices, i.e. use device='cpu'"
|
202
|
-
if
|
203
|
-
|
274
|
+
if self.args.int8 and tflite:
|
275
|
+
assert not getattr(model, "end2end", False), "TFLite INT8 export not supported for end2end models."
|
276
|
+
if edgetpu:
|
277
|
+
if not LINUX:
|
278
|
+
raise SystemError("Edge TPU export only supported on Linux. See https://coral.ai/docs/edgetpu/compiler")
|
279
|
+
elif self.args.batch != 1: # see github.com/ultralytics/ultralytics/pull/13420
|
280
|
+
LOGGER.warning("WARNING ⚠️ Edge TPU export requires batch size 1, setting batch=1.")
|
281
|
+
self.args.batch = 1
|
204
282
|
if isinstance(model, WorldModel):
|
205
283
|
LOGGER.warning(
|
206
284
|
"WARNING ⚠️ YOLOWorld (original version) export is not supported to any format.\n"
|
@@ -208,6 +286,13 @@ class Exporter:
|
|
208
286
|
"(torchscript, onnx, openvino, engine, coreml) formats. "
|
209
287
|
"See https://docs.ultralytics.com/models/yolo-world for details."
|
210
288
|
)
|
289
|
+
model.clip_model = None # openvino int8 export error: https://github.com/ultralytics/ultralytics/pull/18445
|
290
|
+
if self.args.int8 and not self.args.data:
|
291
|
+
self.args.data = DEFAULT_CFG.data or TASK2DATA[getattr(model, "task", "detect")] # assign default data
|
292
|
+
LOGGER.warning(
|
293
|
+
"WARNING ⚠️ INT8 export requires a missing 'data' arg for calibration. "
|
294
|
+
f"Using default 'data={self.args.data}'."
|
295
|
+
)
|
211
296
|
|
212
297
|
# Input
|
213
298
|
im = torch.zeros(self.args.batch, 3, *self.imgsz).to(self.device)
|
@@ -224,14 +309,31 @@ class Exporter:
|
|
224
309
|
model.eval()
|
225
310
|
model.float()
|
226
311
|
model = model.fuse()
|
312
|
+
|
313
|
+
if imx:
|
314
|
+
from ultralytics.utils.torch_utils import FXModel
|
315
|
+
|
316
|
+
model = FXModel(model)
|
227
317
|
for m in model.modules():
|
318
|
+
if isinstance(m, Classify):
|
319
|
+
m.export = True
|
228
320
|
if isinstance(m, (Detect, RTDETRDecoder)): # includes all Detect subclasses like Segment, Pose, OBB
|
229
321
|
m.dynamic = self.args.dynamic
|
230
322
|
m.export = True
|
231
323
|
m.format = self.args.format
|
232
|
-
|
324
|
+
m.max_det = self.args.max_det
|
325
|
+
elif isinstance(m, C2f) and not is_tf_format:
|
233
326
|
# EdgeTPU does not support FlexSplitV while split provides cleaner ONNX graph
|
234
327
|
m.forward = m.forward_split
|
328
|
+
if isinstance(m, Detect) and imx:
|
329
|
+
from ultralytics.utils.tal import make_anchors
|
330
|
+
|
331
|
+
m.anchors, m.strides = (
|
332
|
+
x.transpose(0, 1)
|
333
|
+
for x in make_anchors(
|
334
|
+
torch.cat([s / m.stride.unsqueeze(-1) for s in self.imgsz], dim=1), m.stride, 0.5
|
335
|
+
)
|
336
|
+
)
|
235
337
|
|
236
338
|
y = None
|
237
339
|
for _ in range(2):
|
@@ -255,7 +357,7 @@ class Exporter:
|
|
255
357
|
)
|
256
358
|
self.pretty_name = Path(self.model.yaml.get("yaml_file", self.file)).stem.replace("yolo", "YOLO")
|
257
359
|
data = model.args["data"] if hasattr(model, "args") and isinstance(model.args, dict) else ""
|
258
|
-
description = f
|
360
|
+
description = f"Ultralytics {self.pretty_name} model {f'trained on {data}' if data else ''}"
|
259
361
|
self.metadata = {
|
260
362
|
"description": description,
|
261
363
|
"author": "Ultralytics",
|
@@ -268,13 +370,14 @@ class Exporter:
|
|
268
370
|
"batch": self.args.batch,
|
269
371
|
"imgsz": self.imgsz,
|
270
372
|
"names": model.names,
|
373
|
+
"args": {k: v for k, v in self.args if k in fmt_keys},
|
271
374
|
} # model metadata
|
272
375
|
if model.task == "pose":
|
273
376
|
self.metadata["kpt_shape"] = model.model[-1].kpt_shape
|
274
377
|
|
275
378
|
LOGGER.info(
|
276
379
|
f"\n{colorstr('PyTorch:')} starting from '{file}' with input shape {tuple(im.shape)} BCHW and "
|
277
|
-
f
|
380
|
+
f"output shape(s) {self.output_shape} ({file_size(file):.1f} MB)"
|
278
381
|
)
|
279
382
|
|
280
383
|
# Exports
|
@@ -282,14 +385,14 @@ class Exporter:
|
|
282
385
|
if jit or ncnn: # TorchScript
|
283
386
|
f[0], _ = self.export_torchscript()
|
284
387
|
if engine: # TensorRT required before ONNX
|
285
|
-
f[1], _ = self.export_engine()
|
388
|
+
f[1], _ = self.export_engine(dla=dla)
|
286
389
|
if onnx: # ONNX
|
287
390
|
f[2], _ = self.export_onnx()
|
288
391
|
if xml: # OpenVINO
|
289
392
|
f[3], _ = self.export_openvino()
|
290
393
|
if coreml: # CoreML
|
291
394
|
f[4], _ = self.export_coreml()
|
292
|
-
if
|
395
|
+
if is_tf_format: # TensorFlow formats
|
293
396
|
self.args.int8 |= edgetpu
|
294
397
|
f[5], keras_model = self.export_saved_model()
|
295
398
|
if pb or tfjs: # pb prerequisite to tfjs
|
@@ -302,8 +405,12 @@ class Exporter:
|
|
302
405
|
f[9], _ = self.export_tfjs()
|
303
406
|
if paddle: # PaddlePaddle
|
304
407
|
f[10], _ = self.export_paddle()
|
408
|
+
if mnn: # MNN
|
409
|
+
f[11], _ = self.export_mnn()
|
305
410
|
if ncnn: # NCNN
|
306
|
-
f[
|
411
|
+
f[12], _ = self.export_ncnn()
|
412
|
+
if imx:
|
413
|
+
f[13], _ = self.export_imx()
|
307
414
|
|
308
415
|
# Finish
|
309
416
|
f = [str(x) for x in f if x] # filter out '' and None
|
@@ -320,19 +427,42 @@ class Exporter:
|
|
320
427
|
predict_data = f"data={data}" if model.task == "segment" and fmt == "pb" else ""
|
321
428
|
q = "int8" if self.args.int8 else "half" if self.args.half else "" # quantization
|
322
429
|
LOGGER.info(
|
323
|
-
f
|
430
|
+
f"\nExport complete ({time.time() - t:.1f}s)"
|
324
431
|
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
|
325
|
-
f
|
326
|
-
f
|
327
|
-
f
|
432
|
+
f"\nPredict: yolo predict task={model.task} model={f} imgsz={imgsz} {q} {predict_data}"
|
433
|
+
f"\nValidate: yolo val task={model.task} model={f} imgsz={imgsz} data={data} {q} {s}"
|
434
|
+
f"\nVisualize: https://netron.app"
|
328
435
|
)
|
329
436
|
|
330
437
|
self.run_callbacks("on_export_end")
|
331
438
|
return f # return list of exported files/dirs
|
332
439
|
|
440
|
+
def get_int8_calibration_dataloader(self, prefix=""):
|
441
|
+
"""Build and return a dataloader suitable for calibration of INT8 models."""
|
442
|
+
LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'")
|
443
|
+
data = (check_cls_dataset if self.model.task == "classify" else check_det_dataset)(self.args.data)
|
444
|
+
# TensorRT INT8 calibration should use 2x batch size
|
445
|
+
batch = self.args.batch * (2 if self.args.format == "engine" else 1)
|
446
|
+
dataset = YOLODataset(
|
447
|
+
data[self.args.split or "val"],
|
448
|
+
data=data,
|
449
|
+
task=self.model.task,
|
450
|
+
imgsz=self.imgsz[0],
|
451
|
+
augment=False,
|
452
|
+
batch_size=batch,
|
453
|
+
)
|
454
|
+
n = len(dataset)
|
455
|
+
if n < self.args.batch:
|
456
|
+
raise ValueError(
|
457
|
+
f"The calibration dataset ({n} images) must have at least as many images as the batch size ('batch={self.args.batch}')."
|
458
|
+
)
|
459
|
+
elif n < 300:
|
460
|
+
LOGGER.warning(f"{prefix} WARNING ⚠️ >300 images recommended for INT8 calibration, found {n} images.")
|
461
|
+
return build_dataloader(dataset, batch=batch, workers=0) # required for batch loading
|
462
|
+
|
333
463
|
@try_export
|
334
464
|
def export_torchscript(self, prefix=colorstr("TorchScript:")):
|
335
|
-
"""
|
465
|
+
"""YOLO TorchScript model export."""
|
336
466
|
LOGGER.info(f"\n{prefix} starting export with torch {torch.__version__}...")
|
337
467
|
f = self.file.with_suffix(".torchscript")
|
338
468
|
|
@@ -349,12 +479,10 @@ class Exporter:
|
|
349
479
|
|
350
480
|
@try_export
|
351
481
|
def export_onnx(self, prefix=colorstr("ONNX:")):
|
352
|
-
"""
|
482
|
+
"""YOLO ONNX export."""
|
353
483
|
requirements = ["onnx>=1.12.0"]
|
354
484
|
if self.args.simplify:
|
355
|
-
requirements += ["
|
356
|
-
if ARM64:
|
357
|
-
check_requirements("cmake") # 'cmake' is needed to build onnxsim on aarch64
|
485
|
+
requirements += ["onnxslim", "onnxruntime" + ("-gpu" if torch.cuda.is_available() else "")]
|
358
486
|
check_requirements(requirements)
|
359
487
|
import onnx # noqa
|
360
488
|
|
@@ -386,19 +514,17 @@ class Exporter:
|
|
386
514
|
|
387
515
|
# Checks
|
388
516
|
model_onnx = onnx.load(f) # load onnx model
|
389
|
-
# onnx.checker.check_model(model_onnx) # check onnx model
|
390
517
|
|
391
518
|
# Simplify
|
392
519
|
if self.args.simplify:
|
393
520
|
try:
|
394
|
-
import
|
521
|
+
import onnxslim
|
522
|
+
|
523
|
+
LOGGER.info(f"{prefix} slimming with onnxslim {onnxslim.__version__}...")
|
524
|
+
model_onnx = onnxslim.slim(model_onnx)
|
395
525
|
|
396
|
-
LOGGER.info(f"{prefix} simplifying with onnxsim {onnxsim.__version__}...")
|
397
|
-
# subprocess.run(f'onnxsim "{f}" "{f}"', shell=True)
|
398
|
-
model_onnx, check = onnxsim.simplify(model_onnx)
|
399
|
-
assert check, "Simplified ONNX model could not be validated"
|
400
526
|
except Exception as e:
|
401
|
-
LOGGER.
|
527
|
+
LOGGER.warning(f"{prefix} simplifier failure: {e}")
|
402
528
|
|
403
529
|
# Metadata
|
404
530
|
for k, v in self.metadata.items():
|
@@ -410,21 +536,21 @@ class Exporter:
|
|
410
536
|
|
411
537
|
@try_export
|
412
538
|
def export_openvino(self, prefix=colorstr("OpenVINO:")):
|
413
|
-
"""
|
414
|
-
check_requirements("openvino>=2024.
|
539
|
+
"""YOLO OpenVINO export."""
|
540
|
+
check_requirements("openvino>=2024.5.0")
|
415
541
|
import openvino as ov
|
416
542
|
|
417
543
|
LOGGER.info(f"\n{prefix} starting export with openvino {ov.__version__}...")
|
418
544
|
assert TORCH_1_13, f"OpenVINO export requires torch>=1.13.0 but torch=={torch.__version__} is installed"
|
419
545
|
ov_model = ov.convert_model(
|
420
|
-
self.model
|
546
|
+
self.model,
|
421
547
|
input=None if self.args.dynamic else [self.im.shape],
|
422
548
|
example_input=self.im,
|
423
549
|
)
|
424
550
|
|
425
551
|
def serialize(ov_model, file):
|
426
552
|
"""Set RT info, serialize and save metadata YAML."""
|
427
|
-
ov_model.set_rt_info("
|
553
|
+
ov_model.set_rt_info("YOLO", ["model_info", "model_type"])
|
428
554
|
ov_model.set_rt_info(True, ["model_info", "reverse_input_channels"])
|
429
555
|
ov_model.set_rt_info(114, ["model_info", "pad_value"])
|
430
556
|
ov_model.set_rt_info([255.0], ["model_info", "scale_values"])
|
@@ -439,37 +565,21 @@ class Exporter:
|
|
439
565
|
if self.args.int8:
|
440
566
|
fq = str(self.file).replace(self.file.suffix, f"_int8_openvino_model{os.sep}")
|
441
567
|
fq_ov = str(Path(fq) / self.file.with_suffix(".xml").name)
|
442
|
-
|
443
|
-
self.args.data = DEFAULT_CFG.data or "coco128.yaml"
|
444
|
-
LOGGER.warning(
|
445
|
-
f"{prefix} WARNING ⚠️ INT8 export requires a missing 'data' arg for calibration. "
|
446
|
-
f"Using default 'data={self.args.data}'."
|
447
|
-
)
|
448
|
-
check_requirements("nncf>=2.8.0")
|
568
|
+
check_requirements("nncf>=2.14.0")
|
449
569
|
import nncf
|
450
570
|
|
451
|
-
def transform_fn(data_item):
|
571
|
+
def transform_fn(data_item) -> np.ndarray:
|
452
572
|
"""Quantization transform function."""
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
im = data_item["img"].numpy().astype(np.float32) / 255.0 # uint8 to fp16/32 and 0 - 255 to 0.0 - 1.0
|
573
|
+
data_item: torch.Tensor = data_item["img"] if isinstance(data_item, dict) else data_item
|
574
|
+
assert data_item.dtype == torch.uint8, "Input image must be uint8 for the quantization preprocessing"
|
575
|
+
im = data_item.numpy().astype(np.float32) / 255.0 # uint8 to fp16/32 and 0 - 255 to 0.0 - 1.0
|
457
576
|
return np.expand_dims(im, 0) if im.ndim == 3 else im
|
458
577
|
|
459
578
|
# Generate calibration data for integer quantization
|
460
|
-
LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'")
|
461
|
-
data = check_det_dataset(self.args.data)
|
462
|
-
dataset = YOLODataset(data["val"], data=data, imgsz=self.imgsz[0], augment=False)
|
463
|
-
n = len(dataset)
|
464
|
-
if n < 300:
|
465
|
-
LOGGER.warning(f"{prefix} WARNING ⚠️ >300 images recommended for INT8 calibration, found {n} images.")
|
466
|
-
quantization_dataset = nncf.Dataset(dataset, transform_fn)
|
467
|
-
|
468
579
|
ignored_scope = None
|
469
580
|
if isinstance(self.model.model[-1], Detect):
|
470
581
|
# Includes all Detect subclasses like Segment, Pose, OBB, WorldDetect
|
471
582
|
head_module_name = ".".join(list(self.model.named_modules())[-1][0].split(".")[:2])
|
472
|
-
|
473
583
|
ignored_scope = nncf.IgnoredScope( # ignore operations
|
474
584
|
patterns=[
|
475
585
|
f".*{head_module_name}/.*/Add",
|
@@ -482,7 +592,10 @@ class Exporter:
|
|
482
592
|
)
|
483
593
|
|
484
594
|
quantized_ov_model = nncf.quantize(
|
485
|
-
ov_model,
|
595
|
+
model=ov_model,
|
596
|
+
calibration_dataset=nncf.Dataset(self.get_int8_calibration_dataloader(prefix), transform_fn),
|
597
|
+
preset=nncf.QuantizationPreset.MIXED,
|
598
|
+
ignored_scope=ignored_scope,
|
486
599
|
)
|
487
600
|
serialize(quantized_ov_model, fq_ov)
|
488
601
|
return fq, None
|
@@ -495,8 +608,8 @@ class Exporter:
|
|
495
608
|
|
496
609
|
@try_export
|
497
610
|
def export_paddle(self, prefix=colorstr("PaddlePaddle:")):
|
498
|
-
"""
|
499
|
-
check_requirements(("paddlepaddle", "x2paddle"))
|
611
|
+
"""YOLO Paddle export."""
|
612
|
+
check_requirements(("paddlepaddle-gpu" if torch.cuda.is_available() else "paddlepaddle", "x2paddle"))
|
500
613
|
import x2paddle # noqa
|
501
614
|
from x2paddle.convert import pytorch2paddle # noqa
|
502
615
|
|
@@ -507,11 +620,34 @@ class Exporter:
|
|
507
620
|
yaml_save(Path(f) / "metadata.yaml", self.metadata) # add metadata.yaml
|
508
621
|
return f, None
|
509
622
|
|
623
|
+
@try_export
|
624
|
+
def export_mnn(self, prefix=colorstr("MNN:")):
|
625
|
+
"""YOLOv8 MNN export using MNN https://github.com/alibaba/MNN."""
|
626
|
+
f_onnx, _ = self.export_onnx() # get onnx model first
|
627
|
+
|
628
|
+
check_requirements("MNN>=2.9.6")
|
629
|
+
import MNN # noqa
|
630
|
+
from MNN.tools import mnnconvert
|
631
|
+
|
632
|
+
# Setup and checks
|
633
|
+
LOGGER.info(f"\n{prefix} starting export with MNN {MNN.version()}...")
|
634
|
+
assert Path(f_onnx).exists(), f"failed to export ONNX file: {f_onnx}"
|
635
|
+
f = str(self.file.with_suffix(".mnn")) # MNN model file
|
636
|
+
args = ["", "-f", "ONNX", "--modelFile", f_onnx, "--MNNModel", f, "--bizCode", json.dumps(self.metadata)]
|
637
|
+
if self.args.int8:
|
638
|
+
args.extend(("--weightQuantBits", "8"))
|
639
|
+
if self.args.half:
|
640
|
+
args.append("--fp16")
|
641
|
+
mnnconvert.convert(args)
|
642
|
+
# remove scratch file for model convert optimize
|
643
|
+
convert_scratch = Path(self.file.parent / ".__convert_external_data.bin")
|
644
|
+
if convert_scratch.exists():
|
645
|
+
convert_scratch.unlink()
|
646
|
+
return f, None
|
647
|
+
|
510
648
|
@try_export
|
511
649
|
def export_ncnn(self, prefix=colorstr("NCNN:")):
|
512
|
-
"""
|
513
|
-
YOLOv8 NCNN export using PNNX https://github.com/pnnx/pnnx.
|
514
|
-
"""
|
650
|
+
"""YOLO NCNN export using PNNX https://github.com/pnnx/pnnx."""
|
515
651
|
check_requirements("ncnn")
|
516
652
|
import ncnn # noqa
|
517
653
|
|
@@ -520,7 +656,7 @@ class Exporter:
|
|
520
656
|
f_ts = self.file.with_suffix(".torchscript")
|
521
657
|
|
522
658
|
name = Path("pnnx.exe" if WINDOWS else "pnnx") # PNNX filename
|
523
|
-
pnnx = name if name.is_file() else ROOT / name
|
659
|
+
pnnx = name if name.is_file() else (ROOT / name)
|
524
660
|
if not pnnx.is_file():
|
525
661
|
LOGGER.warning(
|
526
662
|
f"{prefix} WARNING ⚠️ PNNX not found. Attempting to download binary file from "
|
@@ -528,31 +664,32 @@ class Exporter:
|
|
528
664
|
f"or in {ROOT}. See PNNX repo for full installation instructions."
|
529
665
|
)
|
530
666
|
system = "macos" if MACOS else "windows" if WINDOWS else "linux-aarch64" if ARM64 else "linux"
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
667
|
+
try:
|
668
|
+
release, assets = get_github_assets(repo="pnnx/pnnx")
|
669
|
+
asset = [x for x in assets if f"{system}.zip" in x][0]
|
670
|
+
assert isinstance(asset, str), "Unable to retrieve PNNX repo assets" # i.e. pnnx-20240410-macos.zip
|
671
|
+
LOGGER.info(f"{prefix} successfully found latest PNNX asset file {asset}")
|
672
|
+
except Exception as e:
|
673
|
+
release = "20240410"
|
674
|
+
asset = f"pnnx-{release}-{system}.zip"
|
675
|
+
LOGGER.warning(f"{prefix} WARNING ⚠️ PNNX GitHub assets not found: {e}, using default {asset}")
|
676
|
+
unzip_dir = safe_download(f"https://github.com/pnnx/pnnx/releases/download/{release}/{asset}", delete=True)
|
677
|
+
if check_is_path_safe(Path.cwd(), unzip_dir): # avoid path traversal security vulnerability
|
678
|
+
shutil.move(src=unzip_dir / name, dst=pnnx) # move binary to ROOT
|
543
679
|
pnnx.chmod(0o777) # set read, write, and execute permissions for everyone
|
680
|
+
shutil.rmtree(unzip_dir) # delete unzip dir
|
544
681
|
|
545
682
|
ncnn_args = [
|
546
|
-
f
|
547
|
-
f
|
548
|
-
f
|
683
|
+
f"ncnnparam={f / 'model.ncnn.param'}",
|
684
|
+
f"ncnnbin={f / 'model.ncnn.bin'}",
|
685
|
+
f"ncnnpy={f / 'model_ncnn.py'}",
|
549
686
|
]
|
550
687
|
|
551
688
|
pnnx_args = [
|
552
|
-
f
|
553
|
-
f
|
554
|
-
f
|
555
|
-
f
|
689
|
+
f"pnnxparam={f / 'model.pnnx.param'}",
|
690
|
+
f"pnnxbin={f / 'model.pnnx.bin'}",
|
691
|
+
f"pnnxpy={f / 'model_pnnx.py'}",
|
692
|
+
f"pnnxonnx={f / 'model.pnnx.onnx'}",
|
556
693
|
]
|
557
694
|
|
558
695
|
cmd = [
|
@@ -578,16 +715,20 @@ class Exporter:
|
|
578
715
|
|
579
716
|
@try_export
|
580
717
|
def export_coreml(self, prefix=colorstr("CoreML:")):
|
581
|
-
"""
|
718
|
+
"""YOLO CoreML export."""
|
582
719
|
mlmodel = self.args.format.lower() == "mlmodel" # legacy *.mlmodel export format requested
|
583
720
|
check_requirements("coremltools>=6.0,<=6.2" if mlmodel else "coremltools>=7.0")
|
584
721
|
import coremltools as ct # noqa
|
585
722
|
|
586
723
|
LOGGER.info(f"\n{prefix} starting export with coremltools {ct.__version__}...")
|
587
724
|
assert not WINDOWS, "CoreML export is not supported on Windows, please run on macOS or Linux."
|
725
|
+
assert self.args.batch == 1, "CoreML batch sizes > 1 are not supported. Please retry at 'batch=1'."
|
588
726
|
f = self.file.with_suffix(".mlmodel" if mlmodel else ".mlpackage")
|
589
727
|
if f.is_dir():
|
590
728
|
shutil.rmtree(f)
|
729
|
+
if self.args.nms and getattr(self.model, "end2end", False):
|
730
|
+
LOGGER.warning(f"{prefix} WARNING ⚠️ 'nms=True' is not available for end2end models. Forcing 'nms=False'.")
|
731
|
+
self.args.nms = False
|
591
732
|
|
592
733
|
bias = [0.0, 0.0, 0.0]
|
593
734
|
scale = 1 / 255
|
@@ -650,40 +791,61 @@ class Exporter:
|
|
650
791
|
return f, ct_model
|
651
792
|
|
652
793
|
@try_export
|
653
|
-
def export_engine(self, prefix=colorstr("TensorRT:")):
|
654
|
-
"""
|
794
|
+
def export_engine(self, dla=None, prefix=colorstr("TensorRT:")):
|
795
|
+
"""YOLO TensorRT export https://developer.nvidia.com/tensorrt."""
|
655
796
|
assert self.im.device.type != "cpu", "export running on CPU but must be on GPU, i.e. use 'device=0'"
|
656
|
-
f_onnx, _ = self.export_onnx() # run before
|
797
|
+
f_onnx, _ = self.export_onnx() # run before TRT import https://github.com/ultralytics/ultralytics/issues/7016
|
657
798
|
|
658
799
|
try:
|
659
800
|
import tensorrt as trt # noqa
|
660
801
|
except ImportError:
|
661
802
|
if LINUX:
|
662
|
-
check_requirements("
|
803
|
+
check_requirements("tensorrt>7.0.0,!=10.1.0")
|
663
804
|
import tensorrt as trt # noqa
|
805
|
+
check_version(trt.__version__, ">=7.0.0", hard=True)
|
806
|
+
check_version(trt.__version__, "!=10.1.0", msg="https://github.com/ultralytics/ultralytics/pull/14239")
|
664
807
|
|
665
|
-
|
666
|
-
|
667
|
-
self.args.simplify = True
|
668
|
-
|
808
|
+
# Setup and checks
|
669
809
|
LOGGER.info(f"\n{prefix} starting export with TensorRT {trt.__version__}...")
|
810
|
+
is_trt10 = int(trt.__version__.split(".")[0]) >= 10 # is TensorRT >= 10
|
670
811
|
assert Path(f_onnx).exists(), f"failed to export ONNX file: {f_onnx}"
|
671
812
|
f = self.file.with_suffix(".engine") # TensorRT engine file
|
672
813
|
logger = trt.Logger(trt.Logger.INFO)
|
673
814
|
if self.args.verbose:
|
674
815
|
logger.min_severity = trt.Logger.Severity.VERBOSE
|
675
816
|
|
817
|
+
# Engine builder
|
676
818
|
builder = trt.Builder(logger)
|
677
819
|
config = builder.create_builder_config()
|
678
|
-
|
679
|
-
|
680
|
-
|
820
|
+
workspace = int(self.args.workspace * (1 << 30)) if self.args.workspace is not None else 0
|
821
|
+
if is_trt10 and workspace > 0:
|
822
|
+
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace)
|
823
|
+
elif workspace > 0: # TensorRT versions 7, 8
|
824
|
+
config.max_workspace_size = workspace
|
681
825
|
flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
682
826
|
network = builder.create_network(flag)
|
827
|
+
half = builder.platform_has_fast_fp16 and self.args.half
|
828
|
+
int8 = builder.platform_has_fast_int8 and self.args.int8
|
829
|
+
|
830
|
+
# Optionally switch to DLA if enabled
|
831
|
+
if dla is not None:
|
832
|
+
if not IS_JETSON:
|
833
|
+
raise ValueError("DLA is only available on NVIDIA Jetson devices")
|
834
|
+
LOGGER.info(f"{prefix} enabling DLA on core {dla}...")
|
835
|
+
if not self.args.half and not self.args.int8:
|
836
|
+
raise ValueError(
|
837
|
+
"DLA requires either 'half=True' (FP16) or 'int8=True' (INT8) to be enabled. Please enable one of them and try again."
|
838
|
+
)
|
839
|
+
config.default_device_type = trt.DeviceType.DLA
|
840
|
+
config.DLA_core = int(dla)
|
841
|
+
config.set_flag(trt.BuilderFlag.GPU_FALLBACK)
|
842
|
+
|
843
|
+
# Read ONNX file
|
683
844
|
parser = trt.OnnxParser(network, logger)
|
684
845
|
if not parser.parse_from_file(f_onnx):
|
685
846
|
raise RuntimeError(f"failed to load ONNX file: {f_onnx}")
|
686
847
|
|
848
|
+
# Network inputs
|
687
849
|
inputs = [network.get_input(i) for i in range(network.num_inputs)]
|
688
850
|
outputs = [network.get_output(i) for i in range(network.num_outputs)]
|
689
851
|
for inp in inputs:
|
@@ -696,61 +858,117 @@ class Exporter:
|
|
696
858
|
if shape[0] <= 1:
|
697
859
|
LOGGER.warning(f"{prefix} WARNING ⚠️ 'dynamic=True' model requires max batch size, i.e. 'batch=16'")
|
698
860
|
profile = builder.create_optimization_profile()
|
861
|
+
min_shape = (1, shape[1], 32, 32) # minimum input shape
|
862
|
+
max_shape = (*shape[:2], *(int(max(1, workspace) * d) for d in shape[2:])) # max input shape
|
699
863
|
for inp in inputs:
|
700
|
-
profile.set_shape(inp.name,
|
864
|
+
profile.set_shape(inp.name, min=min_shape, opt=shape, max=max_shape)
|
701
865
|
config.add_optimization_profile(profile)
|
702
866
|
|
703
|
-
LOGGER.info(
|
704
|
-
|
705
|
-
|
706
|
-
|
867
|
+
LOGGER.info(f"{prefix} building {'INT8' if int8 else 'FP' + ('16' if half else '32')} engine as {f}")
|
868
|
+
if int8:
|
869
|
+
config.set_flag(trt.BuilderFlag.INT8)
|
870
|
+
config.set_calibration_profile(profile)
|
871
|
+
config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED
|
872
|
+
|
873
|
+
class EngineCalibrator(trt.IInt8Calibrator):
|
874
|
+
def __init__(
|
875
|
+
self,
|
876
|
+
dataset, # ultralytics.data.build.InfiniteDataLoader
|
877
|
+
batch: int,
|
878
|
+
cache: str = "",
|
879
|
+
) -> None:
|
880
|
+
trt.IInt8Calibrator.__init__(self)
|
881
|
+
self.dataset = dataset
|
882
|
+
self.data_iter = iter(dataset)
|
883
|
+
self.algo = trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2
|
884
|
+
self.batch = batch
|
885
|
+
self.cache = Path(cache)
|
886
|
+
|
887
|
+
def get_algorithm(self) -> trt.CalibrationAlgoType:
|
888
|
+
"""Get the calibration algorithm to use."""
|
889
|
+
return self.algo
|
890
|
+
|
891
|
+
def get_batch_size(self) -> int:
|
892
|
+
"""Get the batch size to use for calibration."""
|
893
|
+
return self.batch or 1
|
894
|
+
|
895
|
+
def get_batch(self, names) -> list:
|
896
|
+
"""Get the next batch to use for calibration, as a list of device memory pointers."""
|
897
|
+
try:
|
898
|
+
im0s = next(self.data_iter)["img"] / 255.0
|
899
|
+
im0s = im0s.to("cuda") if im0s.device.type == "cpu" else im0s
|
900
|
+
return [int(im0s.data_ptr())]
|
901
|
+
except StopIteration:
|
902
|
+
# Return [] or None, signal to TensorRT there is no calibration data remaining
|
903
|
+
return None
|
904
|
+
|
905
|
+
def read_calibration_cache(self) -> bytes:
|
906
|
+
"""Use existing cache instead of calibrating again, otherwise, implicitly return None."""
|
907
|
+
if self.cache.exists() and self.cache.suffix == ".cache":
|
908
|
+
return self.cache.read_bytes()
|
909
|
+
|
910
|
+
def write_calibration_cache(self, cache) -> None:
|
911
|
+
"""Write calibration cache to disk."""
|
912
|
+
_ = self.cache.write_bytes(cache)
|
913
|
+
|
914
|
+
# Load dataset w/ builder (for batching) and calibrate
|
915
|
+
config.int8_calibrator = EngineCalibrator(
|
916
|
+
dataset=self.get_int8_calibration_dataloader(prefix),
|
917
|
+
batch=2 * self.args.batch, # TensorRT INT8 calibration should use 2x batch size
|
918
|
+
cache=str(self.file.with_suffix(".cache")),
|
919
|
+
)
|
920
|
+
|
921
|
+
elif half:
|
707
922
|
config.set_flag(trt.BuilderFlag.FP16)
|
708
923
|
|
924
|
+
# Free CUDA memory
|
709
925
|
del self.model
|
926
|
+
gc.collect()
|
710
927
|
torch.cuda.empty_cache()
|
711
928
|
|
712
929
|
# Write file
|
713
|
-
|
930
|
+
build = builder.build_serialized_network if is_trt10 else builder.build_engine
|
931
|
+
with build(network, config) as engine, open(f, "wb") as t:
|
714
932
|
# Metadata
|
715
933
|
meta = json.dumps(self.metadata)
|
716
934
|
t.write(len(meta).to_bytes(4, byteorder="little", signed=True))
|
717
935
|
t.write(meta.encode())
|
718
936
|
# Model
|
719
|
-
t.write(engine.serialize())
|
937
|
+
t.write(engine if is_trt10 else engine.serialize())
|
720
938
|
|
721
939
|
return f, None
|
722
940
|
|
723
941
|
@try_export
|
724
942
|
def export_saved_model(self, prefix=colorstr("TensorFlow SavedModel:")):
|
725
|
-
"""
|
943
|
+
"""YOLO TensorFlow SavedModel export."""
|
726
944
|
cuda = torch.cuda.is_available()
|
727
945
|
try:
|
728
946
|
import tensorflow as tf # noqa
|
729
947
|
except ImportError:
|
730
948
|
suffix = "-macos" if MACOS else "-aarch64" if ARM64 else "" if cuda else "-cpu"
|
731
|
-
version = "
|
949
|
+
version = ">=2.0.0"
|
732
950
|
check_requirements(f"tensorflow{suffix}{version}")
|
733
951
|
import tensorflow as tf # noqa
|
734
|
-
if ARM64:
|
735
|
-
check_requirements("cmake") # 'cmake' is needed to build onnxsim on aarch64
|
736
952
|
check_requirements(
|
737
953
|
(
|
954
|
+
"keras", # required by 'onnx2tf' package
|
955
|
+
"tf_keras", # required by 'onnx2tf' package
|
956
|
+
"sng4onnx>=1.0.1", # required by 'onnx2tf' package
|
957
|
+
"onnx_graphsurgeon>=0.3.26", # required by 'onnx2tf' package
|
738
958
|
"onnx>=1.12.0",
|
739
|
-
"onnx2tf
|
740
|
-
"
|
741
|
-
"
|
742
|
-
"onnx_graphsurgeon>=0.3.26",
|
743
|
-
"tflite_support",
|
959
|
+
"onnx2tf>1.17.5,<=1.26.3",
|
960
|
+
"onnxslim>=0.1.31",
|
961
|
+
"tflite_support<=0.4.3" if IS_JETSON else "tflite_support", # fix ImportError 'GLIBCXX_3.4.29'
|
744
962
|
"flatbuffers>=23.5.26,<100", # update old 'flatbuffers' included inside tensorflow package
|
745
963
|
"onnxruntime-gpu" if cuda else "onnxruntime",
|
746
964
|
),
|
747
|
-
cmds="--extra-index-url https://pypi.ngc.nvidia.com",
|
748
|
-
)
|
965
|
+
cmds="--extra-index-url https://pypi.ngc.nvidia.com", # onnx_graphsurgeon only on NVIDIA
|
966
|
+
)
|
749
967
|
|
750
968
|
LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
|
751
969
|
check_version(
|
752
970
|
tf.__version__,
|
753
|
-
"
|
971
|
+
">=2.0.0",
|
754
972
|
name="tensorflow",
|
755
973
|
verbose=True,
|
756
974
|
msg="https://github.com/ultralytics/ultralytics/issues/5161",
|
@@ -771,39 +989,29 @@ class Exporter:
|
|
771
989
|
f_onnx, _ = self.export_onnx()
|
772
990
|
|
773
991
|
# Export to TF
|
774
|
-
tmp_file = f / "tmp_tflite_int8_calibration_images.npy" # int8 calibration images file
|
775
992
|
np_data = None
|
776
993
|
if self.args.int8:
|
777
|
-
|
994
|
+
tmp_file = f / "tmp_tflite_int8_calibration_images.npy" # int8 calibration images file
|
778
995
|
if self.args.data:
|
779
|
-
# Generate calibration data for integer quantization
|
780
|
-
LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'")
|
781
|
-
data = check_det_dataset(self.args.data)
|
782
|
-
dataset = YOLODataset(data["val"], data=data, imgsz=self.imgsz[0], augment=False)
|
783
|
-
images = []
|
784
|
-
for i, batch in enumerate(dataset):
|
785
|
-
if i >= 100: # maximum number of calibration images
|
786
|
-
break
|
787
|
-
im = batch["img"].permute(1, 2, 0)[None] # list to nparray, CHW to BHWC
|
788
|
-
images.append(im)
|
789
996
|
f.mkdir()
|
790
|
-
images =
|
791
|
-
|
792
|
-
|
793
|
-
|
997
|
+
images = [batch["img"] for batch in self.get_int8_calibration_dataloader(prefix)]
|
998
|
+
images = torch.nn.functional.interpolate(torch.cat(images, 0).float(), size=self.imgsz).permute(
|
999
|
+
0, 2, 3, 1
|
1000
|
+
)
|
1001
|
+
np.save(str(tmp_file), images.numpy().astype(np.float32)) # BHWC
|
794
1002
|
np_data = [["images", tmp_file, [[[[0, 0, 0]]]], [[[[255, 255, 255]]]]]]
|
795
|
-
else:
|
796
|
-
verbosity = "error"
|
797
1003
|
|
798
1004
|
LOGGER.info(f"{prefix} starting TFLite export with onnx2tf {onnx2tf.__version__}...")
|
799
|
-
onnx2tf.convert(
|
1005
|
+
keras_model = onnx2tf.convert(
|
800
1006
|
input_onnx_file_path=f_onnx,
|
801
1007
|
output_folder_path=str(f),
|
802
1008
|
not_use_onnxsim=True,
|
803
|
-
verbosity=
|
1009
|
+
verbosity="error", # note INT8-FP16 activation bug https://github.com/ultralytics/ultralytics/issues/15873
|
804
1010
|
output_integer_quantized_tflite=self.args.int8,
|
805
1011
|
quant_type="per-tensor", # "per-tensor" (faster) or "per-channel" (slower but more accurate)
|
806
1012
|
custom_input_op_name_np_data_path=np_data,
|
1013
|
+
disable_group_convolution=True, # for end-to-end model compatibility
|
1014
|
+
enable_batchmatmul_unfold=True, # for end-to-end model compatibility
|
807
1015
|
)
|
808
1016
|
yaml_save(f / "metadata.yaml", self.metadata) # add metadata.yaml
|
809
1017
|
|
@@ -819,11 +1027,11 @@ class Exporter:
|
|
819
1027
|
for file in f.rglob("*.tflite"):
|
820
1028
|
f.unlink() if "quant_with_int16_act.tflite" in str(f) else self._add_tflite_metadata(file)
|
821
1029
|
|
822
|
-
return str(f), tf.saved_model.load(f, tags=None, options=None)
|
1030
|
+
return str(f), keras_model # or keras_model = tf.saved_model.load(f, tags=None, options=None)
|
823
1031
|
|
824
1032
|
@try_export
|
825
1033
|
def export_pb(self, keras_model, prefix=colorstr("TensorFlow GraphDef:")):
|
826
|
-
"""
|
1034
|
+
"""YOLO TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow."""
|
827
1035
|
import tensorflow as tf # noqa
|
828
1036
|
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa
|
829
1037
|
|
@@ -839,7 +1047,8 @@ class Exporter:
|
|
839
1047
|
|
840
1048
|
@try_export
|
841
1049
|
def export_tflite(self, keras_model, nms, agnostic_nms, prefix=colorstr("TensorFlow Lite:")):
|
842
|
-
"""
|
1050
|
+
"""YOLO TensorFlow Lite export."""
|
1051
|
+
# BUG https://github.com/ultralytics/ultralytics/issues/13436
|
843
1052
|
import tensorflow as tf # noqa
|
844
1053
|
|
845
1054
|
LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
|
@@ -854,7 +1063,7 @@ class Exporter:
|
|
854
1063
|
|
855
1064
|
@try_export
|
856
1065
|
def export_edgetpu(self, tflite_model="", prefix=colorstr("Edge TPU:")):
|
857
|
-
"""
|
1066
|
+
"""YOLO Edge TPU export https://coral.ai/docs/edgetpu/models-intro/."""
|
858
1067
|
LOGGER.warning(f"{prefix} WARNING ⚠️ Edge TPU known bug https://github.com/ultralytics/ultralytics/issues/1185")
|
859
1068
|
|
860
1069
|
cmd = "edgetpu_compiler --version"
|
@@ -876,7 +1085,15 @@ class Exporter:
|
|
876
1085
|
LOGGER.info(f"\n{prefix} starting export with Edge TPU compiler {ver}...")
|
877
1086
|
f = str(tflite_model).replace(".tflite", "_edgetpu.tflite") # Edge TPU model
|
878
1087
|
|
879
|
-
cmd =
|
1088
|
+
cmd = (
|
1089
|
+
"edgetpu_compiler "
|
1090
|
+
f'--out_dir "{Path(f).parent}" '
|
1091
|
+
"--show_operations "
|
1092
|
+
"--search_delegate "
|
1093
|
+
"--delegate_search_step 30 "
|
1094
|
+
"--timeout_sec 180 "
|
1095
|
+
f'"{tflite_model}"'
|
1096
|
+
)
|
880
1097
|
LOGGER.info(f"{prefix} running '{cmd}'")
|
881
1098
|
subprocess.run(cmd, shell=True)
|
882
1099
|
self._add_tflite_metadata(f)
|
@@ -884,7 +1101,7 @@ class Exporter:
|
|
884
1101
|
|
885
1102
|
@try_export
|
886
1103
|
def export_tfjs(self, prefix=colorstr("TensorFlow.js:")):
|
887
|
-
"""
|
1104
|
+
"""YOLO TensorFlow.js export."""
|
888
1105
|
check_requirements("tensorflowjs")
|
889
1106
|
if ARM64:
|
890
1107
|
# Fix error: `np.object` was a deprecated alias for the builtin `object` when exporting to TF.js on ARM64
|
@@ -914,31 +1131,157 @@ class Exporter:
|
|
914
1131
|
if " " in f:
|
915
1132
|
LOGGER.warning(f"{prefix} WARNING ⚠️ your model may not work correctly with spaces in path '{f}'.")
|
916
1133
|
|
917
|
-
#
|
918
|
-
# with open(f_json, 'w') as j: # sort JSON Identity_* in ascending order
|
919
|
-
# subst = re.sub(
|
920
|
-
# r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, '
|
921
|
-
# r'"Identity.?.?": {"name": "Identity.?.?"}, '
|
922
|
-
# r'"Identity.?.?": {"name": "Identity.?.?"}, '
|
923
|
-
# r'"Identity.?.?": {"name": "Identity.?.?"}}}',
|
924
|
-
# r'{"outputs": {"Identity": {"name": "Identity"}, '
|
925
|
-
# r'"Identity_1": {"name": "Identity_1"}, '
|
926
|
-
# r'"Identity_2": {"name": "Identity_2"}, '
|
927
|
-
# r'"Identity_3": {"name": "Identity_3"}}}',
|
928
|
-
# f_json.read_text(),
|
929
|
-
# )
|
930
|
-
# j.write(subst)
|
1134
|
+
# Add metadata
|
931
1135
|
yaml_save(Path(f) / "metadata.yaml", self.metadata) # add metadata.yaml
|
932
1136
|
return f, None
|
933
1137
|
|
1138
|
+
@try_export
|
1139
|
+
def export_imx(self, prefix=colorstr("IMX:")):
|
1140
|
+
"""YOLO IMX export."""
|
1141
|
+
gptq = False
|
1142
|
+
assert LINUX, (
|
1143
|
+
"export only supported on Linux. See https://developer.aitrios.sony-semicon.com/en/raspberrypi-ai-camera/documentation/imx500-converter"
|
1144
|
+
)
|
1145
|
+
if getattr(self.model, "end2end", False):
|
1146
|
+
raise ValueError("IMX export is not supported for end2end models.")
|
1147
|
+
if "C2f" not in self.model.__str__():
|
1148
|
+
raise ValueError("IMX export is only supported for YOLOv8n detection models")
|
1149
|
+
check_requirements(("model-compression-toolkit==2.1.1", "sony-custom-layers==0.2.0", "tensorflow==2.12.0"))
|
1150
|
+
check_requirements("imx500-converter[pt]==3.14.3") # Separate requirements for imx500-converter
|
1151
|
+
|
1152
|
+
import model_compression_toolkit as mct
|
1153
|
+
import onnx
|
1154
|
+
from sony_custom_layers.pytorch.object_detection.nms import multiclass_nms
|
1155
|
+
|
1156
|
+
try:
|
1157
|
+
out = subprocess.run(
|
1158
|
+
["java", "--version"], check=True, capture_output=True
|
1159
|
+
) # Java 17 is required for imx500-converter
|
1160
|
+
if "openjdk 17" not in str(out.stdout):
|
1161
|
+
raise FileNotFoundError
|
1162
|
+
except FileNotFoundError:
|
1163
|
+
subprocess.run(["sudo", "apt", "install", "-y", "openjdk-17-jdk", "openjdk-17-jre"], check=True)
|
1164
|
+
|
1165
|
+
def representative_dataset_gen(dataloader=self.get_int8_calibration_dataloader(prefix)):
|
1166
|
+
for batch in dataloader:
|
1167
|
+
img = batch["img"]
|
1168
|
+
img = img / 255.0
|
1169
|
+
yield [img]
|
1170
|
+
|
1171
|
+
tpc = mct.get_target_platform_capabilities(
|
1172
|
+
fw_name="pytorch", target_platform_name="imx500", target_platform_version="v1"
|
1173
|
+
)
|
1174
|
+
|
1175
|
+
config = mct.core.CoreConfig(
|
1176
|
+
mixed_precision_config=mct.core.MixedPrecisionQuantizationConfig(num_of_images=10),
|
1177
|
+
quantization_config=mct.core.QuantizationConfig(concat_threshold_update=True),
|
1178
|
+
)
|
1179
|
+
|
1180
|
+
resource_utilization = mct.core.ResourceUtilization(weights_memory=3146176 * 0.76)
|
1181
|
+
|
1182
|
+
quant_model = (
|
1183
|
+
mct.gptq.pytorch_gradient_post_training_quantization( # Perform Gradient-Based Post Training Quantization
|
1184
|
+
model=self.model,
|
1185
|
+
representative_data_gen=representative_dataset_gen,
|
1186
|
+
target_resource_utilization=resource_utilization,
|
1187
|
+
gptq_config=mct.gptq.get_pytorch_gptq_config(n_epochs=1000, use_hessian_based_weights=False),
|
1188
|
+
core_config=config,
|
1189
|
+
target_platform_capabilities=tpc,
|
1190
|
+
)[0]
|
1191
|
+
if gptq
|
1192
|
+
else mct.ptq.pytorch_post_training_quantization( # Perform post training quantization
|
1193
|
+
in_module=self.model,
|
1194
|
+
representative_data_gen=representative_dataset_gen,
|
1195
|
+
target_resource_utilization=resource_utilization,
|
1196
|
+
core_config=config,
|
1197
|
+
target_platform_capabilities=tpc,
|
1198
|
+
)[0]
|
1199
|
+
)
|
1200
|
+
|
1201
|
+
class NMSWrapper(torch.nn.Module):
|
1202
|
+
def __init__(
|
1203
|
+
self,
|
1204
|
+
model: torch.nn.Module,
|
1205
|
+
score_threshold: float = 0.001,
|
1206
|
+
iou_threshold: float = 0.7,
|
1207
|
+
max_detections: int = 300,
|
1208
|
+
):
|
1209
|
+
"""
|
1210
|
+
Wrapping PyTorch Module with multiclass_nms layer from sony_custom_layers.
|
1211
|
+
|
1212
|
+
Args:
|
1213
|
+
model (nn.Module): Model instance.
|
1214
|
+
score_threshold (float): Score threshold for non-maximum suppression.
|
1215
|
+
iou_threshold (float): Intersection over union threshold for non-maximum suppression.
|
1216
|
+
max_detections (float): The number of detections to return.
|
1217
|
+
"""
|
1218
|
+
super().__init__()
|
1219
|
+
self.model = model
|
1220
|
+
self.score_threshold = score_threshold
|
1221
|
+
self.iou_threshold = iou_threshold
|
1222
|
+
self.max_detections = max_detections
|
1223
|
+
|
1224
|
+
def forward(self, images):
|
1225
|
+
# model inference
|
1226
|
+
outputs = self.model(images)
|
1227
|
+
|
1228
|
+
boxes = outputs[0]
|
1229
|
+
scores = outputs[1]
|
1230
|
+
nms = multiclass_nms(
|
1231
|
+
boxes=boxes,
|
1232
|
+
scores=scores,
|
1233
|
+
score_threshold=self.score_threshold,
|
1234
|
+
iou_threshold=self.iou_threshold,
|
1235
|
+
max_detections=self.max_detections,
|
1236
|
+
)
|
1237
|
+
return nms
|
1238
|
+
|
1239
|
+
quant_model = NMSWrapper(
|
1240
|
+
model=quant_model,
|
1241
|
+
score_threshold=self.args.conf or 0.001,
|
1242
|
+
iou_threshold=self.args.iou,
|
1243
|
+
max_detections=self.args.max_det,
|
1244
|
+
).to(self.device)
|
1245
|
+
|
1246
|
+
f = Path(str(self.file).replace(self.file.suffix, "_imx_model"))
|
1247
|
+
f.mkdir(exist_ok=True)
|
1248
|
+
onnx_model = f / Path(str(self.file).replace(self.file.suffix, "_imx.onnx")) # js dir
|
1249
|
+
mct.exporter.pytorch_export_model(
|
1250
|
+
model=quant_model, save_model_path=onnx_model, repr_dataset=representative_dataset_gen
|
1251
|
+
)
|
1252
|
+
|
1253
|
+
model_onnx = onnx.load(onnx_model) # load onnx model
|
1254
|
+
for k, v in self.metadata.items():
|
1255
|
+
meta = model_onnx.metadata_props.add()
|
1256
|
+
meta.key, meta.value = k, str(v)
|
1257
|
+
|
1258
|
+
onnx.save(model_onnx, onnx_model)
|
1259
|
+
|
1260
|
+
subprocess.run(
|
1261
|
+
["imxconv-pt", "-i", str(onnx_model), "-o", str(f), "--no-input-persistency", "--overwrite-output"],
|
1262
|
+
check=True,
|
1263
|
+
)
|
1264
|
+
|
1265
|
+
# Needed for imx models.
|
1266
|
+
with open(f / "labels.txt", "w") as file:
|
1267
|
+
file.writelines([f"{name}\n" for _, name in self.model.names.items()])
|
1268
|
+
|
1269
|
+
return f, None
|
1270
|
+
|
934
1271
|
def _add_tflite_metadata(self, file):
|
935
1272
|
"""Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata."""
|
936
|
-
|
937
|
-
|
938
|
-
|
1273
|
+
import flatbuffers
|
1274
|
+
|
1275
|
+
try:
|
1276
|
+
# TFLite Support bug https://github.com/tensorflow/tflite-support/issues/954#issuecomment-2108570845
|
1277
|
+
from tensorflow_lite_support.metadata import metadata_schema_py_generated as schema # noqa
|
1278
|
+
from tensorflow_lite_support.metadata.python import metadata # noqa
|
1279
|
+
except ImportError: # ARM64 systems may not have the 'tensorflow_lite_support' package available
|
1280
|
+
from tflite_support import metadata # noqa
|
1281
|
+
from tflite_support import metadata_schema_py_generated as schema # noqa
|
939
1282
|
|
940
1283
|
# Create model info
|
941
|
-
model_meta =
|
1284
|
+
model_meta = schema.ModelMetadataT()
|
942
1285
|
model_meta.name = self.metadata["description"]
|
943
1286
|
model_meta.version = self.metadata["version"]
|
944
1287
|
model_meta.author = self.metadata["author"]
|
@@ -949,48 +1292,48 @@ class Exporter:
|
|
949
1292
|
with open(tmp_file, "w") as f:
|
950
1293
|
f.write(str(self.metadata))
|
951
1294
|
|
952
|
-
label_file =
|
1295
|
+
label_file = schema.AssociatedFileT()
|
953
1296
|
label_file.name = tmp_file.name
|
954
|
-
label_file.type =
|
1297
|
+
label_file.type = schema.AssociatedFileType.TENSOR_AXIS_LABELS
|
955
1298
|
|
956
1299
|
# Create input info
|
957
|
-
input_meta =
|
1300
|
+
input_meta = schema.TensorMetadataT()
|
958
1301
|
input_meta.name = "image"
|
959
1302
|
input_meta.description = "Input image to be detected."
|
960
|
-
input_meta.content =
|
961
|
-
input_meta.content.contentProperties =
|
962
|
-
input_meta.content.contentProperties.colorSpace =
|
963
|
-
input_meta.content.contentPropertiesType =
|
1303
|
+
input_meta.content = schema.ContentT()
|
1304
|
+
input_meta.content.contentProperties = schema.ImagePropertiesT()
|
1305
|
+
input_meta.content.contentProperties.colorSpace = schema.ColorSpaceType.RGB
|
1306
|
+
input_meta.content.contentPropertiesType = schema.ContentProperties.ImageProperties
|
964
1307
|
|
965
1308
|
# Create output info
|
966
|
-
output1 =
|
1309
|
+
output1 = schema.TensorMetadataT()
|
967
1310
|
output1.name = "output"
|
968
1311
|
output1.description = "Coordinates of detected objects, class labels, and confidence score"
|
969
1312
|
output1.associatedFiles = [label_file]
|
970
1313
|
if self.model.task == "segment":
|
971
|
-
output2 =
|
1314
|
+
output2 = schema.TensorMetadataT()
|
972
1315
|
output2.name = "output"
|
973
1316
|
output2.description = "Mask protos"
|
974
1317
|
output2.associatedFiles = [label_file]
|
975
1318
|
|
976
1319
|
# Create subgraph info
|
977
|
-
subgraph =
|
1320
|
+
subgraph = schema.SubGraphMetadataT()
|
978
1321
|
subgraph.inputTensorMetadata = [input_meta]
|
979
1322
|
subgraph.outputTensorMetadata = [output1, output2] if self.model.task == "segment" else [output1]
|
980
1323
|
model_meta.subgraphMetadata = [subgraph]
|
981
1324
|
|
982
1325
|
b = flatbuffers.Builder(0)
|
983
|
-
b.Finish(model_meta.Pack(b),
|
1326
|
+
b.Finish(model_meta.Pack(b), metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
|
984
1327
|
metadata_buf = b.Output()
|
985
1328
|
|
986
|
-
populator =
|
1329
|
+
populator = metadata.MetadataPopulator.with_model_file(str(file))
|
987
1330
|
populator.load_metadata_buffer(metadata_buf)
|
988
1331
|
populator.load_associated_files([str(tmp_file)])
|
989
1332
|
populator.populate()
|
990
1333
|
tmp_file.unlink()
|
991
1334
|
|
992
1335
|
def _pipeline_coreml(self, model, weights_dir=None, prefix=colorstr("CoreML Pipeline:")):
|
993
|
-
"""
|
1336
|
+
"""YOLO CoreML pipeline."""
|
994
1337
|
import coremltools as ct # noqa
|
995
1338
|
|
996
1339
|
LOGGER.info(f"{prefix} starting pipeline with coremltools {ct.__version__}...")
|
@@ -1014,27 +1357,11 @@ class Exporter:
|
|
1014
1357
|
names = self.metadata["names"]
|
1015
1358
|
nx, ny = spec.description.input[0].type.imageType.width, spec.description.input[0].type.imageType.height
|
1016
1359
|
_, nc = out0_shape # number of anchors, number of classes
|
1017
|
-
# _, nc = out0.type.multiArrayType.shape
|
1018
1360
|
assert len(names) == nc, f"{len(names)} names found for nc={nc}" # check
|
1019
1361
|
|
1020
1362
|
# Define output shapes (missing)
|
1021
1363
|
out0.type.multiArrayType.shape[:] = out0_shape # (3780, 80)
|
1022
1364
|
out1.type.multiArrayType.shape[:] = out1_shape # (3780, 4)
|
1023
|
-
# spec.neuralNetwork.preprocessing[0].featureName = '0'
|
1024
|
-
|
1025
|
-
# Flexible input shapes
|
1026
|
-
# from coremltools.models.neural_network import flexible_shape_utils
|
1027
|
-
# s = [] # shapes
|
1028
|
-
# s.append(flexible_shape_utils.NeuralNetworkImageSize(320, 192))
|
1029
|
-
# s.append(flexible_shape_utils.NeuralNetworkImageSize(640, 384)) # (height, width)
|
1030
|
-
# flexible_shape_utils.add_enumerated_image_sizes(spec, feature_name='image', sizes=s)
|
1031
|
-
# r = flexible_shape_utils.NeuralNetworkImageSizeRange() # shape ranges
|
1032
|
-
# r.add_height_range((192, 640))
|
1033
|
-
# r.add_width_range((192, 640))
|
1034
|
-
# flexible_shape_utils.update_image_size_range(spec, feature_name='image', size_range=r)
|
1035
|
-
|
1036
|
-
# Print
|
1037
|
-
# print(spec.description)
|
1038
1365
|
|
1039
1366
|
# Model from spec
|
1040
1367
|
model = ct.models.MLModel(spec, weights_dir=weights_dir)
|