ultralytics 8.0.238__py3-none-any.whl → 8.0.239__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (134) hide show
  1. ultralytics/__init__.py +2 -2
  2. ultralytics/cfg/__init__.py +241 -138
  3. ultralytics/data/__init__.py +9 -2
  4. ultralytics/data/annotator.py +4 -4
  5. ultralytics/data/augment.py +186 -169
  6. ultralytics/data/base.py +54 -48
  7. ultralytics/data/build.py +34 -23
  8. ultralytics/data/converter.py +242 -70
  9. ultralytics/data/dataset.py +117 -95
  10. ultralytics/data/explorer/__init__.py +3 -1
  11. ultralytics/data/explorer/explorer.py +120 -100
  12. ultralytics/data/explorer/gui/__init__.py +1 -0
  13. ultralytics/data/explorer/gui/dash.py +123 -89
  14. ultralytics/data/explorer/utils.py +37 -39
  15. ultralytics/data/loaders.py +75 -62
  16. ultralytics/data/split_dota.py +44 -36
  17. ultralytics/data/utils.py +160 -142
  18. ultralytics/engine/exporter.py +348 -292
  19. ultralytics/engine/model.py +102 -66
  20. ultralytics/engine/predictor.py +74 -55
  21. ultralytics/engine/results.py +61 -41
  22. ultralytics/engine/trainer.py +192 -144
  23. ultralytics/engine/tuner.py +66 -59
  24. ultralytics/engine/validator.py +31 -26
  25. ultralytics/hub/__init__.py +54 -31
  26. ultralytics/hub/auth.py +28 -25
  27. ultralytics/hub/session.py +282 -133
  28. ultralytics/hub/utils.py +64 -42
  29. ultralytics/models/__init__.py +1 -1
  30. ultralytics/models/fastsam/__init__.py +1 -1
  31. ultralytics/models/fastsam/model.py +6 -6
  32. ultralytics/models/fastsam/predict.py +3 -2
  33. ultralytics/models/fastsam/prompt.py +55 -48
  34. ultralytics/models/fastsam/val.py +1 -1
  35. ultralytics/models/nas/__init__.py +1 -1
  36. ultralytics/models/nas/model.py +9 -8
  37. ultralytics/models/nas/predict.py +8 -6
  38. ultralytics/models/nas/val.py +11 -9
  39. ultralytics/models/rtdetr/__init__.py +1 -1
  40. ultralytics/models/rtdetr/model.py +11 -9
  41. ultralytics/models/rtdetr/train.py +18 -16
  42. ultralytics/models/rtdetr/val.py +25 -19
  43. ultralytics/models/sam/__init__.py +1 -1
  44. ultralytics/models/sam/amg.py +13 -14
  45. ultralytics/models/sam/build.py +44 -42
  46. ultralytics/models/sam/model.py +6 -6
  47. ultralytics/models/sam/modules/decoders.py +6 -4
  48. ultralytics/models/sam/modules/encoders.py +37 -35
  49. ultralytics/models/sam/modules/sam.py +5 -4
  50. ultralytics/models/sam/modules/tiny_encoder.py +95 -73
  51. ultralytics/models/sam/modules/transformer.py +3 -2
  52. ultralytics/models/sam/predict.py +39 -27
  53. ultralytics/models/utils/loss.py +99 -95
  54. ultralytics/models/utils/ops.py +34 -31
  55. ultralytics/models/yolo/__init__.py +1 -1
  56. ultralytics/models/yolo/classify/__init__.py +1 -1
  57. ultralytics/models/yolo/classify/predict.py +8 -6
  58. ultralytics/models/yolo/classify/train.py +37 -31
  59. ultralytics/models/yolo/classify/val.py +26 -24
  60. ultralytics/models/yolo/detect/__init__.py +1 -1
  61. ultralytics/models/yolo/detect/predict.py +8 -6
  62. ultralytics/models/yolo/detect/train.py +47 -37
  63. ultralytics/models/yolo/detect/val.py +100 -82
  64. ultralytics/models/yolo/model.py +31 -25
  65. ultralytics/models/yolo/obb/__init__.py +1 -1
  66. ultralytics/models/yolo/obb/predict.py +13 -11
  67. ultralytics/models/yolo/obb/train.py +3 -3
  68. ultralytics/models/yolo/obb/val.py +70 -59
  69. ultralytics/models/yolo/pose/__init__.py +1 -1
  70. ultralytics/models/yolo/pose/predict.py +17 -12
  71. ultralytics/models/yolo/pose/train.py +28 -25
  72. ultralytics/models/yolo/pose/val.py +91 -64
  73. ultralytics/models/yolo/segment/__init__.py +1 -1
  74. ultralytics/models/yolo/segment/predict.py +10 -8
  75. ultralytics/models/yolo/segment/train.py +16 -15
  76. ultralytics/models/yolo/segment/val.py +90 -68
  77. ultralytics/nn/__init__.py +26 -6
  78. ultralytics/nn/autobackend.py +144 -112
  79. ultralytics/nn/modules/__init__.py +96 -13
  80. ultralytics/nn/modules/block.py +28 -7
  81. ultralytics/nn/modules/conv.py +41 -23
  82. ultralytics/nn/modules/head.py +60 -52
  83. ultralytics/nn/modules/transformer.py +49 -32
  84. ultralytics/nn/modules/utils.py +20 -15
  85. ultralytics/nn/tasks.py +215 -141
  86. ultralytics/solutions/ai_gym.py +59 -47
  87. ultralytics/solutions/distance_calculation.py +17 -14
  88. ultralytics/solutions/heatmap.py +57 -55
  89. ultralytics/solutions/object_counter.py +46 -39
  90. ultralytics/solutions/speed_estimation.py +13 -16
  91. ultralytics/trackers/__init__.py +1 -1
  92. ultralytics/trackers/basetrack.py +1 -0
  93. ultralytics/trackers/bot_sort.py +2 -1
  94. ultralytics/trackers/byte_tracker.py +10 -7
  95. ultralytics/trackers/track.py +7 -7
  96. ultralytics/trackers/utils/gmc.py +25 -25
  97. ultralytics/trackers/utils/kalman_filter.py +85 -42
  98. ultralytics/trackers/utils/matching.py +8 -7
  99. ultralytics/utils/__init__.py +173 -152
  100. ultralytics/utils/autobatch.py +10 -10
  101. ultralytics/utils/benchmarks.py +76 -86
  102. ultralytics/utils/callbacks/__init__.py +1 -1
  103. ultralytics/utils/callbacks/base.py +29 -29
  104. ultralytics/utils/callbacks/clearml.py +51 -43
  105. ultralytics/utils/callbacks/comet.py +81 -66
  106. ultralytics/utils/callbacks/dvc.py +33 -26
  107. ultralytics/utils/callbacks/hub.py +44 -26
  108. ultralytics/utils/callbacks/mlflow.py +31 -24
  109. ultralytics/utils/callbacks/neptune.py +35 -25
  110. ultralytics/utils/callbacks/raytune.py +9 -4
  111. ultralytics/utils/callbacks/tensorboard.py +16 -11
  112. ultralytics/utils/callbacks/wb.py +39 -33
  113. ultralytics/utils/checks.py +189 -141
  114. ultralytics/utils/dist.py +15 -12
  115. ultralytics/utils/downloads.py +112 -96
  116. ultralytics/utils/errors.py +1 -1
  117. ultralytics/utils/files.py +11 -11
  118. ultralytics/utils/instance.py +22 -22
  119. ultralytics/utils/loss.py +117 -67
  120. ultralytics/utils/metrics.py +224 -158
  121. ultralytics/utils/ops.py +38 -28
  122. ultralytics/utils/patches.py +3 -3
  123. ultralytics/utils/plotting.py +217 -120
  124. ultralytics/utils/tal.py +19 -13
  125. ultralytics/utils/torch_utils.py +138 -109
  126. ultralytics/utils/triton.py +12 -10
  127. ultralytics/utils/tuner.py +49 -47
  128. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/METADATA +2 -1
  129. ultralytics-8.0.239.dist-info/RECORD +188 -0
  130. ultralytics-8.0.238.dist-info/RECORD +0 -188
  131. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
  132. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
  133. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
  134. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
@@ -67,8 +67,20 @@ from ultralytics.data.utils import check_det_dataset
67
67
  from ultralytics.nn.autobackend import check_class_names, default_class_names
68
68
  from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder
69
69
  from ultralytics.nn.tasks import DetectionModel, SegmentationModel
70
- from ultralytics.utils import (ARM64, DEFAULT_CFG, LINUX, LOGGER, MACOS, ROOT, WINDOWS, __version__, callbacks,
71
- colorstr, get_default_args, yaml_save)
70
+ from ultralytics.utils import (
71
+ ARM64,
72
+ DEFAULT_CFG,
73
+ LINUX,
74
+ LOGGER,
75
+ MACOS,
76
+ ROOT,
77
+ WINDOWS,
78
+ __version__,
79
+ callbacks,
80
+ colorstr,
81
+ get_default_args,
82
+ yaml_save,
83
+ )
72
84
  from ultralytics.utils.checks import check_imgsz, check_is_path_safe, check_requirements, check_version
73
85
  from ultralytics.utils.downloads import attempt_download_asset, get_github_assets
74
86
  from ultralytics.utils.files import file_size, spaces_in_path
@@ -79,21 +91,23 @@ from ultralytics.utils.torch_utils import get_latest_opset, select_device, smart
79
91
  def export_formats():
80
92
  """YOLOv8 export formats."""
81
93
  import pandas
