ultralytics 8.0.237__py3-none-any.whl → 8.0.239__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ultralytics might be problematic. Click here for more details.
- ultralytics/__init__.py +2 -2
- ultralytics/cfg/__init__.py +241 -138
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +34 -0
- ultralytics/data/__init__.py +9 -2
- ultralytics/data/annotator.py +4 -4
- ultralytics/data/augment.py +186 -169
- ultralytics/data/base.py +54 -48
- ultralytics/data/build.py +34 -23
- ultralytics/data/converter.py +242 -70
- ultralytics/data/dataset.py +117 -95
- ultralytics/data/explorer/__init__.py +5 -0
- ultralytics/data/explorer/explorer.py +170 -97
- ultralytics/data/explorer/gui/__init__.py +1 -0
- ultralytics/data/explorer/gui/dash.py +146 -76
- ultralytics/data/explorer/utils.py +87 -25
- ultralytics/data/loaders.py +75 -62
- ultralytics/data/split_dota.py +44 -36
- ultralytics/data/utils.py +160 -142
- ultralytics/engine/exporter.py +348 -292
- ultralytics/engine/model.py +102 -66
- ultralytics/engine/predictor.py +74 -55
- ultralytics/engine/results.py +63 -40
- ultralytics/engine/trainer.py +192 -144
- ultralytics/engine/tuner.py +66 -59
- ultralytics/engine/validator.py +31 -26
- ultralytics/hub/__init__.py +54 -31
- ultralytics/hub/auth.py +28 -25
- ultralytics/hub/session.py +282 -133
- ultralytics/hub/utils.py +64 -42
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +6 -6
- ultralytics/models/fastsam/predict.py +3 -2
- ultralytics/models/fastsam/prompt.py +55 -48
- ultralytics/models/fastsam/val.py +1 -1
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +9 -8
- ultralytics/models/nas/predict.py +8 -6
- ultralytics/models/nas/val.py +11 -9
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +11 -9
- ultralytics/models/rtdetr/train.py +18 -16
- ultralytics/models/rtdetr/val.py +25 -19
- ultralytics/models/sam/__init__.py +1 -1
- ultralytics/models/sam/amg.py +13 -14
- ultralytics/models/sam/build.py +44 -42
- ultralytics/models/sam/model.py +6 -6
- ultralytics/models/sam/modules/decoders.py +6 -4
- ultralytics/models/sam/modules/encoders.py +37 -35
- ultralytics/models/sam/modules/sam.py +5 -4
- ultralytics/models/sam/modules/tiny_encoder.py +95 -73
- ultralytics/models/sam/modules/transformer.py +3 -2
- ultralytics/models/sam/predict.py +39 -27
- ultralytics/models/utils/loss.py +99 -95
- ultralytics/models/utils/ops.py +34 -31
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +8 -6
- ultralytics/models/yolo/classify/train.py +37 -31
- ultralytics/models/yolo/classify/val.py +26 -24
- ultralytics/models/yolo/detect/__init__.py +1 -1
- ultralytics/models/yolo/detect/predict.py +8 -6
- ultralytics/models/yolo/detect/train.py +47 -37
- ultralytics/models/yolo/detect/val.py +100 -82
- ultralytics/models/yolo/model.py +31 -25
- ultralytics/models/yolo/obb/__init__.py +1 -1
- ultralytics/models/yolo/obb/predict.py +13 -12
- ultralytics/models/yolo/obb/train.py +3 -3
- ultralytics/models/yolo/obb/val.py +80 -58
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +17 -12
- ultralytics/models/yolo/pose/train.py +28 -25
- ultralytics/models/yolo/pose/val.py +91 -64
- ultralytics/models/yolo/segment/__init__.py +1 -1
- ultralytics/models/yolo/segment/predict.py +10 -8
- ultralytics/models/yolo/segment/train.py +16 -15
- ultralytics/models/yolo/segment/val.py +90 -68
- ultralytics/nn/__init__.py +26 -6
- ultralytics/nn/autobackend.py +144 -112
- ultralytics/nn/modules/__init__.py +96 -13
- ultralytics/nn/modules/block.py +28 -7
- ultralytics/nn/modules/conv.py +41 -23
- ultralytics/nn/modules/head.py +67 -59
- ultralytics/nn/modules/transformer.py +49 -32
- ultralytics/nn/modules/utils.py +20 -15
- ultralytics/nn/tasks.py +215 -141
- ultralytics/solutions/ai_gym.py +59 -47
- ultralytics/solutions/distance_calculation.py +22 -15
- ultralytics/solutions/heatmap.py +76 -54
- ultralytics/solutions/object_counter.py +46 -39
- ultralytics/solutions/speed_estimation.py +13 -16
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +1 -0
- ultralytics/trackers/bot_sort.py +2 -1
- ultralytics/trackers/byte_tracker.py +10 -7
- ultralytics/trackers/track.py +7 -7
- ultralytics/trackers/utils/gmc.py +25 -25
- ultralytics/trackers/utils/kalman_filter.py +85 -42
- ultralytics/trackers/utils/matching.py +8 -7
- ultralytics/utils/__init__.py +173 -151
- ultralytics/utils/autobatch.py +10 -10
- ultralytics/utils/benchmarks.py +76 -86
- ultralytics/utils/callbacks/__init__.py +1 -1
- ultralytics/utils/callbacks/base.py +29 -29
- ultralytics/utils/callbacks/clearml.py +51 -43
- ultralytics/utils/callbacks/comet.py +81 -66
- ultralytics/utils/callbacks/dvc.py +33 -26
- ultralytics/utils/callbacks/hub.py +44 -26
- ultralytics/utils/callbacks/mlflow.py +31 -24
- ultralytics/utils/callbacks/neptune.py +35 -25
- ultralytics/utils/callbacks/raytune.py +9 -4
- ultralytics/utils/callbacks/tensorboard.py +16 -11
- ultralytics/utils/callbacks/wb.py +39 -33
- ultralytics/utils/checks.py +189 -141
- ultralytics/utils/dist.py +15 -12
- ultralytics/utils/downloads.py +112 -96
- ultralytics/utils/errors.py +1 -1
- ultralytics/utils/files.py +11 -11
- ultralytics/utils/instance.py +22 -22
- ultralytics/utils/loss.py +117 -67
- ultralytics/utils/metrics.py +224 -158
- ultralytics/utils/ops.py +39 -29
- ultralytics/utils/patches.py +3 -3
- ultralytics/utils/plotting.py +217 -120
- ultralytics/utils/tal.py +19 -13
- ultralytics/utils/torch_utils.py +138 -109
- ultralytics/utils/triton.py +12 -10
- ultralytics/utils/tuner.py +49 -47
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/METADATA +5 -4
- ultralytics-8.0.239.dist-info/RECORD +188 -0
- ultralytics-8.0.237.dist-info/RECORD +0 -187
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
ultralytics/engine/exporter.py
CHANGED
|
@@ -67,8 +67,20 @@ from ultralytics.data.utils import check_det_dataset
|
|
|
67
67
|
from ultralytics.nn.autobackend import check_class_names, default_class_names
|
|
68
68
|
from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder
|
|
69
69
|
from ultralytics.nn.tasks import DetectionModel, SegmentationModel
|
|
70
|
-
from ultralytics.utils import (
|
|
71
|
-
|
|
70
|
+
from ultralytics.utils import (
|
|
71
|
+
ARM64,
|
|
72
|
+
DEFAULT_CFG,
|
|
73
|
+
LINUX,
|
|
74
|
+
LOGGER,
|
|
75
|
+
MACOS,
|
|
76
|
+
ROOT,
|
|
77
|
+
WINDOWS,
|
|
78
|
+
__version__,
|
|
79
|
+
callbacks,
|
|
80
|
+
colorstr,
|
|
81
|
+
get_default_args,
|
|
82
|
+
yaml_save,
|
|
83
|
+
)
|
|
72
84
|
from ultralytics.utils.checks import check_imgsz, check_is_path_safe, check_requirements, check_version
|
|
73
85
|
from ultralytics.utils.downloads import attempt_download_asset, get_github_assets
|
|
74
86
|
from ultralytics.utils.files import file_size, spaces_in_path
|
|
@@ -79,21 +91,23 @@ from ultralytics.utils.torch_utils import get_latest_opset, select_device, smart
|
|
|
79
91
|
def export_formats():
|
|
80
92
|
"""YOLOv8 export formats."""
|
|
81
93
|
import pandas
|
|
94
|
+
|
|
82
95
|
x = [
|
|
83
|
-
[
|
|
84
|
-
[
|
|
85
|
-
[
|
|
86
|
-
[
|
|
87
|
-
[
|
|
88
|
-
[
|
|
89
|
-
[
|
|
90
|
-
[
|
|
91
|
-
[
|
|
92
|
-
[
|
|
93
|
-
[
|
|
94
|
-
[
|
|
95
|
-
[
|
|
96
|
-
|
|
96
|
+
["PyTorch", "-", ".pt", True, True],
|
|
97
|
+
["TorchScript", "torchscript", ".torchscript", True, True],
|
|
98
|
+
["ONNX", "onnx", ".onnx", True, True],
|
|
99
|
+
["OpenVINO", "openvino", "_openvino_model", True, False],
|
|
100
|
+
["TensorRT", "engine", ".engine", False, True],
|
|
101
|
+
["CoreML", "coreml", ".mlpackage", True, False],
|
|
102
|
+
["TensorFlow SavedModel", "saved_model", "_saved_model", True, True],
|
|
103
|
+
["TensorFlow GraphDef", "pb", ".pb", True, True],
|
|
104
|
+
["TensorFlow Lite", "tflite", ".tflite", True, False],
|
|
105
|
+
["TensorFlow Edge TPU", "edgetpu", "_edgetpu.tflite", True, False],
|
|
106
|
+
["TensorFlow.js", "tfjs", "_web_model", True, False],
|
|
107
|
+
["PaddlePaddle", "paddle", "_paddle_model", True, True],
|
|
108
|
+
["ncnn", "ncnn", "_ncnn_model", True, True],
|
|
109
|
+
]
|
|
110
|
+
return pandas.DataFrame(x, columns=["Format", "Argument", "Suffix", "CPU", "GPU"])
|
|
97
111
|
|
|
98
112
|
|
|
99
113
|
def gd_outputs(gd):
|
|
@@ -102,7 +116,7 @@ def gd_outputs(gd):
|
|
|
102
116
|
for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef
|
|
103
117
|
name_list.append(node.name)
|
|
104
118
|
input_list.extend(node.input)
|
|
105
|
-
return sorted(f
|
|
119
|
+
return sorted(f"{x}:0" for x in list(set(name_list) - set(input_list)) if not x.startswith("NoOp"))
|
|
106
120
|
|
|
107
121
|
|
|
108
122
|
def try_export(inner_func):
|
|
@@ -111,14 +125,14 @@ def try_export(inner_func):
|
|
|
111
125
|
|
|
112
126
|
def outer_func(*args, **kwargs):
|
|
113
127
|
"""Export a model."""
|
|
114
|
-
prefix = inner_args[
|
|
128
|
+
prefix = inner_args["prefix"]
|
|
115
129
|
try:
|
|
116
130
|
with Profile() as dt:
|
|
117
131
|
f, model = inner_func(*args, **kwargs)
|
|
118
132
|
LOGGER.info(f"{prefix} export success ✅ {dt.t:.1f}s, saved as '{f}' ({file_size(f):.1f} MB)")
|
|
119
133
|
return f, model
|
|
120
134
|
except Exception as e:
|
|
121
|
-
LOGGER.info(f
|
|
135
|
+
LOGGER.info(f"{prefix} export failure ❌ {dt.t:.1f}s: {e}")
|
|
122
136
|
raise e
|
|
123
137
|
|
|
124
138
|
return outer_func
|
|
@@ -143,8 +157,8 @@ class Exporter:
|
|
|
143
157
|
_callbacks (dict, optional): Dictionary of callback functions. Defaults to None.
|
|
144
158
|
"""
|
|
145
159
|
self.args = get_cfg(cfg, overrides)
|
|
146
|
-
if self.args.format.lower() in (
|
|
147
|
-
os.environ[
|
|
160
|
+
if self.args.format.lower() in ("coreml", "mlmodel"): # fix attempt for protobuf<3.20.x errors
|
|
161
|
+
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" # must run before TensorBoard callback
|
|
148
162
|
|
|
149
163
|
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
|
150
164
|
callbacks.add_integration_callbacks(self)
|
|
@@ -152,45 +166,46 @@ class Exporter:
|
|
|
152
166
|
@smart_inference_mode()
|
|
153
167
|
def __call__(self, model=None):
|
|
154
168
|
"""Returns list of exported files/dirs after running callbacks."""
|
|
155
|
-
self.run_callbacks(
|
|
169
|
+
self.run_callbacks("on_export_start")
|
|
156
170
|
t = time.time()
|
|
157
171
|
fmt = self.args.format.lower() # to lowercase
|
|
158
|
-
if fmt in (
|
|
159
|
-
fmt =
|
|
160
|
-
if fmt in (
|
|
161
|
-
fmt =
|
|
162
|
-
fmts = tuple(export_formats()[
|
|
172
|
+
if fmt in ("tensorrt", "trt"): # 'engine' aliases
|
|
173
|
+
fmt = "engine"
|
|
174
|
+
if fmt in ("mlmodel", "mlpackage", "mlprogram", "apple", "ios", "coreml"): # 'coreml' aliases
|
|
175
|
+
fmt = "coreml"
|
|
176
|
+
fmts = tuple(export_formats()["Argument"][1:]) # available export formats
|
|
163
177
|
flags = [x == fmt for x in fmts]
|
|
164
178
|
if sum(flags) != 1:
|
|
165
179
|
raise ValueError(f"Invalid export format='{fmt}'. Valid formats are {fmts}")
|
|
166
180
|
jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn = flags # export booleans
|
|
167
181
|
|
|
168
182
|
# Device
|
|
169
|
-
if fmt ==
|
|
170
|
-
LOGGER.warning(
|
|
171
|
-
self.args.device =
|
|
172
|
-
self.device = select_device(
|
|
183
|
+
if fmt == "engine" and self.args.device is None:
|
|
184
|
+
LOGGER.warning("WARNING ⚠️ TensorRT requires GPU export, automatically assigning device=0")
|
|
185
|
+
self.args.device = "0"
|
|
186
|
+
self.device = select_device("cpu" if self.args.device is None else self.args.device)
|
|
173
187
|
|
|
174
188
|
# Checks
|
|
175
|
-
if not hasattr(model,
|
|
189
|
+
if not hasattr(model, "names"):
|
|
176
190
|
model.names = default_class_names()
|
|
177
191
|
model.names = check_class_names(model.names)
|
|
178
|
-
if self.args.half and onnx and self.device.type ==
|
|
179
|
-
LOGGER.warning(
|
|
192
|
+
if self.args.half and onnx and self.device.type == "cpu":
|
|
193
|
+
LOGGER.warning("WARNING ⚠️ half=True only compatible with GPU export, i.e. use device=0")
|
|
180
194
|
self.args.half = False
|
|
181
|
-
assert not self.args.dynamic,
|
|
195
|
+
assert not self.args.dynamic, "half=True not compatible with dynamic=True, i.e. use only one."
|
|
182
196
|
self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size
|
|
183
197
|
if self.args.optimize:
|
|
184
198
|
assert not ncnn, "optimize=True not compatible with format='ncnn', i.e. use optimize=False"
|
|
185
|
-
assert self.device.type ==
|
|
199
|
+
assert self.device.type == "cpu", "optimize=True not compatible with cuda devices, i.e. use device='cpu'"
|
|
186
200
|
if edgetpu and not LINUX:
|
|
187
|
-
raise SystemError(
|
|
201
|
+
raise SystemError("Edge TPU export only supported on Linux. See https://coral.ai/docs/edgetpu/compiler/")
|
|
188
202
|
|
|
189
203
|
# Input
|
|
190
204
|
im = torch.zeros(self.args.batch, 3, *self.imgsz).to(self.device)
|
|
191
205
|
file = Path(
|
|
192
|
-
getattr(model,
|
|
193
|
-
|
|
206
|
+
getattr(model, "pt_path", None) or getattr(model, "yaml_file", None) or model.yaml.get("yaml_file", "")
|
|
207
|
+
)
|
|
208
|
+
if file.suffix in {".yaml", ".yml"}:
|
|
194
209
|
file = Path(file.name)
|
|
195
210
|
|
|
196
211
|
# Update model
|
|
@@ -212,42 +227,48 @@ class Exporter:
|
|
|
212
227
|
y = None
|
|
213
228
|
for _ in range(2):
|
|
214
229
|
y = model(im) # dry runs
|
|
215
|
-
if self.args.half and (engine or onnx) and self.device.type !=
|
|
230
|
+
if self.args.half and (engine or onnx) and self.device.type != "cpu":
|
|
216
231
|
im, model = im.half(), model.half() # to FP16
|
|
217
232
|
|
|
218
233
|
# Filter warnings
|
|
219
|
-
warnings.filterwarnings(
|
|
220
|
-
warnings.filterwarnings(
|
|
221
|
-
warnings.filterwarnings(
|
|
234
|
+
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) # suppress TracerWarning
|
|
235
|
+
warnings.filterwarnings("ignore", category=UserWarning) # suppress shape prim::Constant missing ONNX warning
|
|
236
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning) # suppress CoreML np.bool deprecation warning
|
|
222
237
|
|
|
223
238
|
# Assign
|
|
224
239
|
self.im = im
|
|
225
240
|
self.model = model
|
|
226
241
|
self.file = file
|
|
227
|
-
self.output_shape =
|
|
228
|
-
tuple(
|
|
229
|
-
|
|
230
|
-
|
|
242
|
+
self.output_shape = (
|
|
243
|
+
tuple(y.shape)
|
|
244
|
+
if isinstance(y, torch.Tensor)
|
|
245
|
+
else tuple(tuple(x.shape if isinstance(x, torch.Tensor) else []) for x in y)
|
|
246
|
+
)
|
|
247
|
+
self.pretty_name = Path(self.model.yaml.get("yaml_file", self.file)).stem.replace("yolo", "YOLO")
|
|
248
|
+
data = model.args["data"] if hasattr(model, "args") and isinstance(model.args, dict) else ""
|
|
231
249
|
description = f'Ultralytics {self.pretty_name} model {f"trained on {data}" if data else ""}'
|
|
232
250
|
self.metadata = {
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
251
|
+
"description": description,
|
|
252
|
+
"author": "Ultralytics",
|
|
253
|
+
"license": "AGPL-3.0 https://ultralytics.com/license",
|
|
254
|
+
"date": datetime.now().isoformat(),
|
|
255
|
+
"version": __version__,
|
|
256
|
+
"stride": int(max(model.stride)),
|
|
257
|
+
"task": model.task,
|
|
258
|
+
"batch": self.args.batch,
|
|
259
|
+
"imgsz": self.imgsz,
|
|
260
|
+
"names": model.names,
|
|
261
|
+
} # model metadata
|
|
262
|
+
if model.task == "pose":
|
|
263
|
+
self.metadata["kpt_shape"] = model.model[-1].kpt_shape
|
|
264
|
+
|
|
265
|
+
LOGGER.info(
|
|
266
|
+
f"\n{colorstr('PyTorch:')} starting from '{file}' with input shape {tuple(im.shape)} BCHW and "
|
|
267
|
+
f'output shape(s) {self.output_shape} ({file_size(file):.1f} MB)'
|
|
268
|
+
)
|
|
248
269
|
|
|
249
270
|
# Exports
|
|
250
|
-
f = [
|
|
271
|
+
f = [""] * len(fmts) # exported filenames
|
|
251
272
|
if jit or ncnn: # TorchScript
|
|
252
273
|
f[0], _ = self.export_torchscript()
|
|
253
274
|
if engine: # TensorRT required before ONNX
|
|
@@ -266,7 +287,7 @@ class Exporter:
|
|
|
266
287
|
if tflite:
|
|
267
288
|
f[7], _ = self.export_tflite(keras_model=keras_model, nms=False, agnostic_nms=self.args.agnostic_nms)
|
|
268
289
|
if edgetpu:
|
|
269
|
-
f[8], _ = self.export_edgetpu(tflite_model=Path(f[5]) / f
|
|
290
|
+
f[8], _ = self.export_edgetpu(tflite_model=Path(f[5]) / f"{self.file.stem}_full_integer_quant.tflite")
|
|
270
291
|
if tfjs:
|
|
271
292
|
f[9], _ = self.export_tfjs()
|
|
272
293
|
if paddle: # PaddlePaddle
|
|
@@ -279,58 +300,65 @@ class Exporter:
|
|
|
279
300
|
if any(f):
|
|
280
301
|
f = str(Path(f[-1]))
|
|
281
302
|
square = self.imgsz[0] == self.imgsz[1]
|
|
282
|
-
s =
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
303
|
+
s = (
|
|
304
|
+
""
|
|
305
|
+
if square
|
|
306
|
+
else f"WARNING ⚠️ non-PyTorch val requires square images, 'imgsz={self.imgsz}' will not "
|
|
307
|
+
f"work. Use export 'imgsz={max(self.imgsz)}' if val is required."
|
|
308
|
+
)
|
|
309
|
+
imgsz = self.imgsz[0] if square else str(self.imgsz)[1:-1].replace(" ", "")
|
|
310
|
+
predict_data = f"data={data}" if model.task == "segment" and fmt == "pb" else ""
|
|
311
|
+
q = "int8" if self.args.int8 else "half" if self.args.half else "" # quantization
|
|
312
|
+
LOGGER.info(
|
|
313
|
+
f'\nExport complete ({time.time() - t:.1f}s)'
|
|
314
|
+
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
|
|
315
|
+
f'\nPredict: yolo predict task={model.task} model={f} imgsz={imgsz} {q} {predict_data}'
|
|
316
|
+
f'\nValidate: yolo val task={model.task} model={f} imgsz={imgsz} data={data} {q} {s}'
|
|
317
|
+
f'\nVisualize: https://netron.app'
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
self.run_callbacks("on_export_end")
|
|
294
321
|
return f # return list of exported files/dirs
|
|
295
322
|
|
|
296
323
|
@try_export
|
|
297
|
-
def export_torchscript(self, prefix=colorstr(
|
|
324
|
+
def export_torchscript(self, prefix=colorstr("TorchScript:")):
|
|
298
325
|
"""YOLOv8 TorchScript model export."""
|
|
299
|
-
LOGGER.info(f
|
|
300
|
-
f = self.file.with_suffix(
|
|
326
|
+
LOGGER.info(f"\n{prefix} starting export with torch {torch.__version__}...")
|
|
327
|
+
f = self.file.with_suffix(".torchscript")
|
|
301
328
|
|
|
302
329
|
ts = torch.jit.trace(self.model, self.im, strict=False)
|
|
303
|
-
extra_files = {
|
|
330
|
+
extra_files = {"config.txt": json.dumps(self.metadata)} # torch._C.ExtraFilesMap()
|
|
304
331
|
if self.args.optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
|
|
305
|
-
LOGGER.info(f
|
|
332
|
+
LOGGER.info(f"{prefix} optimizing for mobile...")
|
|
306
333
|
from torch.utils.mobile_optimizer import optimize_for_mobile
|
|
334
|
+
|
|
307
335
|
optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files)
|
|
308
336
|
else:
|
|
309
337
|
ts.save(str(f), _extra_files=extra_files)
|
|
310
338
|
return f, None
|
|
311
339
|
|
|
312
340
|
@try_export
|
|
313
|
-
def export_onnx(self, prefix=colorstr(
|
|
341
|
+
def export_onnx(self, prefix=colorstr("ONNX:")):
|
|
314
342
|
"""YOLOv8 ONNX export."""
|
|
315
|
-
requirements = [
|
|
343
|
+
requirements = ["onnx>=1.12.0"]
|
|
316
344
|
if self.args.simplify:
|
|
317
|
-
requirements += [
|
|
345
|
+
requirements += ["onnxsim>=0.4.33", "onnxruntime-gpu" if torch.cuda.is_available() else "onnxruntime"]
|
|
318
346
|
check_requirements(requirements)
|
|
319
347
|
import onnx # noqa
|
|
320
348
|
|
|
321
349
|
opset_version = self.args.opset or get_latest_opset()
|
|
322
|
-
LOGGER.info(f
|
|
323
|
-
f = str(self.file.with_suffix(
|
|
350
|
+
LOGGER.info(f"\n{prefix} starting export with onnx {onnx.__version__} opset {opset_version}...")
|
|
351
|
+
f = str(self.file.with_suffix(".onnx"))
|
|
324
352
|
|
|
325
|
-
output_names = [
|
|
353
|
+
output_names = ["output0", "output1"] if isinstance(self.model, SegmentationModel) else ["output0"]
|
|
326
354
|
dynamic = self.args.dynamic
|
|
327
355
|
if dynamic:
|
|
328
|
-
dynamic = {
|
|
356
|
+
dynamic = {"images": {0: "batch", 2: "height", 3: "width"}} # shape(1,3,640,640)
|
|
329
357
|
if isinstance(self.model, SegmentationModel):
|
|
330
|
-
dynamic[
|
|
331
|
-
dynamic[
|
|
358
|
+
dynamic["output0"] = {0: "batch", 2: "anchors"} # shape(1, 116, 8400)
|
|
359
|
+
dynamic["output1"] = {0: "batch", 2: "mask_height", 3: "mask_width"} # shape(1,32,160,160)
|
|
332
360
|
elif isinstance(self.model, DetectionModel):
|
|
333
|
-
dynamic[
|
|
361
|
+
dynamic["output0"] = {0: "batch", 2: "anchors"} # shape(1, 84, 8400)
|
|
334
362
|
|
|
335
363
|
torch.onnx.export(
|
|
336
364
|
self.model.cpu() if dynamic else self.model, # dynamic=True only compatible with cpu
|
|
@@ -339,9 +367,10 @@ class Exporter:
|
|
|
339
367
|
verbose=False,
|
|
340
368
|
opset_version=opset_version,
|
|
341
369
|
do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
|
|
342
|
-
input_names=[
|
|
370
|
+
input_names=["images"],
|
|
343
371
|
output_names=output_names,
|
|
344
|
-
dynamic_axes=dynamic or None
|
|
372
|
+
dynamic_axes=dynamic or None,
|
|
373
|
+
)
|
|
345
374
|
|
|
346
375
|
# Checks
|
|
347
376
|
model_onnx = onnx.load(f) # load onnx model
|
|
@@ -352,12 +381,12 @@ class Exporter:
|
|
|
352
381
|
try:
|
|
353
382
|
import onnxsim
|
|
354
383
|
|
|
355
|
-
LOGGER.info(f
|
|
384
|
+
LOGGER.info(f"{prefix} simplifying with onnxsim {onnxsim.__version__}...")
|
|
356
385
|
# subprocess.run(f'onnxsim "{f}" "{f}"', shell=True)
|
|
357
386
|
model_onnx, check = onnxsim.simplify(model_onnx)
|
|
358
|
-
assert check,
|
|
387
|
+
assert check, "Simplified ONNX model could not be validated"
|
|
359
388
|
except Exception as e:
|
|
360
|
-
LOGGER.info(f
|
|
389
|
+
LOGGER.info(f"{prefix} simplifier failure: {e}")
|
|
361
390
|
|
|
362
391
|
# Metadata
|
|
363
392
|
for k, v in self.metadata.items():
|
|
@@ -368,58 +397,56 @@ class Exporter:
|
|
|
368
397
|
return f, model_onnx
|
|
369
398
|
|
|
370
399
|
@try_export
|
|
371
|
-
def export_openvino(self, prefix=colorstr(
|
|
400
|
+
def export_openvino(self, prefix=colorstr("OpenVINO:")):
|
|
372
401
|
"""YOLOv8 OpenVINO export."""
|
|
373
|
-
check_requirements(
|
|
402
|
+
check_requirements("openvino-dev>=2023.0") # requires openvino-dev: https://pypi.org/project/openvino-dev/
|
|
374
403
|
import openvino.runtime as ov # noqa
|
|
375
404
|
from openvino.tools import mo # noqa
|
|
376
405
|
|
|
377
|
-
LOGGER.info(f
|
|
378
|
-
f = str(self.file).replace(self.file.suffix, f
|
|
379
|
-
fq = str(self.file).replace(self.file.suffix, f
|
|
380
|
-
f_onnx = self.file.with_suffix(
|
|
381
|
-
f_ov = str(Path(f) / self.file.with_suffix(
|
|
382
|
-
fq_ov = str(Path(fq) / self.file.with_suffix(
|
|
406
|
+
LOGGER.info(f"\n{prefix} starting export with openvino {ov.__version__}...")
|
|
407
|
+
f = str(self.file).replace(self.file.suffix, f"_openvino_model{os.sep}")
|
|
408
|
+
fq = str(self.file).replace(self.file.suffix, f"_int8_openvino_model{os.sep}")
|
|
409
|
+
f_onnx = self.file.with_suffix(".onnx")
|
|
410
|
+
f_ov = str(Path(f) / self.file.with_suffix(".xml").name)
|
|
411
|
+
fq_ov = str(Path(fq) / self.file.with_suffix(".xml").name)
|
|
383
412
|
|
|
384
413
|
def serialize(ov_model, file):
|
|
385
414
|
"""Set RT info, serialize and save metadata YAML."""
|
|
386
|
-
ov_model.set_rt_info(
|
|
387
|
-
ov_model.set_rt_info(True, [
|
|
388
|
-
ov_model.set_rt_info(114, [
|
|
389
|
-
ov_model.set_rt_info([255.0], [
|
|
390
|
-
ov_model.set_rt_info(self.args.iou, [
|
|
391
|
-
ov_model.set_rt_info([v.replace(
|
|
392
|
-
if self.model.task !=
|
|
393
|
-
ov_model.set_rt_info(
|
|
415
|
+
ov_model.set_rt_info("YOLOv8", ["model_info", "model_type"])
|
|
416
|
+
ov_model.set_rt_info(True, ["model_info", "reverse_input_channels"])
|
|
417
|
+
ov_model.set_rt_info(114, ["model_info", "pad_value"])
|
|
418
|
+
ov_model.set_rt_info([255.0], ["model_info", "scale_values"])
|
|
419
|
+
ov_model.set_rt_info(self.args.iou, ["model_info", "iou_threshold"])
|
|
420
|
+
ov_model.set_rt_info([v.replace(" ", "_") for v in self.model.names.values()], ["model_info", "labels"])
|
|
421
|
+
if self.model.task != "classify":
|
|
422
|
+
ov_model.set_rt_info("fit_to_window_letterbox", ["model_info", "resize_type"])
|
|
394
423
|
|
|
395
424
|
ov.serialize(ov_model, file) # save
|
|
396
|
-
yaml_save(Path(file).parent /
|
|
425
|
+
yaml_save(Path(file).parent / "metadata.yaml", self.metadata) # add metadata.yaml
|
|
397
426
|
|
|
398
|
-
ov_model = mo.convert_model(
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
compress_to_fp16=self.args.half) # export
|
|
427
|
+
ov_model = mo.convert_model(
|
|
428
|
+
f_onnx, model_name=self.pretty_name, framework="onnx", compress_to_fp16=self.args.half
|
|
429
|
+
) # export
|
|
402
430
|
|
|
403
431
|
if self.args.int8:
|
|
404
432
|
assert self.args.data, "INT8 export requires a data argument for calibration, i.e. 'data=coco8.yaml'"
|
|
405
|
-
check_requirements(
|
|
433
|
+
check_requirements("nncf>=2.5.0")
|
|
406
434
|
import nncf
|
|
407
435
|
|
|
408
436
|
def transform_fn(data_item):
|
|
409
437
|
"""Quantization transform function."""
|
|
410
|
-
im = data_item[
|
|
438
|
+
im = data_item["img"].numpy().astype(np.float32) / 255.0 # uint8 to fp16/32 and 0 - 255 to 0.0 - 1.0
|
|
411
439
|
return np.expand_dims(im, 0) if im.ndim == 3 else im
|
|
412
440
|
|
|
413
441
|
# Generate calibration data for integer quantization
|
|
414
442
|
LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'")
|
|
415
443
|
data = check_det_dataset(self.args.data)
|
|
416
|
-
dataset = YOLODataset(data[
|
|
444
|
+
dataset = YOLODataset(data["val"], data=data, imgsz=self.imgsz[0], augment=False)
|
|
417
445
|
quantization_dataset = nncf.Dataset(dataset, transform_fn)
|
|
418
|
-
ignored_scope = nncf.IgnoredScope(types=[
|
|
419
|
-
quantized_ov_model = nncf.quantize(
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
ignored_scope=ignored_scope)
|
|
446
|
+
ignored_scope = nncf.IgnoredScope(types=["Multiply", "Subtract", "Sigmoid"]) # ignore operation
|
|
447
|
+
quantized_ov_model = nncf.quantize(
|
|
448
|
+
ov_model, quantization_dataset, preset=nncf.QuantizationPreset.MIXED, ignored_scope=ignored_scope
|
|
449
|
+
)
|
|
423
450
|
serialize(quantized_ov_model, fq_ov)
|
|
424
451
|
return fq, None
|
|
425
452
|
|
|
@@ -427,48 +454,49 @@ class Exporter:
|
|
|
427
454
|
return f, None
|
|
428
455
|
|
|
429
456
|
@try_export
|
|
430
|
-
def export_paddle(self, prefix=colorstr(
|
|
457
|
+
def export_paddle(self, prefix=colorstr("PaddlePaddle:")):
|
|
431
458
|
"""YOLOv8 Paddle export."""
|
|
432
|
-
check_requirements((
|
|
459
|
+
check_requirements(("paddlepaddle", "x2paddle"))
|
|
433
460
|
import x2paddle # noqa
|
|
434
461
|
from x2paddle.convert import pytorch2paddle # noqa
|
|
435
462
|
|
|
436
|
-
LOGGER.info(f
|
|
437
|
-
f = str(self.file).replace(self.file.suffix, f
|
|
463
|
+
LOGGER.info(f"\n{prefix} starting export with X2Paddle {x2paddle.__version__}...")
|
|
464
|
+
f = str(self.file).replace(self.file.suffix, f"_paddle_model{os.sep}")
|
|
438
465
|
|
|
439
|
-
pytorch2paddle(module=self.model, save_dir=f, jit_type=
|
|
440
|
-
yaml_save(Path(f) /
|
|
466
|
+
pytorch2paddle(module=self.model, save_dir=f, jit_type="trace", input_examples=[self.im]) # export
|
|
467
|
+
yaml_save(Path(f) / "metadata.yaml", self.metadata) # add metadata.yaml
|
|
441
468
|
return f, None
|
|
442
469
|
|
|
443
470
|
@try_export
|
|
444
|
-
def export_ncnn(self, prefix=colorstr(
|
|
471
|
+
def export_ncnn(self, prefix=colorstr("ncnn:")):
|
|
445
472
|
"""
|
|
446
473
|
YOLOv8 ncnn export using PNNX https://github.com/pnnx/pnnx.
|
|
447
474
|
"""
|
|
448
|
-
check_requirements(
|
|
475
|
+
check_requirements("git+https://github.com/Tencent/ncnn.git" if ARM64 else "ncnn") # requires ncnn
|
|
449
476
|
import ncnn # noqa
|
|
450
477
|
|
|
451
|
-
LOGGER.info(f
|
|
452
|
-
f = Path(str(self.file).replace(self.file.suffix, f
|
|
453
|
-
f_ts = self.file.with_suffix(
|
|
478
|
+
LOGGER.info(f"\n{prefix} starting export with ncnn {ncnn.__version__}...")
|
|
479
|
+
f = Path(str(self.file).replace(self.file.suffix, f"_ncnn_model{os.sep}"))
|
|
480
|
+
f_ts = self.file.with_suffix(".torchscript")
|
|
454
481
|
|
|
455
|
-
name = Path(
|
|
482
|
+
name = Path("pnnx.exe" if WINDOWS else "pnnx") # PNNX filename
|
|
456
483
|
pnnx = name if name.is_file() else ROOT / name
|
|
457
484
|
if not pnnx.is_file():
|
|
458
485
|
LOGGER.warning(
|
|
459
|
-
f
|
|
460
|
-
|
|
461
|
-
f
|
|
462
|
-
|
|
486
|
+
f"{prefix} WARNING ⚠️ PNNX not found. Attempting to download binary file from "
|
|
487
|
+
"https://github.com/pnnx/pnnx/.\nNote PNNX Binary file must be placed in current working directory "
|
|
488
|
+
f"or in {ROOT}. See PNNX repo for full installation instructions."
|
|
489
|
+
)
|
|
490
|
+
system = ["macos"] if MACOS else ["windows"] if WINDOWS else ["ubuntu", "linux"] # operating system
|
|
463
491
|
try:
|
|
464
|
-
_, assets = get_github_assets(repo=
|
|
492
|
+
_, assets = get_github_assets(repo="pnnx/pnnx", retry=True)
|
|
465
493
|
url = [x for x in assets if any(s in x for s in system)][0]
|
|
466
494
|
except Exception as e:
|
|
467
|
-
url = f
|
|
468
|
-
LOGGER.warning(f
|
|
469
|
-
asset = attempt_download_asset(url, repo=
|
|
495
|
+
url = f"https://github.com/pnnx/pnnx/releases/download/20231127/pnnx-20231127-{system[0]}.zip"
|
|
496
|
+
LOGGER.warning(f"{prefix} WARNING ⚠️ PNNX GitHub assets not found: {e}, using default {url}")
|
|
497
|
+
asset = attempt_download_asset(url, repo="pnnx/pnnx", release="latest")
|
|
470
498
|
if check_is_path_safe(Path.cwd(), asset): # avoid path traversal security vulnerability
|
|
471
|
-
unzip_dir = Path(asset).with_suffix(
|
|
499
|
+
unzip_dir = Path(asset).with_suffix("")
|
|
472
500
|
(unzip_dir / name).rename(pnnx) # move binary to ROOT
|
|
473
501
|
shutil.rmtree(unzip_dir) # delete unzip dir
|
|
474
502
|
Path(asset).unlink() # delete zip
|
|
@@ -477,53 +505,56 @@ class Exporter:
|
|
|
477
505
|
ncnn_args = [
|
|
478
506
|
f'ncnnparam={f / "model.ncnn.param"}',
|
|
479
507
|
f'ncnnbin={f / "model.ncnn.bin"}',
|
|
480
|
-
f'ncnnpy={f / "model_ncnn.py"}',
|
|
508
|
+
f'ncnnpy={f / "model_ncnn.py"}',
|
|
509
|
+
]
|
|
481
510
|
|
|
482
511
|
pnnx_args = [
|
|
483
512
|
f'pnnxparam={f / "model.pnnx.param"}',
|
|
484
513
|
f'pnnxbin={f / "model.pnnx.bin"}',
|
|
485
514
|
f'pnnxpy={f / "model_pnnx.py"}',
|
|
486
|
-
f'pnnxonnx={f / "model.pnnx.onnx"}',
|
|
515
|
+
f'pnnxonnx={f / "model.pnnx.onnx"}',
|
|
516
|
+
]
|
|
487
517
|
|
|
488
518
|
cmd = [
|
|
489
519
|
str(pnnx),
|
|
490
520
|
str(f_ts),
|
|
491
521
|
*ncnn_args,
|
|
492
522
|
*pnnx_args,
|
|
493
|
-
f
|
|
494
|
-
f
|
|
495
|
-
f'inputshape="{[self.args.batch, 3, *self.imgsz]}"',
|
|
523
|
+
f"fp16={int(self.args.half)}",
|
|
524
|
+
f"device={self.device.type}",
|
|
525
|
+
f'inputshape="{[self.args.batch, 3, *self.imgsz]}"',
|
|
526
|
+
]
|
|
496
527
|
f.mkdir(exist_ok=True) # make ncnn_model directory
|
|
497
528
|
LOGGER.info(f"{prefix} running '{' '.join(cmd)}'")
|
|
498
529
|
subprocess.run(cmd, check=True)
|
|
499
530
|
|
|
500
531
|
# Remove debug files
|
|
501
|
-
pnnx_files = [x.split(
|
|
502
|
-
for f_debug in (
|
|
532
|
+
pnnx_files = [x.split("=")[-1] for x in pnnx_args]
|
|
533
|
+
for f_debug in ("debug.bin", "debug.param", "debug2.bin", "debug2.param", *pnnx_files):
|
|
503
534
|
Path(f_debug).unlink(missing_ok=True)
|
|
504
535
|
|
|
505
|
-
yaml_save(f /
|
|
536
|
+
yaml_save(f / "metadata.yaml", self.metadata) # add metadata.yaml
|
|
506
537
|
return str(f), None
|
|
507
538
|
|
|
508
539
|
@try_export
|
|
509
|
-
def export_coreml(self, prefix=colorstr(
|
|
540
|
+
def export_coreml(self, prefix=colorstr("CoreML:")):
|
|
510
541
|
"""YOLOv8 CoreML export."""
|
|
511
|
-
mlmodel = self.args.format.lower() ==
|
|
512
|
-
check_requirements(
|
|
542
|
+
mlmodel = self.args.format.lower() == "mlmodel" # legacy *.mlmodel export format requested
|
|
543
|
+
check_requirements("coremltools>=6.0,<=6.2" if mlmodel else "coremltools>=7.0")
|
|
513
544
|
import coremltools as ct # noqa
|
|
514
545
|
|
|
515
|
-
LOGGER.info(f
|
|
516
|
-
f = self.file.with_suffix(
|
|
546
|
+
LOGGER.info(f"\n{prefix} starting export with coremltools {ct.__version__}...")
|
|
547
|
+
f = self.file.with_suffix(".mlmodel" if mlmodel else ".mlpackage")
|
|
517
548
|
if f.is_dir():
|
|
518
549
|
shutil.rmtree(f)
|
|
519
550
|
|
|
520
551
|
bias = [0.0, 0.0, 0.0]
|
|
521
552
|
scale = 1 / 255
|
|
522
553
|
classifier_config = None
|
|
523
|
-
if self.model.task ==
|
|
554
|
+
if self.model.task == "classify":
|
|
524
555
|
classifier_config = ct.ClassifierConfig(list(self.model.names.values())) if self.args.nms else None
|
|
525
556
|
model = self.model
|
|
526
|
-
elif self.model.task ==
|
|
557
|
+
elif self.model.task == "detect":
|
|
527
558
|
model = IOSDetectModel(self.model, self.im) if self.args.nms else self.model
|
|
528
559
|
else:
|
|
529
560
|
if self.args.nms:
|
|
@@ -532,69 +563,73 @@ class Exporter:
|
|
|
532
563
|
model = self.model
|
|
533
564
|
|
|
534
565
|
ts = torch.jit.trace(model.eval(), self.im, strict=False) # TorchScript model
|
|
535
|
-
ct_model = ct.convert(
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
566
|
+
ct_model = ct.convert(
|
|
567
|
+
ts,
|
|
568
|
+
inputs=[ct.ImageType("image", shape=self.im.shape, scale=scale, bias=bias)],
|
|
569
|
+
classifier_config=classifier_config,
|
|
570
|
+
convert_to="neuralnetwork" if mlmodel else "mlprogram",
|
|
571
|
+
)
|
|
572
|
+
bits, mode = (8, "kmeans") if self.args.int8 else (16, "linear") if self.args.half else (32, None)
|
|
540
573
|
if bits < 32:
|
|
541
|
-
if
|
|
542
|
-
check_requirements(
|
|
574
|
+
if "kmeans" in mode:
|
|
575
|
+
check_requirements("scikit-learn") # scikit-learn package required for k-means quantization
|
|
543
576
|
if mlmodel:
|
|
544
577
|
ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
|
|
545
578
|
elif bits == 8: # mlprogram already quantized to FP16
|
|
546
579
|
import coremltools.optimize.coreml as cto
|
|
547
|
-
|
|
580
|
+
|
|
581
|
+
op_config = cto.OpPalettizerConfig(mode="kmeans", nbits=bits, weight_threshold=512)
|
|
548
582
|
config = cto.OptimizationConfig(global_config=op_config)
|
|
549
583
|
ct_model = cto.palettize_weights(ct_model, config=config)
|
|
550
|
-
if self.args.nms and self.model.task ==
|
|
584
|
+
if self.args.nms and self.model.task == "detect":
|
|
551
585
|
if mlmodel:
|
|
552
586
|
import platform
|
|
553
587
|
|
|
554
588
|
# coremltools<=6.2 NMS export requires Python<3.11
|
|
555
|
-
check_version(platform.python_version(),
|
|
589
|
+
check_version(platform.python_version(), "<3.11", name="Python ", hard=True)
|
|
556
590
|
weights_dir = None
|
|
557
591
|
else:
|
|
558
592
|
ct_model.save(str(f)) # save otherwise weights_dir does not exist
|
|
559
|
-
weights_dir = str(f /
|
|
593
|
+
weights_dir = str(f / "Data/com.apple.CoreML/weights")
|
|
560
594
|
ct_model = self._pipeline_coreml(ct_model, weights_dir=weights_dir)
|
|
561
595
|
|
|
562
596
|
m = self.metadata # metadata dict
|
|
563
|
-
ct_model.short_description = m.pop(
|
|
564
|
-
ct_model.author = m.pop(
|
|
565
|
-
ct_model.license = m.pop(
|
|
566
|
-
ct_model.version = m.pop(
|
|
597
|
+
ct_model.short_description = m.pop("description")
|
|
598
|
+
ct_model.author = m.pop("author")
|
|
599
|
+
ct_model.license = m.pop("license")
|
|
600
|
+
ct_model.version = m.pop("version")
|
|
567
601
|
ct_model.user_defined_metadata.update({k: str(v) for k, v in m.items()})
|
|
568
602
|
try:
|
|
569
603
|
ct_model.save(str(f)) # save *.mlpackage
|
|
570
604
|
except Exception as e:
|
|
571
605
|
LOGGER.warning(
|
|
572
|
-
f
|
|
573
|
-
f
|
|
574
|
-
|
|
606
|
+
f"{prefix} WARNING ⚠️ CoreML export to *.mlpackage failed ({e}), reverting to *.mlmodel export. "
|
|
607
|
+
f"Known coremltools Python 3.11 and Windows bugs https://github.com/apple/coremltools/issues/1928."
|
|
608
|
+
)
|
|
609
|
+
f = f.with_suffix(".mlmodel")
|
|
575
610
|
ct_model.save(str(f))
|
|
576
611
|
return f, ct_model
|
|
577
612
|
|
|
578
613
|
@try_export
|
|
579
|
-
def export_engine(self, prefix=colorstr(
|
|
614
|
+
def export_engine(self, prefix=colorstr("TensorRT:")):
|
|
580
615
|
"""YOLOv8 TensorRT export https://developer.nvidia.com/tensorrt."""
|
|
581
|
-
assert self.im.device.type !=
|
|
616
|
+
assert self.im.device.type != "cpu", "export running on CPU but must be on GPU, i.e. use 'device=0'"
|
|
582
617
|
f_onnx, _ = self.export_onnx() # run before trt import https://github.com/ultralytics/ultralytics/issues/7016
|
|
583
618
|
|
|
584
619
|
try:
|
|
585
620
|
import tensorrt as trt # noqa
|
|
586
621
|
except ImportError:
|
|
587
622
|
if LINUX:
|
|
588
|
-
check_requirements(
|
|
623
|
+
check_requirements("nvidia-tensorrt", cmds="-U --index-url https://pypi.ngc.nvidia.com")
|
|
589
624
|
import tensorrt as trt # noqa
|
|
590
625
|
|
|
591
|
-
check_version(trt.__version__,
|
|
626
|
+
check_version(trt.__version__, "7.0.0", hard=True) # require tensorrt>=7.0.0
|
|
592
627
|
|
|
593
628
|
self.args.simplify = True
|
|
594
629
|
|
|
595
|
-
LOGGER.info(f
|
|
596
|
-
assert Path(f_onnx).exists(), f
|
|
597
|
-
f = self.file.with_suffix(
|
|
630
|
+
LOGGER.info(f"\n{prefix} starting export with TensorRT {trt.__version__}...")
|
|
631
|
+
assert Path(f_onnx).exists(), f"failed to export ONNX file: {f_onnx}"
|
|
632
|
+
f = self.file.with_suffix(".engine") # TensorRT engine file
|
|
598
633
|
logger = trt.Logger(trt.Logger.INFO)
|
|
599
634
|
if self.args.verbose:
|
|
600
635
|
logger.min_severity = trt.Logger.Severity.VERBOSE
|
|
@@ -604,11 +639,11 @@ class Exporter:
|
|
|
604
639
|
config.max_workspace_size = self.args.workspace * 1 << 30
|
|
605
640
|
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice
|
|
606
641
|
|
|
607
|
-
flag =
|
|
642
|
+
flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
|
608
643
|
network = builder.create_network(flag)
|
|
609
644
|
parser = trt.OnnxParser(network, logger)
|
|
610
645
|
if not parser.parse_from_file(f_onnx):
|
|
611
|
-
raise RuntimeError(f
|
|
646
|
+
raise RuntimeError(f"failed to load ONNX file: {f_onnx}")
|
|
612
647
|
|
|
613
648
|
inputs = [network.get_input(i) for i in range(network.num_inputs)]
|
|
614
649
|
outputs = [network.get_output(i) for i in range(network.num_outputs)]
|
|
@@ -627,7 +662,8 @@ class Exporter:
|
|
|
627
662
|
config.add_optimization_profile(profile)
|
|
628
663
|
|
|
629
664
|
LOGGER.info(
|
|
630
|
-
f
|
|
665
|
+
f"{prefix} building FP{16 if builder.platform_has_fast_fp16 and self.args.half else 32} engine as {f}"
|
|
666
|
+
)
|
|
631
667
|
if builder.platform_has_fast_fp16 and self.args.half:
|
|
632
668
|
config.set_flag(trt.BuilderFlag.FP16)
|
|
633
669
|
|
|
@@ -635,10 +671,10 @@ class Exporter:
|
|
|
635
671
|
torch.cuda.empty_cache()
|
|
636
672
|
|
|
637
673
|
# Write file
|
|
638
|
-
with builder.build_engine(network, config) as engine, open(f,
|
|
674
|
+
with builder.build_engine(network, config) as engine, open(f, "wb") as t:
|
|
639
675
|
# Metadata
|
|
640
676
|
meta = json.dumps(self.metadata)
|
|
641
|
-
t.write(len(meta).to_bytes(4, byteorder=
|
|
677
|
+
t.write(len(meta).to_bytes(4, byteorder="little", signed=True))
|
|
642
678
|
t.write(meta.encode())
|
|
643
679
|
# Model
|
|
644
680
|
t.write(engine.serialize())
|
|
@@ -646,7 +682,7 @@ class Exporter:
|
|
|
646
682
|
return f, None
|
|
647
683
|
|
|
648
684
|
@try_export
|
|
649
|
-
def export_saved_model(self, prefix=colorstr(
|
|
685
|
+
def export_saved_model(self, prefix=colorstr("TensorFlow SavedModel:")):
|
|
650
686
|
"""YOLOv8 TensorFlow SavedModel export."""
|
|
651
687
|
cuda = torch.cuda.is_available()
|
|
652
688
|
try:
|
|
@@ -655,44 +691,55 @@ class Exporter:
|
|
|
655
691
|
check_requirements(f"tensorflow{'-macos' if MACOS else '-aarch64' if ARM64 else '' if cuda else '-cpu'}")
|
|
656
692
|
import tensorflow as tf # noqa
|
|
657
693
|
check_requirements(
|
|
658
|
-
(
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
694
|
+
(
|
|
695
|
+
"onnx",
|
|
696
|
+
"onnx2tf>=1.15.4,<=1.17.5",
|
|
697
|
+
"sng4onnx>=1.0.1",
|
|
698
|
+
"onnxsim>=0.4.33",
|
|
699
|
+
"onnx_graphsurgeon>=0.3.26",
|
|
700
|
+
"tflite_support",
|
|
701
|
+
"onnxruntime-gpu" if cuda else "onnxruntime",
|
|
702
|
+
),
|
|
703
|
+
cmds="--extra-index-url https://pypi.ngc.nvidia.com",
|
|
704
|
+
) # onnx_graphsurgeon only on NVIDIA
|
|
705
|
+
|
|
706
|
+
LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
|
|
707
|
+
check_version(
|
|
708
|
+
tf.__version__,
|
|
709
|
+
"<=2.13.1",
|
|
710
|
+
name="tensorflow",
|
|
711
|
+
verbose=True,
|
|
712
|
+
msg="https://github.com/ultralytics/ultralytics/issues/5161",
|
|
713
|
+
)
|
|
714
|
+
f = Path(str(self.file).replace(self.file.suffix, "_saved_model"))
|
|
669
715
|
if f.is_dir():
|
|
670
716
|
import shutil
|
|
717
|
+
|
|
671
718
|
shutil.rmtree(f) # delete output folder
|
|
672
719
|
|
|
673
720
|
# Pre-download calibration file to fix https://github.com/PINTO0309/onnx2tf/issues/545
|
|
674
|
-
onnx2tf_file = Path(
|
|
721
|
+
onnx2tf_file = Path("calibration_image_sample_data_20x128x128x3_float32.npy")
|
|
675
722
|
if not onnx2tf_file.exists():
|
|
676
|
-
attempt_download_asset(f
|
|
723
|
+
attempt_download_asset(f"{onnx2tf_file}.zip", unzip=True, delete=True)
|
|
677
724
|
|
|
678
725
|
# Export to ONNX
|
|
679
726
|
self.args.simplify = True
|
|
680
727
|
f_onnx, _ = self.export_onnx()
|
|
681
728
|
|
|
682
729
|
# Export to TF
|
|
683
|
-
tmp_file = f /
|
|
730
|
+
tmp_file = f / "tmp_tflite_int8_calibration_images.npy" # int8 calibration images file
|
|
684
731
|
if self.args.int8:
|
|
685
|
-
verbosity =
|
|
732
|
+
verbosity = "--verbosity info"
|
|
686
733
|
if self.args.data:
|
|
687
734
|
# Generate calibration data for integer quantization
|
|
688
735
|
LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'")
|
|
689
736
|
data = check_det_dataset(self.args.data)
|
|
690
|
-
dataset = YOLODataset(data[
|
|
737
|
+
dataset = YOLODataset(data["val"], data=data, imgsz=self.imgsz[0], augment=False)
|
|
691
738
|
images = []
|
|
692
739
|
for i, batch in enumerate(dataset):
|
|
693
740
|
if i >= 100: # maximum number of calibration images
|
|
694
741
|
break
|
|
695
|
-
im = batch[
|
|
742
|
+
im = batch["img"].permute(1, 2, 0)[None] # list to nparray, CHW to BHWC
|
|
696
743
|
images.append(im)
|
|
697
744
|
f.mkdir()
|
|
698
745
|
images = torch.cat(images, 0).float()
|
|
@@ -701,38 +748,38 @@ class Exporter:
|
|
|
701
748
|
np.save(str(tmp_file), images.numpy()) # BHWC
|
|
702
749
|
int8 = f'-oiqt -qt per-tensor -cind images "{tmp_file}" "[[[[0, 0, 0]]]]" "[[[[255, 255, 255]]]]"'
|
|
703
750
|
else:
|
|
704
|
-
int8 =
|
|
751
|
+
int8 = "-oiqt -qt per-tensor"
|
|
705
752
|
else:
|
|
706
|
-
verbosity =
|
|
707
|
-
int8 =
|
|
753
|
+
verbosity = "--non_verbose"
|
|
754
|
+
int8 = ""
|
|
708
755
|
|
|
709
756
|
cmd = f'onnx2tf -i "{f_onnx}" -o "{f}" -nuo {verbosity} {int8}'.strip()
|
|
710
757
|
LOGGER.info(f"{prefix} running '{cmd}'")
|
|
711
758
|
subprocess.run(cmd, shell=True)
|
|
712
|
-
yaml_save(f /
|
|
759
|
+
yaml_save(f / "metadata.yaml", self.metadata) # add metadata.yaml
|
|
713
760
|
|
|
714
761
|
# Remove/rename TFLite models
|
|
715
762
|
if self.args.int8:
|
|
716
763
|
tmp_file.unlink(missing_ok=True)
|
|
717
|
-
for file in f.rglob(
|
|
718
|
-
file.rename(file.with_name(file.stem.replace(
|
|
719
|
-
for file in f.rglob(
|
|
764
|
+
for file in f.rglob("*_dynamic_range_quant.tflite"):
|
|
765
|
+
file.rename(file.with_name(file.stem.replace("_dynamic_range_quant", "_int8") + file.suffix))
|
|
766
|
+
for file in f.rglob("*_integer_quant_with_int16_act.tflite"):
|
|
720
767
|
file.unlink() # delete extra fp16 activation TFLite files
|
|
721
768
|
|
|
722
769
|
# Add TFLite metadata
|
|
723
|
-
for file in f.rglob(
|
|
724
|
-
f.unlink() if
|
|
770
|
+
for file in f.rglob("*.tflite"):
|
|
771
|
+
f.unlink() if "quant_with_int16_act.tflite" in str(f) else self._add_tflite_metadata(file)
|
|
725
772
|
|
|
726
773
|
return str(f), tf.saved_model.load(f, tags=None, options=None) # load saved_model as Keras model
|
|
727
774
|
|
|
728
775
|
@try_export
|
|
729
|
-
def export_pb(self, keras_model, prefix=colorstr(
|
|
776
|
+
def export_pb(self, keras_model, prefix=colorstr("TensorFlow GraphDef:")):
|
|
730
777
|
"""YOLOv8 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow."""
|
|
731
778
|
import tensorflow as tf # noqa
|
|
732
779
|
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa
|
|
733
780
|
|
|
734
|
-
LOGGER.info(f
|
|
735
|
-
f = self.file.with_suffix(
|
|
781
|
+
LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
|
|
782
|
+
f = self.file.with_suffix(".pb")
|
|
736
783
|
|
|
737
784
|
m = tf.function(lambda x: keras_model(x)) # full model
|
|
738
785
|
m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
|
|
@@ -742,40 +789,43 @@ class Exporter:
|
|
|
742
789
|
return f, None
|
|
743
790
|
|
|
744
791
|
@try_export
|
|
745
|
-
def export_tflite(self, keras_model, nms, agnostic_nms, prefix=colorstr(
|
|
792
|
+
def export_tflite(self, keras_model, nms, agnostic_nms, prefix=colorstr("TensorFlow Lite:")):
|
|
746
793
|
"""YOLOv8 TensorFlow Lite export."""
|
|
747
794
|
import tensorflow as tf # noqa
|
|
748
795
|
|
|
749
|
-
LOGGER.info(f
|
|
750
|
-
saved_model = Path(str(self.file).replace(self.file.suffix,
|
|
796
|
+
LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
|
|
797
|
+
saved_model = Path(str(self.file).replace(self.file.suffix, "_saved_model"))
|
|
751
798
|
if self.args.int8:
|
|
752
|
-
f = saved_model / f
|
|
799
|
+
f = saved_model / f"{self.file.stem}_int8.tflite" # fp32 in/out
|
|
753
800
|
elif self.args.half:
|
|
754
|
-
f = saved_model / f
|
|
801
|
+
f = saved_model / f"{self.file.stem}_float16.tflite" # fp32 in/out
|
|
755
802
|
else:
|
|
756
|
-
f = saved_model / f
|
|
803
|
+
f = saved_model / f"{self.file.stem}_float32.tflite"
|
|
757
804
|
return str(f), None
|
|
758
805
|
|
|
759
806
|
@try_export
|
|
760
|
-
def export_edgetpu(self, tflite_model=
|
|
807
|
+
def export_edgetpu(self, tflite_model="", prefix=colorstr("Edge TPU:")):
|
|
761
808
|
"""YOLOv8 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/."""
|
|
762
|
-
LOGGER.warning(f
|
|
809
|
+
LOGGER.warning(f"{prefix} WARNING ⚠️ Edge TPU known bug https://github.com/ultralytics/ultralytics/issues/1185")
|
|
763
810
|
|
|
764
|
-
cmd =
|
|
765
|
-
help_url =
|
|
766
|
-
assert LINUX, f
|
|
811
|
+
cmd = "edgetpu_compiler --version"
|
|
812
|
+
help_url = "https://coral.ai/docs/edgetpu/compiler/"
|
|
813
|
+
assert LINUX, f"export only supported on Linux. See {help_url}"
|
|
767
814
|
if subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True).returncode != 0:
|
|
768
|
-
LOGGER.info(f
|
|
769
|
-
sudo = subprocess.run(
|
|
770
|
-
for c in (
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
815
|
+
LOGGER.info(f"\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}")
|
|
816
|
+
sudo = subprocess.run("sudo --version >/dev/null", shell=True).returncode == 0 # sudo installed on system
|
|
817
|
+
for c in (
|
|
818
|
+
"curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -",
|
|
819
|
+
'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | '
|
|
820
|
+
"sudo tee /etc/apt/sources.list.d/coral-edgetpu.list",
|
|
821
|
+
"sudo apt-get update",
|
|
822
|
+
"sudo apt-get install edgetpu-compiler",
|
|
823
|
+
):
|
|
824
|
+
subprocess.run(c if sudo else c.replace("sudo ", ""), shell=True, check=True)
|
|
775
825
|
ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]
|
|
776
826
|
|
|
777
|
-
LOGGER.info(f
|
|
778
|
-
f = str(tflite_model).replace(
|
|
827
|
+
LOGGER.info(f"\n{prefix} starting export with Edge TPU compiler {ver}...")
|
|
828
|
+
f = str(tflite_model).replace(".tflite", "_edgetpu.tflite") # Edge TPU model
|
|
779
829
|
|
|
780
830
|
cmd = f'edgetpu_compiler -s -d -k 10 --out_dir "{Path(f).parent}" "{tflite_model}"'
|
|
781
831
|
LOGGER.info(f"{prefix} running '{cmd}'")
|
|
@@ -784,30 +834,30 @@ class Exporter:
|
|
|
784
834
|
return f, None
|
|
785
835
|
|
|
786
836
|
@try_export
|
|
787
|
-
def export_tfjs(self, prefix=colorstr(
|
|
837
|
+
def export_tfjs(self, prefix=colorstr("TensorFlow.js:")):
|
|
788
838
|
"""YOLOv8 TensorFlow.js export."""
|
|
789
839
|
# JAX bug requiring install constraints in https://github.com/google/jax/issues/18978
|
|
790
|
-
check_requirements([
|
|
840
|
+
check_requirements(["jax<=0.4.21", "jaxlib<=0.4.21", "tensorflowjs"])
|
|
791
841
|
import tensorflow as tf
|
|
792
842
|
import tensorflowjs as tfjs # noqa
|
|
793
843
|
|
|
794
|
-
LOGGER.info(f
|
|
795
|
-
f = str(self.file).replace(self.file.suffix,
|
|
796
|
-
f_pb = str(self.file.with_suffix(
|
|
844
|
+
LOGGER.info(f"\n{prefix} starting export with tensorflowjs {tfjs.__version__}...")
|
|
845
|
+
f = str(self.file).replace(self.file.suffix, "_web_model") # js dir
|
|
846
|
+
f_pb = str(self.file.with_suffix(".pb")) # *.pb path
|
|
797
847
|
|
|
798
848
|
gd = tf.Graph().as_graph_def() # TF GraphDef
|
|
799
|
-
with open(f_pb,
|
|
849
|
+
with open(f_pb, "rb") as file:
|
|
800
850
|
gd.ParseFromString(file.read())
|
|
801
|
-
outputs =
|
|
802
|
-
LOGGER.info(f
|
|
851
|
+
outputs = ",".join(gd_outputs(gd))
|
|
852
|
+
LOGGER.info(f"\n{prefix} output node names: {outputs}")
|
|
803
853
|
|
|
804
|
-
quantization =
|
|
854
|
+
quantization = "--quantize_float16" if self.args.half else "--quantize_uint8" if self.args.int8 else ""
|
|
805
855
|
with spaces_in_path(f_pb) as fpb_, spaces_in_path(f) as f_: # exporter can not handle spaces in path
|
|
806
856
|
cmd = f'tensorflowjs_converter --input_format=tf_frozen_model {quantization} --output_node_names={outputs} "{fpb_}" "{f_}"'
|
|
807
857
|
LOGGER.info(f"{prefix} running '{cmd}'")
|
|
808
858
|
subprocess.run(cmd, shell=True)
|
|
809
859
|
|
|
810
|
-
if
|
|
860
|
+
if " " in f:
|
|
811
861
|
LOGGER.warning(f"{prefix} WARNING ⚠️ your model may not work correctly with spaces in path '{f}'.")
|
|
812
862
|
|
|
813
863
|
# f_json = Path(f) / 'model.json' # *.json path
|
|
@@ -824,7 +874,7 @@ class Exporter:
|
|
|
824
874
|
# f_json.read_text(),
|
|
825
875
|
# )
|
|
826
876
|
# j.write(subst)
|
|
827
|
-
yaml_save(Path(f) /
|
|
877
|
+
yaml_save(Path(f) / "metadata.yaml", self.metadata) # add metadata.yaml
|
|
828
878
|
return f, None
|
|
829
879
|
|
|
830
880
|
def _add_tflite_metadata(self, file):
|
|
@@ -835,14 +885,14 @@ class Exporter:
|
|
|
835
885
|
|
|
836
886
|
# Create model info
|
|
837
887
|
model_meta = _metadata_fb.ModelMetadataT()
|
|
838
|
-
model_meta.name = self.metadata[
|
|
839
|
-
model_meta.version = self.metadata[
|
|
840
|
-
model_meta.author = self.metadata[
|
|
841
|
-
model_meta.license = self.metadata[
|
|
888
|
+
model_meta.name = self.metadata["description"]
|
|
889
|
+
model_meta.version = self.metadata["version"]
|
|
890
|
+
model_meta.author = self.metadata["author"]
|
|
891
|
+
model_meta.license = self.metadata["license"]
|
|
842
892
|
|
|
843
893
|
# Label file
|
|
844
|
-
tmp_file = Path(file).parent /
|
|
845
|
-
with open(tmp_file,
|
|
894
|
+
tmp_file = Path(file).parent / "temp_meta.txt"
|
|
895
|
+
with open(tmp_file, "w") as f:
|
|
846
896
|
f.write(str(self.metadata))
|
|
847
897
|
|
|
848
898
|
label_file = _metadata_fb.AssociatedFileT()
|
|
@@ -851,8 +901,8 @@ class Exporter:
|
|
|
851
901
|
|
|
852
902
|
# Create input info
|
|
853
903
|
input_meta = _metadata_fb.TensorMetadataT()
|
|
854
|
-
input_meta.name =
|
|
855
|
-
input_meta.description =
|
|
904
|
+
input_meta.name = "image"
|
|
905
|
+
input_meta.description = "Input image to be detected."
|
|
856
906
|
input_meta.content = _metadata_fb.ContentT()
|
|
857
907
|
input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
|
|
858
908
|
input_meta.content.contentProperties.colorSpace = _metadata_fb.ColorSpaceType.RGB
|
|
@@ -860,19 +910,19 @@ class Exporter:
|
|
|
860
910
|
|
|
861
911
|
# Create output info
|
|
862
912
|
output1 = _metadata_fb.TensorMetadataT()
|
|
863
|
-
output1.name =
|
|
864
|
-
output1.description =
|
|
913
|
+
output1.name = "output"
|
|
914
|
+
output1.description = "Coordinates of detected objects, class labels, and confidence score"
|
|
865
915
|
output1.associatedFiles = [label_file]
|
|
866
|
-
if self.model.task ==
|
|
916
|
+
if self.model.task == "segment":
|
|
867
917
|
output2 = _metadata_fb.TensorMetadataT()
|
|
868
|
-
output2.name =
|
|
869
|
-
output2.description =
|
|
918
|
+
output2.name = "output"
|
|
919
|
+
output2.description = "Mask protos"
|
|
870
920
|
output2.associatedFiles = [label_file]
|
|
871
921
|
|
|
872
922
|
# Create subgraph info
|
|
873
923
|
subgraph = _metadata_fb.SubGraphMetadataT()
|
|
874
924
|
subgraph.inputTensorMetadata = [input_meta]
|
|
875
|
-
subgraph.outputTensorMetadata = [output1, output2] if self.model.task ==
|
|
925
|
+
subgraph.outputTensorMetadata = [output1, output2] if self.model.task == "segment" else [output1]
|
|
876
926
|
model_meta.subgraphMetadata = [subgraph]
|
|
877
927
|
|
|
878
928
|
b = flatbuffers.Builder(0)
|
|
@@ -885,11 +935,11 @@ class Exporter:
|
|
|
885
935
|
populator.populate()
|
|
886
936
|
tmp_file.unlink()
|
|
887
937
|
|
|
888
|
-
def _pipeline_coreml(self, model, weights_dir=None, prefix=colorstr(
|
|
938
|
+
def _pipeline_coreml(self, model, weights_dir=None, prefix=colorstr("CoreML Pipeline:")):
|
|
889
939
|
"""YOLOv8 CoreML pipeline."""
|
|
890
940
|
import coremltools as ct # noqa
|
|
891
941
|
|
|
892
|
-
LOGGER.info(f
|
|
942
|
+
LOGGER.info(f"{prefix} starting pipeline with coremltools {ct.__version__}...")
|
|
893
943
|
_, _, h, w = list(self.im.shape) # BCHW
|
|
894
944
|
|
|
895
945
|
# Output shapes
|
|
@@ -897,8 +947,9 @@ class Exporter:
|
|
|
897
947
|
out0, out1 = iter(spec.description.output)
|
|
898
948
|
if MACOS:
|
|
899
949
|
from PIL import Image
|
|
900
|
-
|
|
901
|
-
|
|
950
|
+
|
|
951
|
+
img = Image.new("RGB", (w, h)) # w=192, h=320
|
|
952
|
+
out = model.predict({"image": img})
|
|
902
953
|
out0_shape = out[out0.name].shape # (3780, 80)
|
|
903
954
|
out1_shape = out[out1.name].shape # (3780, 4)
|
|
904
955
|
else: # linux and windows can not run model.predict(), get sizes from PyTorch model output y
|
|
@@ -906,11 +957,11 @@ class Exporter:
|
|
|
906
957
|
out1_shape = self.output_shape[2], 4 # (3780, 4)
|
|
907
958
|
|
|
908
959
|
# Checks
|
|
909
|
-
names = self.metadata[
|
|
960
|
+
names = self.metadata["names"]
|
|
910
961
|
nx, ny = spec.description.input[0].type.imageType.width, spec.description.input[0].type.imageType.height
|
|
911
962
|
_, nc = out0_shape # number of anchors, number of classes
|
|
912
963
|
# _, nc = out0.type.multiArrayType.shape
|
|
913
|
-
assert len(names) == nc, f
|
|
964
|
+
assert len(names) == nc, f"{len(names)} names found for nc={nc}" # check
|
|
914
965
|
|
|
915
966
|
# Define output shapes (missing)
|
|
916
967
|
out0.type.multiArrayType.shape[:] = out0_shape # (3780, 80)
|
|
@@ -944,8 +995,8 @@ class Exporter:
|
|
|
944
995
|
nms_spec.description.output.add()
|
|
945
996
|
nms_spec.description.output[i].ParseFromString(decoder_output)
|
|
946
997
|
|
|
947
|
-
nms_spec.description.output[0].name =
|
|
948
|
-
nms_spec.description.output[1].name =
|
|
998
|
+
nms_spec.description.output[0].name = "confidence"
|
|
999
|
+
nms_spec.description.output[1].name = "coordinates"
|
|
949
1000
|
|
|
950
1001
|
output_sizes = [nc, 4]
|
|
951
1002
|
for i in range(2):
|
|
@@ -961,10 +1012,10 @@ class Exporter:
|
|
|
961
1012
|
nms = nms_spec.nonMaximumSuppression
|
|
962
1013
|
nms.confidenceInputFeatureName = out0.name # 1x507x80
|
|
963
1014
|
nms.coordinatesInputFeatureName = out1.name # 1x507x4
|
|
964
|
-
nms.confidenceOutputFeatureName =
|
|
965
|
-
nms.coordinatesOutputFeatureName =
|
|
966
|
-
nms.iouThresholdInputFeatureName =
|
|
967
|
-
nms.confidenceThresholdInputFeatureName =
|
|
1015
|
+
nms.confidenceOutputFeatureName = "confidence"
|
|
1016
|
+
nms.coordinatesOutputFeatureName = "coordinates"
|
|
1017
|
+
nms.iouThresholdInputFeatureName = "iouThreshold"
|
|
1018
|
+
nms.confidenceThresholdInputFeatureName = "confidenceThreshold"
|
|
968
1019
|
nms.iouThreshold = 0.45
|
|
969
1020
|
nms.confidenceThreshold = 0.25
|
|
970
1021
|
nms.pickTop.perClass = True
|
|
@@ -972,10 +1023,14 @@ class Exporter:
|
|
|
972
1023
|
nms_model = ct.models.MLModel(nms_spec)
|
|
973
1024
|
|
|
974
1025
|
# 4. Pipeline models together
|
|
975
|
-
pipeline = ct.models.pipeline.Pipeline(
|
|
976
|
-
|
|
977
|
-
|
|
978
|
-
|
|
1026
|
+
pipeline = ct.models.pipeline.Pipeline(
|
|
1027
|
+
input_features=[
|
|
1028
|
+
("image", ct.models.datatypes.Array(3, ny, nx)),
|
|
1029
|
+
("iouThreshold", ct.models.datatypes.Double()),
|
|
1030
|
+
("confidenceThreshold", ct.models.datatypes.Double()),
|
|
1031
|
+
],
|
|
1032
|
+
output_features=["confidence", "coordinates"],
|
|
1033
|
+
)
|
|
979
1034
|
pipeline.add_model(model)
|
|
980
1035
|
pipeline.add_model(nms_model)
|
|
981
1036
|
|
|
@@ -986,19 +1041,20 @@ class Exporter:
|
|
|
986
1041
|
|
|
987
1042
|
# Update metadata
|
|
988
1043
|
pipeline.spec.specificationVersion = 5
|
|
989
|
-
pipeline.spec.description.metadata.userDefined.update(
|
|
990
|
-
|
|
991
|
-
|
|
1044
|
+
pipeline.spec.description.metadata.userDefined.update(
|
|
1045
|
+
{"IoU threshold": str(nms.iouThreshold), "Confidence threshold": str(nms.confidenceThreshold)}
|
|
1046
|
+
)
|
|
992
1047
|
|
|
993
1048
|
# Save the model
|
|
994
1049
|
model = ct.models.MLModel(pipeline.spec, weights_dir=weights_dir)
|
|
995
|
-
model.input_description[
|
|
996
|
-
model.input_description[
|
|
997
|
-
model.input_description[
|
|
998
|
-
|
|
999
|
-
|
|
1000
|
-
model.output_description[
|
|
1001
|
-
|
|
1050
|
+
model.input_description["image"] = "Input image"
|
|
1051
|
+
model.input_description["iouThreshold"] = f"(optional) IOU threshold override (default: {nms.iouThreshold})"
|
|
1052
|
+
model.input_description[
|
|
1053
|
+
"confidenceThreshold"
|
|
1054
|
+
] = f"(optional) Confidence threshold override (default: {nms.confidenceThreshold})"
|
|
1055
|
+
model.output_description["confidence"] = 'Boxes × Class confidence (see user-defined metadata "classes")'
|
|
1056
|
+
model.output_description["coordinates"] = "Boxes × [x, y, width, height] (relative to image size)"
|
|
1057
|
+
LOGGER.info(f"{prefix} pipeline success")
|
|
1002
1058
|
return model
|
|
1003
1059
|
|
|
1004
1060
|
def add_callback(self, event: str, callback):
|