94
+
82
95
  x = [
83
- ['PyTorch', '-', '.pt', True, True],
84
- ['TorchScript', 'torchscript', '.torchscript', True, True],
85
- ['ONNX', 'onnx', '.onnx', True, True],
86
- ['OpenVINO', 'openvino', '_openvino_model', True, False],
87
- ['TensorRT', 'engine', '.engine', False, True],
88
- ['CoreML', 'coreml', '.mlpackage', True, False],
89
- ['TensorFlow SavedModel', 'saved_model', '_saved_model', True, True],
90
- ['TensorFlow GraphDef', 'pb', '.pb', True, True],
91
- ['TensorFlow Lite', 'tflite', '.tflite', True, False],
92
- ['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', True, False],
93
- ['TensorFlow.js', 'tfjs', '_web_model', True, False],
94
- ['PaddlePaddle', 'paddle', '_paddle_model', True, True],
95
- ['ncnn', 'ncnn', '_ncnn_model', True, True], ]
96
- return pandas.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
96
+ ["PyTorch", "-", ".pt", True, True],
97
+ ["TorchScript", "torchscript", ".torchscript", True, True],
98
+ ["ONNX", "onnx", ".onnx", True, True],
99
+ ["OpenVINO", "openvino", "_openvino_model", True, False],
100
+ ["TensorRT", "engine", ".engine", False, True],
101
+ ["CoreML", "coreml", ".mlpackage", True, False],
102
+ ["TensorFlow SavedModel", "saved_model", "_saved_model", True, True],
103
+ ["TensorFlow GraphDef", "pb", ".pb", True, True],
104
+ ["TensorFlow Lite", "tflite", ".tflite", True, False],
105
+ ["TensorFlow Edge TPU", "edgetpu", "_edgetpu.tflite", True, False],
106
+ ["TensorFlow.js", "tfjs", "_web_model", True, False],
107
+ ["PaddlePaddle", "paddle", "_paddle_model", True, True],
108
+ ["ncnn", "ncnn", "_ncnn_model", True, True],
109
+ ]
110
+ return pandas.DataFrame(x, columns=["Format", "Argument", "Suffix", "CPU", "GPU"])
97
111
 
98
112
 
99
113
  def gd_outputs(gd):
@@ -102,7 +116,7 @@ def gd_outputs(gd):
102
116
  for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef
103
117
  name_list.append(node.name)
104
118
  input_list.extend(node.input)
105
- return sorted(f'{x}:0' for x in list(set(name_list) - set(input_list)) if not x.startswith('NoOp'))
119
+ return sorted(f"{x}:0" for x in list(set(name_list) - set(input_list)) if not x.startswith("NoOp"))
106
120
 
107
121
 
108
122
  def try_export(inner_func):
@@ -111,14 +125,14 @@ def try_export(inner_func):
111
125
 
112
126
  def outer_func(*args, **kwargs):
113
127
  """Export a model."""
114
- prefix = inner_args['prefix']
128
+ prefix = inner_args["prefix"]
115
129
  try:
116
130
  with Profile() as dt:
117
131
  f, model = inner_func(*args, **kwargs)
118
132
  LOGGER.info(f"{prefix} export success ✅ {dt.t:.1f}s, saved as '{f}' ({file_size(f):.1f} MB)")
119
133
  return f, model
120
134
  except Exception as e:
121
- LOGGER.info(f'{prefix} export failure ❌ {dt.t:.1f}s: {e}')
135
+ LOGGER.info(f"{prefix} export failure ❌ {dt.t:.1f}s: {e}")
122
136
  raise e
123
137
 
124
138
  return outer_func
@@ -143,8 +157,8 @@ class Exporter:
143
157
  _callbacks (dict, optional): Dictionary of callback functions. Defaults to None.
144
158
  """
145
159
  self.args = get_cfg(cfg, overrides)
146
- if self.args.format.lower() in ('coreml', 'mlmodel'): # fix attempt for protobuf<3.20.x errors
147
- os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' # must run before TensorBoard callback
160
+ if self.args.format.lower() in ("coreml", "mlmodel"): # fix attempt for protobuf<3.20.x errors
161
+ os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" # must run before TensorBoard callback
148
162
 
149
163
  self.callbacks = _callbacks or callbacks.get_default_callbacks()
150
164
  callbacks.add_integration_callbacks(self)
@@ -152,45 +166,46 @@ class Exporter:
152
166
  @smart_inference_mode()
153
167
  def __call__(self, model=None):
154
168
  """Returns list of exported files/dirs after running callbacks."""
155
- self.run_callbacks('on_export_start')
169
+ self.run_callbacks("on_export_start")
156
170
  t = time.time()
157
171
  fmt = self.args.format.lower() # to lowercase
158
- if fmt in ('tensorrt', 'trt'): # 'engine' aliases
159
- fmt = 'engine'
160
- if fmt in ('mlmodel', 'mlpackage', 'mlprogram', 'apple', 'ios', 'coreml'): # 'coreml' aliases
161
- fmt = 'coreml'
162
- fmts = tuple(export_formats()['Argument'][1:]) # available export formats
172
+ if fmt in ("tensorrt", "trt"): # 'engine' aliases
173
+ fmt = "engine"
174
+ if fmt in ("mlmodel", "mlpackage", "mlprogram", "apple", "ios", "coreml"): # 'coreml' aliases
175
+ fmt = "coreml"
176
+ fmts = tuple(export_formats()["Argument"][1:]) # available export formats
163
177
  flags = [x == fmt for x in fmts]
164
178
  if sum(flags) != 1:
165
179
  raise ValueError(f"Invalid export format='{fmt}'. Valid formats are {fmts}")
166
180
  jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn = flags # export booleans
167
181
 
168
182
  # Device
169
- if fmt == 'engine' and self.args.device is None:
170
- LOGGER.warning('WARNING ⚠️ TensorRT requires GPU export, automatically assigning device=0')
171
- self.args.device = '0'
172
- self.device = select_device('cpu' if self.args.device is None else self.args.device)
183
+ if fmt == "engine" and self.args.device is None:
184
+ LOGGER.warning("WARNING ⚠️ TensorRT requires GPU export, automatically assigning device=0")
185
+ self.args.device = "0"
186
+ self.device = select_device("cpu" if self.args.device is None else self.args.device)
173
187
 
174
188
  # Checks
175
- if not hasattr(model, 'names'):
189
+ if not hasattr(model, "names"):
176
190
  model.names = default_class_names()
177
191
  model.names = check_class_names(model.names)
178
- if self.args.half and onnx and self.device.type == 'cpu':
179
- LOGGER.warning('WARNING ⚠️ half=True only compatible with GPU export, i.e. use device=0')
192
+ if self.args.half and onnx and self.device.type == "cpu":
193
+ LOGGER.warning("WARNING ⚠️ half=True only compatible with GPU export, i.e. use device=0")
180
194
  self.args.half = False
181
- assert not self.args.dynamic, 'half=True not compatible with dynamic=True, i.e. use only one.'
195
+ assert not self.args.dynamic, "half=True not compatible with dynamic=True, i.e. use only one."
182
196
  self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size
183
197
  if self.args.optimize:
184
198
  assert not ncnn, "optimize=True not compatible with format='ncnn', i.e. use optimize=False"
185
- assert self.device.type == 'cpu', "optimize=True not compatible with cuda devices, i.e. use device='cpu'"
199
+ assert self.device.type == "cpu", "optimize=True not compatible with cuda devices, i.e. use device='cpu'"
186
200
  if edgetpu and not LINUX:
187
- raise SystemError('Edge TPU export only supported on Linux. See https://coral.ai/docs/edgetpu/compiler/')
201
+ raise SystemError("Edge TPU export only supported on Linux. See https://coral.ai/docs/edgetpu/compiler/")
188
202
 
189
203
  # Input
190
204
  im = torch.zeros(self.args.batch, 3, *self.imgsz).to(self.device)
191
205
  file = Path(
192
- getattr(model, 'pt_path', None) or getattr(model, 'yaml_file', None) or model.yaml.get('yaml_file', ''))
193
- if file.suffix in {'.yaml', '.yml'}:
206
+ getattr(model, "pt_path", None) or getattr(model, "yaml_file", None) or model.yaml.get("yaml_file", "")
207
+ )
208
+ if file.suffix in {".yaml", ".yml"}:
194
209
  file = Path(file.name)
195
210
 
196
211
  # Update model
@@ -212,42 +227,48 @@ class Exporter:
212
227
  y = None
213
228
  for _ in range(2):
214
229
  y = model(im) # dry runs
215
- if self.args.half and (engine or onnx) and self.device.type != 'cpu':
230
+ if self.args.half and (engine or onnx) and self.device.type != "cpu":
216
231
  im, model = im.half(), model.half() # to FP16
217
232
 
218
233
  # Filter warnings
219
- warnings.filterwarnings('ignore', category=torch.jit.TracerWarning) # suppress TracerWarning
220
- warnings.filterwarnings('ignore', category=UserWarning) # suppress shape prim::Constant missing ONNX warning
221
- warnings.filterwarnings('ignore', category=DeprecationWarning) # suppress CoreML np.bool deprecation warning
234
+ warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) # suppress TracerWarning
235
+ warnings.filterwarnings("ignore", category=UserWarning) # suppress shape prim::Constant missing ONNX warning
236
+ warnings.filterwarnings("ignore", category=DeprecationWarning) # suppress CoreML np.bool deprecation warning
222
237
 
223
238
  # Assign
224
239
  self.im = im
225
240
  self.model = model
226
241
  self.file = file
227
- self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else tuple(
228
- tuple(x.shape if isinstance(x, torch.Tensor) else []) for x in y)
229
- self.pretty_name = Path(self.model.yaml.get('yaml_file', self.file)).stem.replace('yolo', 'YOLO')
230
- data = model.args['data'] if hasattr(model, 'args') and isinstance(model.args, dict) else ''
242
+ self.output_shape = (
243
+ tuple(y.shape)
244
+ if isinstance(y, torch.Tensor)
245
+ else tuple(tuple(x.shape if isinstance(x, torch.Tensor) else []) for x in y)
246
+ )
247
+ self.pretty_name = Path(self.model.yaml.get("yaml_file", self.file)).stem.replace("yolo", "YOLO")
248
+ data = model.args["data"] if hasattr(model, "args") and isinstance(model.args, dict) else ""
231
249
  description = f'Ultralytics {self.pretty_name} model {f"trained on {data}" if data else ""}'
232
250
  self.metadata = {
233
- 'description': description,
234
- 'author': 'Ultralytics',
235
- 'license': 'AGPL-3.0 https://ultralytics.com/license',
236
- 'date': datetime.now().isoformat(),
237
- 'version': __version__,
238
- 'stride': int(max(model.stride)),
239
- 'task': model.task,
240
- 'batch': self.args.batch,
241
- 'imgsz': self.imgsz,
242
- 'names': model.names} # model metadata
243
- if model.task == 'pose':
244
- self.metadata['kpt_shape'] = model.model[-1].kpt_shape
245
-
246
- LOGGER.info(f"\n{colorstr('PyTorch:')} starting from '{file}' with input shape {tuple(im.shape)} BCHW and "
247
- f'output shape(s) {self.output_shape} ({file_size(file):.1f} MB)')
251
+ "description": description,
252
+ "author": "Ultralytics",
253
+ "license": "AGPL-3.0 https://ultralytics.com/license",
254
+ "date": datetime.now().isoformat(),
255
+ "version": __version__,
256
+ "stride": int(max(model.stride)),
257
+ "task": model.task,
258
+ "batch": self.args.batch,
259
+ "imgsz": self.imgsz,
260
+ "names": model.names,
261
+ } # model metadata
262
+ if model.task == "pose":
263
+ self.metadata["kpt_shape"] = model.model[-1].kpt_shape
264
+
265
+ LOGGER.info(
266
+ f"\n{colorstr('PyTorch:')} starting from '{file}' with input shape {tuple(im.shape)} BCHW and "
267
+ f'output shape(s) {self.output_shape} ({file_size(file):.1f} MB)'
268
+ )
248
269
 
249
270
  # Exports
250
- f = [''] * len(fmts) # exported filenames
271
+ f = [""] * len(fmts) # exported filenames
251
272
  if jit or ncnn: # TorchScript
252
273
  f[0], _ = self.export_torchscript()
253
274
  if engine: # TensorRT required before ONNX
@@ -266,7 +287,7 @@ class Exporter:
266
287
  if tflite:
267
288
  f[7], _ = self.export_tflite(keras_model=keras_model, nms=False, agnostic_nms=self.args.agnostic_nms)
268
289
  if edgetpu:
269
- f[8], _ = self.export_edgetpu(tflite_model=Path(f[5]) / f'{self.file.stem}_full_integer_quant.tflite')
290
+ f[8], _ = self.export_edgetpu(tflite_model=Path(f[5]) / f"{self.file.stem}_full_integer_quant.tflite")
270
291
  if tfjs:
271
292
  f[9], _ = self.export_tfjs()
272
293
  if paddle: # PaddlePaddle
@@ -279,58 +300,65 @@ class Exporter:
279
300
  if any(f):
280
301
  f = str(Path(f[-1]))
281
302
  square = self.imgsz[0] == self.imgsz[1]
282
- s = '' if square else f"WARNING ⚠️ non-PyTorch val requires square images, 'imgsz={self.imgsz}' will not " \
283
- f"work. Use export 'imgsz={max(self.imgsz)}' if val is required."
284
- imgsz = self.imgsz[0] if square else str(self.imgsz)[1:-1].replace(' ', '')
285
- predict_data = f'data={data}' if model.task == 'segment' and fmt == 'pb' else ''
286
- q = 'int8' if self.args.int8 else 'half' if self.args.half else '' # quantization
287
- LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'
288
- f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
289
- f'\nPredict: yolo predict task={model.task} model={f} imgsz={imgsz} {q} {predict_data}'
290
- f'\nValidate: yolo val task={model.task} model={f} imgsz={imgsz} data={data} {q} {s}'
291
- f'\nVisualize: https://netron.app')
292
-
293
- self.run_callbacks('on_export_end')
303
+ s = (
304
+ ""
305
+ if square
306
+ else f"WARNING ⚠️ non-PyTorch val requires square images, 'imgsz={self.imgsz}' will not "
307
+ f"work. Use export 'imgsz={max(self.imgsz)}' if val is required."
308
+ )
309
+ imgsz = self.imgsz[0] if square else str(self.imgsz)[1:-1].replace(" ", "")
310
+ predict_data = f"data={data}" if model.task == "segment" and fmt == "pb" else ""
311
+ q = "int8" if self.args.int8 else "half" if self.args.half else "" # quantization
312
+ LOGGER.info(
313
+ f'\nExport complete ({time.time() - t:.1f}s)'
314
+ f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
315
+ f'\nPredict: yolo predict task={model.task} model={f} imgsz={imgsz} {q} {predict_data}'
316
+ f'\nValidate: yolo val task={model.task} model={f} imgsz={imgsz} data={data} {q} {s}'
317
+ f'\nVisualize: https://netron.app'
318
+ )
319
+
320
+ self.run_callbacks("on_export_end")
294
321
  return f # return list of exported files/dirs
295
322
 
296
323
  @try_export
297
- def export_torchscript(self, prefix=colorstr('TorchScript:')):
324
+ def export_torchscript(self, prefix=colorstr("TorchScript:")):
298
325
  """YOLOv8 TorchScript model export."""
299
- LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...')
300
- f = self.file.with_suffix('.torchscript')
326
+ LOGGER.info(f"\n{prefix} starting export with torch {torch.__version__}...")
327
+ f = self.file.with_suffix(".torchscript")
301
328
 
302
329
  ts = torch.jit.trace(self.model, self.im, strict=False)
303
- extra_files = {'config.txt': json.dumps(self.metadata)} # torch._C.ExtraFilesMap()
330
+ extra_files = {"config.txt": json.dumps(self.metadata)} # torch._C.ExtraFilesMap()
304
331
  if self.args.optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
305
- LOGGER.info(f'{prefix} optimizing for mobile...')
332
+ LOGGER.info(f"{prefix} optimizing for mobile...")
306
333
  from torch.utils.mobile_optimizer import optimize_for_mobile
334
+
307
335
  optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files)
308
336
  else:
309
337
  ts.save(str(f), _extra_files=extra_files)
310
338
  return f, None
311
339
 
312
340
  @try_export
313
- def export_onnx(self, prefix=colorstr('ONNX:')):
341
+ def export_onnx(self, prefix=colorstr("ONNX:")):
314
342
  """YOLOv8 ONNX export."""
315
- requirements = ['onnx>=1.12.0']
343
+ requirements = ["onnx>=1.12.0"]
316
344
  if self.args.simplify:
317
- requirements += ['onnxsim>=0.4.33', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime']
345
+ requirements += ["onnxsim>=0.4.33", "onnxruntime-gpu" if torch.cuda.is_available() else "onnxruntime"]
318
346
  check_requirements(requirements)
319
347
  import onnx # noqa
320
348
 
321
349
  opset_version = self.args.opset or get_latest_opset()
322
- LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__} opset {opset_version}...')
323
- f = str(self.file.with_suffix('.onnx'))
350
+ LOGGER.info(f"\n{prefix} starting export with onnx {onnx.__version__} opset {opset_version}...")
351
+ f = str(self.file.with_suffix(".onnx"))
324
352
 
325
- output_names = ['output0', 'output1'] if isinstance(self.model, SegmentationModel) else ['output0']
353
+ output_names = ["output0", "output1"] if isinstance(self.model, SegmentationModel) else ["output0"]
326
354
  dynamic = self.args.dynamic
327
355
  if dynamic:
328
- dynamic = {'images': {0: 'batch', 2: 'height', 3: 'width'}} # shape(1,3,640,640)
356
+ dynamic = {"images": {0: "batch", 2: "height", 3: "width"}} # shape(1,3,640,640)
329
357
  if isinstance(self.model, SegmentationModel):
330
- dynamic['output0'] = {0: 'batch', 2: 'anchors'} # shape(1, 116, 8400)
331
- dynamic['output1'] = {0: 'batch', 2: 'mask_height', 3: 'mask_width'} # shape(1,32,160,160)
358
+ dynamic["output0"] = {0: "batch", 2: "anchors"} # shape(1, 116, 8400)
359
+ dynamic["output1"] = {0: "batch", 2: "mask_height", 3: "mask_width"} # shape(1,32,160,160)
332
360
  elif isinstance(self.model, DetectionModel):
333
- dynamic['output0'] = {0: 'batch', 2: 'anchors'} # shape(1, 84, 8400)
361
+ dynamic["output0"] = {0: "batch", 2: "anchors"} # shape(1, 84, 8400)
334
362
 
335
363
  torch.onnx.export(
336
364
  self.model.cpu() if dynamic else self.model, # dynamic=True only compatible with cpu
@@ -339,9 +367,10 @@ class Exporter:
339
367
  verbose=False,
340
368
  opset_version=opset_version,
341
369
  do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
342
- input_names=['images'],
370
+ input_names=["images"],
343
371
  output_names=output_names,
344
- dynamic_axes=dynamic or None)
372
+ dynamic_axes=dynamic or None,
373
+ )
345
374
 
346
375
  # Checks
347
376
  model_onnx = onnx.load(f) # load onnx model
@@ -352,12 +381,12 @@ class Exporter:
352
381
  try:
353
382
  import onnxsim
354
383
 
355
- LOGGER.info(f'{prefix} simplifying with onnxsim {onnxsim.__version__}...')
384
+ LOGGER.info(f"{prefix} simplifying with onnxsim {onnxsim.__version__}...")
356
385
  # subprocess.run(f'onnxsim "{f}" "{f}"', shell=True)
357
386
  model_onnx, check = onnxsim.simplify(model_onnx)
358
- assert check, 'Simplified ONNX model could not be validated'
387
+ assert check, "Simplified ONNX model could not be validated"
359
388
  except Exception as e:
360
- LOGGER.info(f'{prefix} simplifier failure: {e}')
389
+ LOGGER.info(f"{prefix} simplifier failure: {e}")
361
390
 
362
391
  # Metadata
363
392
  for k, v in self.metadata.items():
@@ -368,58 +397,56 @@ class Exporter:
368
397
  return f, model_onnx
369
398
 
370
399
  @try_export
371
- def export_openvino(self, prefix=colorstr('OpenVINO:')):
400
+ def export_openvino(self, prefix=colorstr("OpenVINO:")):
372
401
  """YOLOv8 OpenVINO export."""
373
- check_requirements('openvino-dev>=2023.0') # requires openvino-dev: https://pypi.org/project/openvino-dev/
402
+ check_requirements("openvino-dev>=2023.0") # requires openvino-dev: https://pypi.org/project/openvino-dev/
374
403
  import openvino.runtime as ov # noqa
375
404
  from openvino.tools import mo # noqa
376
405
 
377
- LOGGER.info(f'\n{prefix} starting export with openvino {ov.__version__}...')
378
- f = str(self.file).replace(self.file.suffix, f'_openvino_model{os.sep}')
379
- fq = str(self.file).replace(self.file.suffix, f'_int8_openvino_model{os.sep}')
380
- f_onnx = self.file.with_suffix('.onnx')
381
- f_ov = str(Path(f) / self.file.with_suffix('.xml').name)
382
- fq_ov = str(Path(fq) / self.file.with_suffix('.xml').name)
406
+ LOGGER.info(f"\n{prefix} starting export with openvino {ov.__version__}...")
407
+ f = str(self.file).replace(self.file.suffix, f"_openvino_model{os.sep}")
408
+ fq = str(self.file).replace(self.file.suffix, f"_int8_openvino_model{os.sep}")
409
+ f_onnx = self.file.with_suffix(".onnx")
410
+ f_ov = str(Path(f) / self.file.with_suffix(".xml").name)
411
+ fq_ov = str(Path(fq) / self.file.with_suffix(".xml").name)
383
412
 
384
413
  def serialize(ov_model, file):
385
414
  """Set RT info, serialize and save metadata YAML."""
386
- ov_model.set_rt_info('YOLOv8', ['model_info', 'model_type'])
387
- ov_model.set_rt_info(True, ['model_info', 'reverse_input_channels'])
388
- ov_model.set_rt_info(114, ['model_info', 'pad_value'])
389
- ov_model.set_rt_info([255.0], ['model_info', 'scale_values'])
390
- ov_model.set_rt_info(self.args.iou, ['model_info', 'iou_threshold'])
391
- ov_model.set_rt_info([v.replace(' ', '_') for v in self.model.names.values()], ['model_info', 'labels'])
392
- if self.model.task != 'classify':
393
- ov_model.set_rt_info('fit_to_window_letterbox', ['model_info', 'resize_type'])
415
+ ov_model.set_rt_info("YOLOv8", ["model_info", "model_type"])
416
+ ov_model.set_rt_info(True, ["model_info", "reverse_input_channels"])
417
+ ov_model.set_rt_info(114, ["model_info", "pad_value"])
418
+ ov_model.set_rt_info([255.0], ["model_info", "scale_values"])
419
+ ov_model.set_rt_info(self.args.iou, ["model_info", "iou_threshold"])
420
+ ov_model.set_rt_info([v.replace(" ", "_") for v in self.model.names.values()], ["model_info", "labels"])
421
+ if self.model.task != "classify":
422
+ ov_model.set_rt_info("fit_to_window_letterbox", ["model_info", "resize_type"])
394
423
 
395
424
  ov.serialize(ov_model, file) # save
396
- yaml_save(Path(file).parent / 'metadata.yaml', self.metadata) # add metadata.yaml
425
+ yaml_save(Path(file).parent / "metadata.yaml", self.metadata) # add metadata.yaml
397
426
 
398
- ov_model = mo.convert_model(f_onnx,
399
- model_name=self.pretty_name,
400
- framework='onnx',
401
- compress_to_fp16=self.args.half) # export
427
+ ov_model = mo.convert_model(
428
+ f_onnx, model_name=self.pretty_name, framework="onnx", compress_to_fp16=self.args.half
429
+ ) # export
402
430
 
403
431
  if self.args.int8:
404
432
  assert self.args.data, "INT8 export requires a data argument for calibration, i.e. 'data=coco8.yaml'"
405
- check_requirements('nncf>=2.5.0')
433
+ check_requirements("nncf>=2.5.0")
406
434
  import nncf
407
435
 
408
436
  def transform_fn(data_item):
409
437
  """Quantization transform function."""
410
- im = data_item['img'].numpy().astype(np.float32) / 255.0 # uint8 to fp16/32 and 0 - 255 to 0.0 - 1.0
438
+ im = data_item["img"].numpy().astype(np.float32) / 255.0 # uint8 to fp16/32 and 0 - 255 to 0.0 - 1.0
411
439
  return np.expand_dims(im, 0) if im.ndim == 3 else im
412
440
 
413
441
  # Generate calibration data for integer quantization
414
442
  LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'")
415
443
  data = check_det_dataset(self.args.data)
416
- dataset = YOLODataset(data['val'], data=data, imgsz=self.imgsz[0], augment=False)
444
+ dataset = YOLODataset(data["val"], data=data, imgsz=self.imgsz[0], augment=False)
417
445
  quantization_dataset = nncf.Dataset(dataset, transform_fn)
418
- ignored_scope = nncf.IgnoredScope(types=['Multiply', 'Subtract', 'Sigmoid']) # ignore operation
419
- quantized_ov_model = nncf.quantize(ov_model,
420
- quantization_dataset,
421
- preset=nncf.QuantizationPreset.MIXED,
422
- ignored_scope=ignored_scope)
446
+ ignored_scope = nncf.IgnoredScope(types=["Multiply", "Subtract", "Sigmoid"]) # ignore operation
447
+ quantized_ov_model = nncf.quantize(
448
+ ov_model, quantization_dataset, preset=nncf.QuantizationPreset.MIXED, ignored_scope=ignored_scope
449
+ )
423
450
  serialize(quantized_ov_model, fq_ov)
424
451
  return fq, None
425
452
 
@@ -427,48 +454,49 @@ class Exporter:
427
454
  return f, None
428
455
 
429
456
  @try_export
430
- def export_paddle(self, prefix=colorstr('PaddlePaddle:')):
457
+ def export_paddle(self, prefix=colorstr("PaddlePaddle:")):
431
458
  """YOLOv8 Paddle export."""
432
- check_requirements(('paddlepaddle', 'x2paddle'))
459
+ check_requirements(("paddlepaddle", "x2paddle"))
433
460
  import x2paddle # noqa
434
461
  from x2paddle.convert import pytorch2paddle # noqa
435
462
 
436
- LOGGER.info(f'\n{prefix} starting export with X2Paddle {x2paddle.__version__}...')
437
- f = str(self.file).replace(self.file.suffix, f'_paddle_model{os.sep}')
463
+ LOGGER.info(f"\n{prefix} starting export with X2Paddle {x2paddle.__version__}...")
464
+ f = str(self.file).replace(self.file.suffix, f"_paddle_model{os.sep}")
438
465
 
439
- pytorch2paddle(module=self.model, save_dir=f, jit_type='trace', input_examples=[self.im]) # export
440
- yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
466
+ pytorch2paddle(module=self.model, save_dir=f, jit_type="trace", input_examples=[self.im]) # export
467
+ yaml_save(Path(f) / "metadata.yaml", self.metadata) # add metadata.yaml
441
468
  return f, None
442
469
 
443
470
  @try_export
444
- def export_ncnn(self, prefix=colorstr('ncnn:')):
471
+ def export_ncnn(self, prefix=colorstr("ncnn:")):
445
472
  """
446
473
  YOLOv8 ncnn export using PNNX https://github.com/pnnx/pnnx.
447
474
  """
448
- check_requirements('git+https://github.com/Tencent/ncnn.git' if ARM64 else 'ncnn') # requires ncnn
475
+ check_requirements("git+https://github.com/Tencent/ncnn.git" if ARM64 else "ncnn") # requires ncnn
449
476
  import ncnn # noqa
450
477
 
451
- LOGGER.info(f'\n{prefix} starting export with ncnn {ncnn.__version__}...')
452
- f = Path(str(self.file).replace(self.file.suffix, f'_ncnn_model{os.sep}'))
453
- f_ts = self.file.with_suffix('.torchscript')
478
+ LOGGER.info(f"\n{prefix} starting export with ncnn {ncnn.__version__}...")
479
+ f = Path(str(self.file).replace(self.file.suffix, f"_ncnn_model{os.sep}"))
480
+ f_ts = self.file.with_suffix(".torchscript")
454
481
 
455
- name = Path('pnnx.exe' if WINDOWS else 'pnnx') # PNNX filename
482
+ name = Path("pnnx.exe" if WINDOWS else "pnnx") # PNNX filename
456
483
  pnnx = name if name.is_file() else ROOT / name
457
484
  if not pnnx.is_file():
458
485
  LOGGER.warning(
459
- f'{prefix} WARNING ⚠️ PNNX not found. Attempting to download binary file from '
460
- 'https://github.com/pnnx/pnnx/.\nNote PNNX Binary file must be placed in current working directory '
461
- f'or in {ROOT}. See PNNX repo for full installation instructions.')
462
- system = ['macos'] if MACOS else ['windows'] if WINDOWS else ['ubuntu', 'linux'] # operating system
486
+ f"{prefix} WARNING ⚠️ PNNX not found. Attempting to download binary file from "
487
+ "https://github.com/pnnx/pnnx/.\nNote PNNX Binary file must be placed in current working directory "
488
+ f"or in {ROOT}. See PNNX repo for full installation instructions."
489
+ )
490
+ system = ["macos"] if MACOS else ["windows"] if WINDOWS else ["ubuntu", "linux"] # operating system
463
491
  try:
464
- _, assets = get_github_assets(repo='pnnx/pnnx', retry=True)
492
+ _, assets = get_github_assets(repo="pnnx/pnnx", retry=True)
465
493
  url = [x for x in assets if any(s in x for s in system)][0]
466
494
  except Exception as e:
467
- url = f'https://github.com/pnnx/pnnx/releases/download/20231127/pnnx-20231127-{system[0]}.zip'
468
- LOGGER.warning(f'{prefix} WARNING ⚠️ PNNX GitHub assets not found: {e}, using default {url}')
469
- asset = attempt_download_asset(url, repo='pnnx/pnnx', release='latest')
495
+ url = f"https://github.com/pnnx/pnnx/releases/download/20231127/pnnx-20231127-{system[0]}.zip"
496
+ LOGGER.warning(f"{prefix} WARNING ⚠️ PNNX GitHub assets not found: {e}, using default {url}")
497
+ asset = attempt_download_asset(url, repo="pnnx/pnnx", release="latest")
470
498
  if check_is_path_safe(Path.cwd(), asset): # avoid path traversal security vulnerability
471
- unzip_dir = Path(asset).with_suffix('')
499
+ unzip_dir = Path(asset).with_suffix("")
472
500
  (unzip_dir / name).rename(pnnx) # move binary to ROOT
473
501
  shutil.rmtree(unzip_dir) # delete unzip dir
474
502
  Path(asset).unlink() # delete zip
@@ -477,53 +505,56 @@ class Exporter:
477
505
  ncnn_args = [
478
506
  f'ncnnparam={f / "model.ncnn.param"}',
479
507
  f'ncnnbin={f / "model.ncnn.bin"}',
480
- f'ncnnpy={f / "model_ncnn.py"}', ]
508
+ f'ncnnpy={f / "model_ncnn.py"}',
509
+ ]
481
510
 
482
511
  pnnx_args = [
483
512
  f'pnnxparam={f / "model.pnnx.param"}',
484
513
  f'pnnxbin={f / "model.pnnx.bin"}',
485
514
  f'pnnxpy={f / "model_pnnx.py"}',
486
- f'pnnxonnx={f / "model.pnnx.onnx"}', ]
515
+ f'pnnxonnx={f / "model.pnnx.onnx"}',
516
+ ]
487
517
 
488
518
  cmd = [
489
519
  str(pnnx),
490
520
  str(f_ts),
491
521
  *ncnn_args,
492
522
  *pnnx_args,
493
- f'fp16={int(self.args.half)}',
494
- f'device={self.device.type}',
495
- f'inputshape="{[self.args.batch, 3, *self.imgsz]}"', ]
523
+ f"fp16={int(self.args.half)}",
524
+ f"device={self.device.type}",
525
+ f'inputshape="{[self.args.batch, 3, *self.imgsz]}"',
526
+ ]
496
527
  f.mkdir(exist_ok=True) # make ncnn_model directory
497
528
  LOGGER.info(f"{prefix} running '{' '.join(cmd)}'")
498
529
  subprocess.run(cmd, check=True)
499
530
 
500
531
  # Remove debug files
501
- pnnx_files = [x.split('=')[-1] for x in pnnx_args]
502
- for f_debug in ('debug.bin', 'debug.param', 'debug2.bin', 'debug2.param', *pnnx_files):
532
+ pnnx_files = [x.split("=")[-1] for x in pnnx_args]
533
+ for f_debug in ("debug.bin", "debug.param", "debug2.bin", "debug2.param", *pnnx_files):
503
534
  Path(f_debug).unlink(missing_ok=True)
504
535
 
505
- yaml_save(f / 'metadata.yaml', self.metadata) # add metadata.yaml
536
+ yaml_save(f / "metadata.yaml", self.metadata) # add metadata.yaml
506
537
  return str(f), None
507
538
 
508
539
  @try_export
509
- def export_coreml(self, prefix=colorstr('CoreML:')):
540
+ def export_coreml(self, prefix=colorstr("CoreML:")):
510
541
  """YOLOv8 CoreML export."""
511
- mlmodel = self.args.format.lower() == 'mlmodel' # legacy *.mlmodel export format requested
512
- check_requirements('coremltools>=6.0,<=6.2' if mlmodel else 'coremltools>=7.0')
542
+ mlmodel = self.args.format.lower() == "mlmodel" # legacy *.mlmodel export format requested
543
+ check_requirements("coremltools>=6.0,<=6.2" if mlmodel else "coremltools>=7.0")
513
544
  import coremltools as ct # noqa
514
545
 
515
- LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...')
516
- f = self.file.with_suffix('.mlmodel' if mlmodel else '.mlpackage')
546
+ LOGGER.info(f"\n{prefix} starting export with coremltools {ct.__version__}...")
547
+ f = self.file.with_suffix(".mlmodel" if mlmodel else ".mlpackage")
517
548
  if f.is_dir():
518
549
  shutil.rmtree(f)
519
550
 
520
551
  bias = [0.0, 0.0, 0.0]
521
552
  scale = 1 / 255
522
553
  classifier_config = None
523
- if self.model.task == 'classify':
554
+ if self.model.task == "classify":
524
555
  classifier_config = ct.ClassifierConfig(list(self.model.names.values())) if self.args.nms else None
525
556
  model = self.model
526
- elif self.model.task == 'detect':
557
+ elif self.model.task == "detect":
527
558
  model = IOSDetectModel(self.model, self.im) if self.args.nms else self.model
528
559
  else:
529
560
  if self.args.nms:
@@ -532,69 +563,73 @@ class Exporter:
532
563
  model = self.model
533
564
 
534
565
  ts = torch.jit.trace(model.eval(), self.im, strict=False) # TorchScript model
535
- ct_model = ct.convert(ts,
536
- inputs=[ct.ImageType('image', shape=self.im.shape, scale=scale, bias=bias)],
537
- classifier_config=classifier_config,
538
- convert_to='neuralnetwork' if mlmodel else 'mlprogram')
539
- bits, mode = (8, 'kmeans') if self.args.int8 else (16, 'linear') if self.args.half else (32, None)
566
+ ct_model = ct.convert(
567
+ ts,
568
+ inputs=[ct.ImageType("image", shape=self.im.shape, scale=scale, bias=bias)],
569
+ classifier_config=classifier_config,
570
+ convert_to="neuralnetwork" if mlmodel else "mlprogram",
571
+ )
572
+ bits, mode = (8, "kmeans") if self.args.int8 else (16, "linear") if self.args.half else (32, None)
540
573
  if bits < 32:
541
- if 'kmeans' in mode:
542
- check_requirements('scikit-learn') # scikit-learn package required for k-means quantization
574
+ if "kmeans" in mode:
575
+ check_requirements("scikit-learn") # scikit-learn package required for k-means quantization
543
576
  if mlmodel:
544
577
  ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
545
578
  elif bits == 8: # mlprogram already quantized to FP16
546
579
  import coremltools.optimize.coreml as cto
547
- op_config = cto.OpPalettizerConfig(mode='kmeans', nbits=bits, weight_threshold=512)
580
+
581
+ op_config = cto.OpPalettizerConfig(mode="kmeans", nbits=bits, weight_threshold=512)
548
582
  config = cto.OptimizationConfig(global_config=op_config)
549
583
  ct_model = cto.palettize_weights(ct_model, config=config)
550
- if self.args.nms and self.model.task == 'detect':
584
+ if self.args.nms and self.model.task == "detect":
551
585
  if mlmodel:
552
586
  import platform
553
587
 
554
588
  # coremltools<=6.2 NMS export requires Python<3.11
555
- check_version(platform.python_version(), '<3.11', name='Python ', hard=True)
589
+ check_version(platform.python_version(), "<3.11", name="Python ", hard=True)
556
590
  weights_dir = None
557
591
  else:
558
592
  ct_model.save(str(f)) # save otherwise weights_dir does not exist
559
- weights_dir = str(f / 'Data/com.apple.CoreML/weights')
593
+ weights_dir = str(f / "Data/com.apple.CoreML/weights")
560
594
  ct_model = self._pipeline_coreml(ct_model, weights_dir=weights_dir)
561
595
 
562
596
  m = self.metadata # metadata dict
563
- ct_model.short_description = m.pop('description')
564
- ct_model.author = m.pop('author')
565
- ct_model.license = m.pop('license')
566
- ct_model.version = m.pop('version')
597
+ ct_model.short_description = m.pop("description")
598
+ ct_model.author = m.pop("author")
599
+ ct_model.license = m.pop("license")
600
+ ct_model.version = m.pop("version")
567
601
  ct_model.user_defined_metadata.update({k: str(v) for k, v in m.items()})
568
602
  try:
569
603
  ct_model.save(str(f)) # save *.mlpackage
570
604
  except Exception as e:
571
605
  LOGGER.warning(
572
- f'{prefix} WARNING ⚠️ CoreML export to *.mlpackage failed ({e}), reverting to *.mlmodel export. '
573
- f'Known coremltools Python 3.11 and Windows bugs https://github.com/apple/coremltools/issues/1928.')
574
- f = f.with_suffix('.mlmodel')
606
+ f"{prefix} WARNING ⚠️ CoreML export to *.mlpackage failed ({e}), reverting to *.mlmodel export. "
607
+ f"Known coremltools Python 3.11 and Windows bugs https://github.com/apple/coremltools/issues/1928."
608
+ )
609
+ f = f.with_suffix(".mlmodel")
575
610
  ct_model.save(str(f))
576
611
  return f, ct_model
577
612
 
578
613
  @try_export
579
- def export_engine(self, prefix=colorstr('TensorRT:')):
614
+ def export_engine(self, prefix=colorstr("TensorRT:")):
580
615
  """YOLOv8 TensorRT export https://developer.nvidia.com/tensorrt."""
581
- assert self.im.device.type != 'cpu', "export running on CPU but must be on GPU, i.e. use 'device=0'"
616
+ assert self.im.device.type != "cpu", "export running on CPU but must be on GPU, i.e. use 'device=0'"
582
617
  f_onnx, _ = self.export_onnx() # run before trt import https://github.com/ultralytics/ultralytics/issues/7016
583
618
 
584
619
  try:
585
620
  import tensorrt as trt # noqa
586
621
  except ImportError:
587
622
  if LINUX:
588
- check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com')
623
+ check_requirements("nvidia-tensorrt", cmds="-U --index-url https://pypi.ngc.nvidia.com")
589
624
  import tensorrt as trt # noqa
590
625
 
591
- check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
626
+ check_version(trt.__version__, "7.0.0", hard=True) # require tensorrt>=7.0.0
592
627
 
593
628
  self.args.simplify = True
594
629
 
595
- LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
596
- assert Path(f_onnx).exists(), f'failed to export ONNX file: {f_onnx}'
597
- f = self.file.with_suffix('.engine') # TensorRT engine file
630
+ LOGGER.info(f"\n{prefix} starting export with TensorRT {trt.__version__}...")
631
+ assert Path(f_onnx).exists(), f"failed to export ONNX file: {f_onnx}"
632
+ f = self.file.with_suffix(".engine") # TensorRT engine file
598
633
  logger = trt.Logger(trt.Logger.INFO)
599
634
  if self.args.verbose:
600
635
  logger.min_severity = trt.Logger.Severity.VERBOSE
@@ -604,11 +639,11 @@ class Exporter:
604
639
  config.max_workspace_size = self.args.workspace * 1 << 30
605
640
  # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice
606
641
 
607
- flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
642
+ flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
608
643
  network = builder.create_network(flag)
609
644
  parser = trt.OnnxParser(network, logger)
610
645
  if not parser.parse_from_file(f_onnx):
611
- raise RuntimeError(f'failed to load ONNX file: {f_onnx}')
646
+ raise RuntimeError(f"failed to load ONNX file: {f_onnx}")
612
647
 
613
648
  inputs = [network.get_input(i) for i in range(network.num_inputs)]
614
649
  outputs = [network.get_output(i) for i in range(network.num_outputs)]
@@ -627,7 +662,8 @@ class Exporter:
627
662
  config.add_optimization_profile(profile)
628
663
 
629
664
  LOGGER.info(
630
- f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and self.args.half else 32} engine as {f}')
665
+ f"{prefix} building FP{16 if builder.platform_has_fast_fp16 and self.args.half else 32} engine as {f}"
666
+ )
631
667
  if builder.platform_has_fast_fp16 and self.args.half:
632
668
  config.set_flag(trt.BuilderFlag.FP16)
633
669
 
@@ -635,10 +671,10 @@ class Exporter:
635
671
  torch.cuda.empty_cache()
636
672
 
637
673
  # Write file
638
- with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
674
+ with builder.build_engine(network, config) as engine, open(f, "wb") as t:
639
675
  # Metadata
640
676
  meta = json.dumps(self.metadata)
641
- t.write(len(meta).to_bytes(4, byteorder='little', signed=True))
677
+ t.write(len(meta).to_bytes(4, byteorder="little", signed=True))
642
678
  t.write(meta.encode())
643
679
  # Model
644
680
  t.write(engine.serialize())
@@ -646,7 +682,7 @@ class Exporter:
646
682
  return f, None
647
683
 
648
684
  @try_export
649
- def export_saved_model(self, prefix=colorstr('TensorFlow SavedModel:')):
685
+ def export_saved_model(self, prefix=colorstr("TensorFlow SavedModel:")):
650
686
  """YOLOv8 TensorFlow SavedModel export."""
651
687
  cuda = torch.cuda.is_available()
652
688
  try:
@@ -655,44 +691,55 @@ class Exporter:
655
691
  check_requirements(f"tensorflow{'-macos' if MACOS else '-aarch64' if ARM64 else '' if cuda else '-cpu'}")
656
692
  import tensorflow as tf # noqa
657
693
  check_requirements(
658
- ('onnx', 'onnx2tf>=1.15.4,<=1.17.5', 'sng4onnx>=1.0.1', 'onnxsim>=0.4.33', 'onnx_graphsurgeon>=0.3.26',
659
- 'tflite_support', 'onnxruntime-gpu' if cuda else 'onnxruntime'),
660
- cmds='--extra-index-url https://pypi.ngc.nvidia.com') # onnx_graphsurgeon only on NVIDIA
661
-
662
- LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
663
- check_version(tf.__version__,
664
- '<=2.13.1',
665
- name='tensorflow',
666
- verbose=True,
667
- msg='https://github.com/ultralytics/ultralytics/issues/5161')
668
- f = Path(str(self.file).replace(self.file.suffix, '_saved_model'))
694
+ (
695
+ "onnx",
696
+ "onnx2tf>=1.15.4,<=1.17.5",
697
+ "sng4onnx>=1.0.1",
698
+ "onnxsim>=0.4.33",
699
+ "onnx_graphsurgeon>=0.3.26",
700
+ "tflite_support",
701
+ "onnxruntime-gpu" if cuda else "onnxruntime",
702
+ ),
703
+ cmds="--extra-index-url https://pypi.ngc.nvidia.com",
704
+ ) # onnx_graphsurgeon only on NVIDIA
705
+
706
+ LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
707
+ check_version(
708
+ tf.__version__,
709
+ "<=2.13.1",
710
+ name="tensorflow",
711
+ verbose=True,
712
+ msg="https://github.com/ultralytics/ultralytics/issues/5161",
713
+ )
714
+ f = Path(str(self.file).replace(self.file.suffix, "_saved_model"))
669
715
  if f.is_dir():
670
716
  import shutil
717
+
671
718
  shutil.rmtree(f) # delete output folder
672
719
 
673
720
  # Pre-download calibration file to fix https://github.com/PINTO0309/onnx2tf/issues/545
674
- onnx2tf_file = Path('calibration_image_sample_data_20x128x128x3_float32.npy')
721
+ onnx2tf_file = Path("calibration_image_sample_data_20x128x128x3_float32.npy")
675
722
  if not onnx2tf_file.exists():
676
- attempt_download_asset(f'{onnx2tf_file}.zip', unzip=True, delete=True)
723
+ attempt_download_asset(f"{onnx2tf_file}.zip", unzip=True, delete=True)
677
724
 
678
725
  # Export to ONNX
679
726
  self.args.simplify = True
680
727
  f_onnx, _ = self.export_onnx()
681
728
 
682
729
  # Export to TF
683
- tmp_file = f / 'tmp_tflite_int8_calibration_images.npy' # int8 calibration images file
730
+ tmp_file = f / "tmp_tflite_int8_calibration_images.npy" # int8 calibration images file
684
731
  if self.args.int8:
685
- verbosity = '--verbosity info'
732
+ verbosity = "--verbosity info"
686
733
  if self.args.data:
687
734
  # Generate calibration data for integer quantization
688
735
  LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'")
689
736
  data = check_det_dataset(self.args.data)
690
- dataset = YOLODataset(data['val'], data=data, imgsz=self.imgsz[0], augment=False)
737
+ dataset = YOLODataset(data["val"], data=data, imgsz=self.imgsz[0], augment=False)
691
738
  images = []
692
739
  for i, batch in enumerate(dataset):
693
740
  if i >= 100: # maximum number of calibration images
694
741
  break
695
- im = batch['img'].permute(1, 2, 0)[None] # list to nparray, CHW to BHWC
742
+ im = batch["img"].permute(1, 2, 0)[None] # list to nparray, CHW to BHWC
696
743
  images.append(im)
697
744
  f.mkdir()
698
745
  images = torch.cat(images, 0).float()
@@ -701,38 +748,38 @@ class Exporter:
701
748
  np.save(str(tmp_file), images.numpy()) # BHWC
702
749
  int8 = f'-oiqt -qt per-tensor -cind images "{tmp_file}" "[[[[0, 0, 0]]]]" "[[[[255, 255, 255]]]]"'
703
750
  else:
704
- int8 = '-oiqt -qt per-tensor'
751
+ int8 = "-oiqt -qt per-tensor"
705
752
  else:
706
- verbosity = '--non_verbose'
707
- int8 = ''
753
+ verbosity = "--non_verbose"
754
+ int8 = ""
708
755
 
709
756
  cmd = f'onnx2tf -i "{f_onnx}" -o "{f}" -nuo {verbosity} {int8}'.strip()
710
757
  LOGGER.info(f"{prefix} running '{cmd}'")
711
758
  subprocess.run(cmd, shell=True)
712
- yaml_save(f / 'metadata.yaml', self.metadata) # add metadata.yaml
759
+ yaml_save(f / "metadata.yaml", self.metadata) # add metadata.yaml
713
760
 
714
761
  # Remove/rename TFLite models
715
762
  if self.args.int8:
716
763
  tmp_file.unlink(missing_ok=True)
717
- for file in f.rglob('*_dynamic_range_quant.tflite'):
718
- file.rename(file.with_name(file.stem.replace('_dynamic_range_quant', '_int8') + file.suffix))
719
- for file in f.rglob('*_integer_quant_with_int16_act.tflite'):
764
+ for file in f.rglob("*_dynamic_range_quant.tflite"):
765
+ file.rename(file.with_name(file.stem.replace("_dynamic_range_quant", "_int8") + file.suffix))
766
+ for file in f.rglob("*_integer_quant_with_int16_act.tflite"):
720
767
  file.unlink() # delete extra fp16 activation TFLite files
721
768
 
722
769
  # Add TFLite metadata
723
- for file in f.rglob('*.tflite'):
724
- f.unlink() if 'quant_with_int16_act.tflite' in str(f) else self._add_tflite_metadata(file)
770
+ for file in f.rglob("*.tflite"):
771
+ f.unlink() if "quant_with_int16_act.tflite" in str(f) else self._add_tflite_metadata(file)
725
772
 
726
773
  return str(f), tf.saved_model.load(f, tags=None, options=None) # load saved_model as Keras model
727
774
 
728
775
  @try_export
729
- def export_pb(self, keras_model, prefix=colorstr('TensorFlow GraphDef:')):
776
+ def export_pb(self, keras_model, prefix=colorstr("TensorFlow GraphDef:")):
730
777
  """YOLOv8 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow."""
731
778
  import tensorflow as tf # noqa
732
779
  from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa
733
780
 
734
- LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
735
- f = self.file.with_suffix('.pb')
781
+ LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
782
+ f = self.file.with_suffix(".pb")
736
783
 
737
784
  m = tf.function(lambda x: keras_model(x)) # full model
738
785
  m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
@@ -742,40 +789,43 @@ class Exporter:
742
789
  return f, None
743
790
 
744
791
  @try_export
745
- def export_tflite(self, keras_model, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')):
792
+ def export_tflite(self, keras_model, nms, agnostic_nms, prefix=colorstr("TensorFlow Lite:")):
746
793
  """YOLOv8 TensorFlow Lite export."""
747
794
  import tensorflow as tf # noqa
748
795
 
749
- LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
750
- saved_model = Path(str(self.file).replace(self.file.suffix, '_saved_model'))
796
+ LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
797
+ saved_model = Path(str(self.file).replace(self.file.suffix, "_saved_model"))
751
798
  if self.args.int8:
752
- f = saved_model / f'{self.file.stem}_int8.tflite' # fp32 in/out
799
+ f = saved_model / f"{self.file.stem}_int8.tflite" # fp32 in/out
753
800
  elif self.args.half:
754
- f = saved_model / f'{self.file.stem}_float16.tflite' # fp32 in/out
801
+ f = saved_model / f"{self.file.stem}_float16.tflite" # fp32 in/out
755
802
  else:
756
- f = saved_model / f'{self.file.stem}_float32.tflite'
803
+ f = saved_model / f"{self.file.stem}_float32.tflite"
757
804
  return str(f), None
758
805
 
759
806
  @try_export
760
- def export_edgetpu(self, tflite_model='', prefix=colorstr('Edge TPU:')):
807
+ def export_edgetpu(self, tflite_model="", prefix=colorstr("Edge TPU:")):
761
808
  """YOLOv8 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/."""
762
- LOGGER.warning(f'{prefix} WARNING ⚠️ Edge TPU known bug https://github.com/ultralytics/ultralytics/issues/1185')
809
+ LOGGER.warning(f"{prefix} WARNING ⚠️ Edge TPU known bug https://github.com/ultralytics/ultralytics/issues/1185")
763
810
 
764
- cmd = 'edgetpu_compiler --version'
765
- help_url = 'https://coral.ai/docs/edgetpu/compiler/'
766
- assert LINUX, f'export only supported on Linux. See {help_url}'
811
+ cmd = "edgetpu_compiler --version"
812
+ help_url = "https://coral.ai/docs/edgetpu/compiler/"
813
+ assert LINUX, f"export only supported on Linux. See {help_url}"
767
814
  if subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True).returncode != 0:
768
- LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}')
769
- sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0 # sudo installed on system
770
- for c in ('curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -',
771
- 'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | '
772
- 'sudo tee /etc/apt/sources.list.d/coral-edgetpu.list', 'sudo apt-get update',
773
- 'sudo apt-get install edgetpu-compiler'):
774
- subprocess.run(c if sudo else c.replace('sudo ', ''), shell=True, check=True)
815
+ LOGGER.info(f"\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}")
816
+ sudo = subprocess.run("sudo --version >/dev/null", shell=True).returncode == 0 # sudo installed on system
817
+ for c in (
818
+ "curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -",
819
+ 'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | '
820
+ "sudo tee /etc/apt/sources.list.d/coral-edgetpu.list",
821
+ "sudo apt-get update",
822
+ "sudo apt-get install edgetpu-compiler",
823
+ ):
824
+ subprocess.run(c if sudo else c.replace("sudo ", ""), shell=True, check=True)
775
825
  ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]
776
826
 
777
- LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')
778
- f = str(tflite_model).replace('.tflite', '_edgetpu.tflite') # Edge TPU model
827
+ LOGGER.info(f"\n{prefix} starting export with Edge TPU compiler {ver}...")
828
+ f = str(tflite_model).replace(".tflite", "_edgetpu.tflite") # Edge TPU model
779
829
 
780
830
  cmd = f'edgetpu_compiler -s -d -k 10 --out_dir "{Path(f).parent}" "{tflite_model}"'
781
831
  LOGGER.info(f"{prefix} running '{cmd}'")
@@ -784,30 +834,30 @@ class Exporter:
784
834
  return f, None
785
835
 
786
836
  @try_export
787
- def export_tfjs(self, prefix=colorstr('TensorFlow.js:')):
837
+ def export_tfjs(self, prefix=colorstr("TensorFlow.js:")):
788
838
  """YOLOv8 TensorFlow.js export."""
789
839
  # JAX bug requiring install constraints in https://github.com/google/jax/issues/18978
790
- check_requirements(['jax<=0.4.21', 'jaxlib<=0.4.21', 'tensorflowjs'])
840
+ check_requirements(["jax<=0.4.21", "jaxlib<=0.4.21", "tensorflowjs"])
791
841
  import tensorflow as tf
792
842
  import tensorflowjs as tfjs # noqa
793
843
 
794
- LOGGER.info(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')
795
- f = str(self.file).replace(self.file.suffix, '_web_model') # js dir
796
- f_pb = str(self.file.with_suffix('.pb')) # *.pb path
844
+ LOGGER.info(f"\n{prefix} starting export with tensorflowjs {tfjs.__version__}...")
845
+ f = str(self.file).replace(self.file.suffix, "_web_model") # js dir
846
+ f_pb = str(self.file.with_suffix(".pb")) # *.pb path
797
847
 
798
848
  gd = tf.Graph().as_graph_def() # TF GraphDef
799
- with open(f_pb, 'rb') as file:
849
+ with open(f_pb, "rb") as file:
800
850
  gd.ParseFromString(file.read())
801
- outputs = ','.join(gd_outputs(gd))
802
- LOGGER.info(f'\n{prefix} output node names: {outputs}')
851
+ outputs = ",".join(gd_outputs(gd))
852
+ LOGGER.info(f"\n{prefix} output node names: {outputs}")
803
853
 
804
- quantization = '--quantize_float16' if self.args.half else '--quantize_uint8' if self.args.int8 else ''
854
+ quantization = "--quantize_float16" if self.args.half else "--quantize_uint8" if self.args.int8 else ""
805
855
  with spaces_in_path(f_pb) as fpb_, spaces_in_path(f) as f_: # exporter can not handle spaces in path
806
856
  cmd = f'tensorflowjs_converter --input_format=tf_frozen_model {quantization} --output_node_names={outputs} "{fpb_}" "{f_}"'
807
857
  LOGGER.info(f"{prefix} running '{cmd}'")
808
858
  subprocess.run(cmd, shell=True)
809
859
 
810
- if ' ' in f:
860
+ if " " in f:
811
861
  LOGGER.warning(f"{prefix} WARNING ⚠️ your model may not work correctly with spaces in path '{f}'.")
812
862
 
813
863
  # f_json = Path(f) / 'model.json' # *.json path
@@ -824,7 +874,7 @@ class Exporter:
824
874
  # f_json.read_text(),
825
875
  # )
826
876
  # j.write(subst)
827
- yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
877
+ yaml_save(Path(f) / "metadata.yaml", self.metadata) # add metadata.yaml
828
878
  return f, None
829
879
 
830
880
  def _add_tflite_metadata(self, file):
@@ -835,14 +885,14 @@ class Exporter:
835
885
 
836
886
  # Create model info
837
887
  model_meta = _metadata_fb.ModelMetadataT()
838
- model_meta.name = self.metadata['description']
839
- model_meta.version = self.metadata['version']
840
- model_meta.author = self.metadata['author']
841
- model_meta.license = self.metadata['license']
888
+ model_meta.name = self.metadata["description"]
889
+ model_meta.version = self.metadata["version"]
890
+ model_meta.author = self.metadata["author"]
891
+ model_meta.license = self.metadata["license"]
842
892
 
843
893
  # Label file
844
- tmp_file = Path(file).parent / 'temp_meta.txt'
845
- with open(tmp_file, 'w') as f:
894
+ tmp_file = Path(file).parent / "temp_meta.txt"
895
+ with open(tmp_file, "w") as f:
846
896
  f.write(str(self.metadata))
847
897
 
848
898
  label_file = _metadata_fb.AssociatedFileT()
@@ -851,8 +901,8 @@ class Exporter:
851
901
 
852
902
  # Create input info
853
903
  input_meta = _metadata_fb.TensorMetadataT()
854
- input_meta.name = 'image'
855
- input_meta.description = 'Input image to be detected.'
904
+ input_meta.name = "image"
905
+ input_meta.description = "Input image to be detected."
856
906
  input_meta.content = _metadata_fb.ContentT()
857
907
  input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
858
908
  input_meta.content.contentProperties.colorSpace = _metadata_fb.ColorSpaceType.RGB
@@ -860,19 +910,19 @@ class Exporter:
860
910
 
861
911
  # Create output info
862
912
  output1 = _metadata_fb.TensorMetadataT()
863
- output1.name = 'output'
864
- output1.description = 'Coordinates of detected objects, class labels, and confidence score'
913
+ output1.name = "output"
914
+ output1.description = "Coordinates of detected objects, class labels, and confidence score"
865
915
  output1.associatedFiles = [label_file]
866
- if self.model.task == 'segment':
916
+ if self.model.task == "segment":
867
917
  output2 = _metadata_fb.TensorMetadataT()
868
- output2.name = 'output'
869
- output2.description = 'Mask protos'
918
+ output2.name = "output"
919
+ output2.description = "Mask protos"
870
920
  output2.associatedFiles = [label_file]
871
921
 
872
922
  # Create subgraph info
873
923
  subgraph = _metadata_fb.SubGraphMetadataT()
874
924
  subgraph.inputTensorMetadata = [input_meta]
875
- subgraph.outputTensorMetadata = [output1, output2] if self.model.task == 'segment' else [output1]
925
+ subgraph.outputTensorMetadata = [output1, output2] if self.model.task == "segment" else [output1]
876
926
  model_meta.subgraphMetadata = [subgraph]
877
927
 
878
928
  b = flatbuffers.Builder(0)
@@ -885,11 +935,11 @@ class Exporter:
885
935
  populator.populate()
886
936
  tmp_file.unlink()
887
937
 
888
- def _pipeline_coreml(self, model, weights_dir=None, prefix=colorstr('CoreML Pipeline:')):
938
+ def _pipeline_coreml(self, model, weights_dir=None, prefix=colorstr("CoreML Pipeline:")):
889
939
  """YOLOv8 CoreML pipeline."""
890
940
  import coremltools as ct # noqa
891
941
 
892
- LOGGER.info(f'{prefix} starting pipeline with coremltools {ct.__version__}...')
942
+ LOGGER.info(f"{prefix} starting pipeline with coremltools {ct.__version__}...")
893
943
  _, _, h, w = list(self.im.shape) # BCHW
894
944
 
895
945
  # Output shapes
@@ -897,8 +947,9 @@ class Exporter:
897
947
  out0, out1 = iter(spec.description.output)
898
948
  if MACOS:
899
949
  from PIL import Image
900
- img = Image.new('RGB', (w, h)) # w=192, h=320
901
- out = model.predict({'image': img})
950
+
951
+ img = Image.new("RGB", (w, h)) # w=192, h=320
952
+ out = model.predict({"image": img})
902
953
  out0_shape = out[out0.name].shape # (3780, 80)
903
954
  out1_shape = out[out1.name].shape # (3780, 4)
904
955
  else: # linux and windows can not run model.predict(), get sizes from PyTorch model output y
@@ -906,11 +957,11 @@ class Exporter:
906
957
  out1_shape = self.output_shape[2], 4 # (3780, 4)
907
958
 
908
959
  # Checks
909
- names = self.metadata['names']
960
+ names = self.metadata["names"]
910
961
  nx, ny = spec.description.input[0].type.imageType.width, spec.description.input[0].type.imageType.height
911
962
  _, nc = out0_shape # number of anchors, number of classes
912
963
  # _, nc = out0.type.multiArrayType.shape
913
- assert len(names) == nc, f'{len(names)} names found for nc={nc}' # check
964
+ assert len(names) == nc, f"{len(names)} names found for nc={nc}" # check
914
965
 
915
966
  # Define output shapes (missing)
916
967
  out0.type.multiArrayType.shape[:] = out0_shape # (3780, 80)
@@ -944,8 +995,8 @@ class Exporter:
944
995
  nms_spec.description.output.add()
945
996
  nms_spec.description.output[i].ParseFromString(decoder_output)
946
997
 
947
- nms_spec.description.output[0].name = 'confidence'
948
- nms_spec.description.output[1].name = 'coordinates'
998
+ nms_spec.description.output[0].name = "confidence"
999
+ nms_spec.description.output[1].name = "coordinates"
949
1000
 
950
1001
  output_sizes = [nc, 4]
951
1002
  for i in range(2):
@@ -961,10 +1012,10 @@ class Exporter:
961
1012
  nms = nms_spec.nonMaximumSuppression
962
1013
  nms.confidenceInputFeatureName = out0.name # 1x507x80
963
1014
  nms.coordinatesInputFeatureName = out1.name # 1x507x4
964
- nms.confidenceOutputFeatureName = 'confidence'
965
- nms.coordinatesOutputFeatureName = 'coordinates'
966
- nms.iouThresholdInputFeatureName = 'iouThreshold'
967
- nms.confidenceThresholdInputFeatureName = 'confidenceThreshold'
1015
+ nms.confidenceOutputFeatureName = "confidence"
1016
+ nms.coordinatesOutputFeatureName = "coordinates"
1017
+ nms.iouThresholdInputFeatureName = "iouThreshold"
1018
+ nms.confidenceThresholdInputFeatureName = "confidenceThreshold"
968
1019
  nms.iouThreshold = 0.45
969
1020
  nms.confidenceThreshold = 0.25
970
1021
  nms.pickTop.perClass = True
@@ -972,10 +1023,14 @@ class Exporter:
972
1023
  nms_model = ct.models.MLModel(nms_spec)
973
1024
 
974
1025
  # 4. Pipeline models together
975
- pipeline = ct.models.pipeline.Pipeline(input_features=[('image', ct.models.datatypes.Array(3, ny, nx)),
976
- ('iouThreshold', ct.models.datatypes.Double()),
977
- ('confidenceThreshold', ct.models.datatypes.Double())],
978
- output_features=['confidence', 'coordinates'])
1026
+ pipeline = ct.models.pipeline.Pipeline(
1027
+ input_features=[
1028
+ ("image", ct.models.datatypes.Array(3, ny, nx)),
1029
+ ("iouThreshold", ct.models.datatypes.Double()),
1030
+ ("confidenceThreshold", ct.models.datatypes.Double()),
1031
+ ],
1032
+ output_features=["confidence", "coordinates"],
1033
+ )
979
1034
  pipeline.add_model(model)
980
1035
  pipeline.add_model(nms_model)
981
1036
 
@@ -986,19 +1041,20 @@ class Exporter:
986
1041
 
987
1042
  # Update metadata
988
1043
  pipeline.spec.specificationVersion = 5
989
- pipeline.spec.description.metadata.userDefined.update({
990
- 'IoU threshold': str(nms.iouThreshold),
991
- 'Confidence threshold': str(nms.confidenceThreshold)})
1044
+ pipeline.spec.description.metadata.userDefined.update(
1045
+ {"IoU threshold": str(nms.iouThreshold), "Confidence threshold": str(nms.confidenceThreshold)}
1046
+ )
992
1047
 
993
1048
  # Save the model
994
1049
  model = ct.models.MLModel(pipeline.spec, weights_dir=weights_dir)
995
- model.input_description['image'] = 'Input image'
996
- model.input_description['iouThreshold'] = f'(optional) IOU threshold override (default: {nms.iouThreshold})'
997
- model.input_description['confidenceThreshold'] = \
998
- f'(optional) Confidence threshold override (default: {nms.confidenceThreshold})'
999
- model.output_description['confidence'] = 'Boxes × Class confidence (see user-defined metadata "classes")'
1000
- model.output_description['coordinates'] = 'Boxes × [x, y, width, height] (relative to image size)'
1001
- LOGGER.info(f'{prefix} pipeline success')
1050
+ model.input_description["image"] = "Input image"
1051
+ model.input_description["iouThreshold"] = f"(optional) IOU threshold override (default: {nms.iouThreshold})"
1052
+ model.input_description[
1053
+ "confidenceThreshold"
1054
+ ] = f"(optional) Confidence threshold override (default: {nms.confidenceThreshold})"
1055
+ model.output_description["confidence"] = 'Boxes × Class confidence (see user-defined metadata "classes")'
1056
+ model.output_description["coordinates"] = "Boxes × [x, y, width, height] (relative to image size)"
1057
+ LOGGER.info(f"{prefix} pipeline success")
1002
1058
  return model
1003
1059
 
1004
1060
  def add_callback(self, event: str, callback):