ultralytics 8.3.222__py3-none-any.whl → 8.3.223__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ultralytics/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
- __version__ = "8.3.222"
3
+ __version__ = "8.3.223"
4
4
 
5
5
  import importlib
6
6
  import os
@@ -342,7 +342,7 @@ names:
342
342
  322: ringlet
343
343
  323: monarch butterfly
344
344
  324: small white
345
- 325: sulphur butterfly
345
+ 325: sulfur butterfly
346
346
  326: gossamer-winged butterfly
347
347
  327: starfish
348
348
  328: sea urchin
@@ -35,7 +35,7 @@ names:
35
35
  17: armband
36
36
  18: armchair
37
37
  19: armoire
38
- 20: armor/armour
38
+ 20: armor
39
39
  21: artichoke
40
40
  22: trash can/garbage can/wastebin/dustbin/trash barrel/trash bin
41
41
  23: ashtray
@@ -245,7 +245,7 @@ names:
245
245
  227: CD player
246
246
  228: celery
247
247
  229: cellular telephone/cellular phone/cellphone/mobile phone/smart phone
248
- 230: chain mail/ring mail/chain armor/chain armour/ring armor/ring armour
248
+ 230: chain mail/ring mail/chain armor/ring armor
249
249
  231: chair
250
250
  232: chaise longue/chaise/daybed
251
251
  233: chalice
@@ -305,7 +305,7 @@ names:
305
305
  287: coin
306
306
  288: colander/cullender
307
307
  289: coleslaw/slaw
308
- 290: coloring material/colouring material
308
+ 290: coloring material
309
309
  291: combination lock
310
310
  292: pacifier/teething ring
311
311
  293: comic book
@@ -401,7 +401,7 @@ names:
401
401
  383: domestic ass/donkey
402
402
  384: doorknob/doorhandle
403
403
  385: doormat/welcome mat
404
- 386: doughnut/donut
404
+ 386: donut
405
405
  387: dove
406
406
  388: dragonfly
407
407
  389: drawer
@@ -1072,7 +1072,7 @@ names:
1072
1072
  1054: tag
1073
1073
  1055: taillight/rear light
1074
1074
  1056: tambourine
1075
- 1057: army tank/armored combat vehicle/armoured combat vehicle
1075
+ 1057: army tank/armored combat vehicle
1076
1076
  1058: tank/tank storage vessel/storage tank
1077
1077
  1059: tank top/tank top clothing
1078
1078
  1060: tape/tape sticky cloth or paper
@@ -182,7 +182,7 @@ names:
182
182
  163: Dolphin
183
183
  164: Door
184
184
  165: Door handle
185
- 166: Doughnut
185
+ 166: Donut
186
186
  167: Dragonfly
187
187
  168: Drawer
188
188
  169: Dress
@@ -107,9 +107,17 @@ from ultralytics.utils.checks import (
107
107
  is_intel,
108
108
  is_sudo_available,
109
109
  )
110
- from ultralytics.utils.downloads import attempt_download_asset, get_github_assets, safe_download
111
- from ultralytics.utils.export import onnx2engine, torch2imx, torch2onnx
112
- from ultralytics.utils.files import file_size, spaces_in_path
110
+ from ultralytics.utils.downloads import get_github_assets, safe_download
111
+ from ultralytics.utils.export import (
112
+ keras2pb,
113
+ onnx2engine,
114
+ onnx2saved_model,
115
+ pb2tfjs,
116
+ tflite2edgetpu,
117
+ torch2imx,
118
+ torch2onnx,
119
+ )
120
+ from ultralytics.utils.files import file_size
113
121
  from ultralytics.utils.metrics import batch_probiou
114
122
  from ultralytics.utils.nms import TorchNMS
115
123
  from ultralytics.utils.ops import Profile
@@ -206,15 +214,6 @@ def validate_args(format, passed_args, valid_args):
206
214
  assert arg in valid_args, f"ERROR ❌️ argument '{arg}' is not supported for format='{format}'"
207
215
 
208
216
 
209
- def gd_outputs(gd):
210
- """Return TensorFlow GraphDef model output node names."""
211
- name_list, input_list = [], []
212
- for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef
213
- name_list.append(node.name)
214
- input_list.extend(node.input)
215
- return sorted(f"{x}:0" for x in list(set(name_list) - set(input_list)) if not x.startswith("NoOp"))
216
-
217
-
218
217
  def try_export(inner_func):
219
218
  """YOLO export decorator, i.e. @try_export."""
220
219
  inner_args = get_default_args(inner_func)
@@ -371,7 +370,7 @@ class Exporter:
371
370
  LOGGER.warning("IMX export requires nms=True, setting nms=True.")
372
371
  self.args.nms = True
373
372
  if model.task not in {"detect", "pose", "classify"}:
374
- raise ValueError("IMX export only supported for detection and pose estimation models.")
373
+ raise ValueError("IMX export only supported for detection, pose estimation, and classification models.")
375
374
  if not hasattr(model, "names"):
376
375
  model.names = default_class_names()
377
376
  model.names = check_class_names(model.names)
@@ -461,6 +460,10 @@ class Exporter:
461
460
  from ultralytics.utils.export.imx import FXModel
462
461
 
463
462
  model = FXModel(model, self.imgsz)
463
+ if tflite or edgetpu:
464
+ from ultralytics.utils.export.tensorflow import tf_wrapper
465
+
466
+ model = tf_wrapper(model)
464
467
  for m in model.modules():
465
468
  if isinstance(m, Classify):
466
469
  m.export = True
@@ -642,7 +645,7 @@ class Exporter:
642
645
  assert TORCH_1_13, f"'nms=True' ONNX export requires torch>=1.13 (found torch=={TORCH_VERSION})"
643
646
 
644
647
  f = str(self.file.with_suffix(".onnx"))
645
- output_names = ["output0", "output1"] if isinstance(self.model, SegmentationModel) else ["output0"]
648
+ output_names = ["output0", "output1"] if self.model.task == "segment" else ["output0"]
646
649
  dynamic = self.args.dynamic
647
650
  if dynamic:
648
651
  dynamic = {"images": {0: "batch", 2: "height", 3: "width"}} # shape(1,3,640,640)
@@ -1053,75 +1056,43 @@ class Exporter:
1053
1056
  if f.is_dir():
1054
1057
  shutil.rmtree(f) # delete output folder
1055
1058
 
1056
- # Pre-download calibration file to fix https://github.com/PINTO0309/onnx2tf/issues/545
1057
- onnx2tf_file = Path("calibration_image_sample_data_20x128x128x3_float32.npy")
1058
- if not onnx2tf_file.exists():
1059
- attempt_download_asset(f"{onnx2tf_file}.zip", unzip=True, delete=True)
1059
+ # Export to TF
1060
+ images = None
1061
+ if self.args.int8 and self.args.data:
1062
+ images = [batch["img"] for batch in self.get_int8_calibration_dataloader(prefix)]
1063
+ images = (
1064
+ torch.nn.functional.interpolate(torch.cat(images, 0).float(), size=self.imgsz)
1065
+ .permute(0, 2, 3, 1)
1066
+ .numpy()
1067
+ .astype(np.float32)
1068
+ )
1060
1069
 
1061
1070
  # Export to ONNX
1062
1071
  if isinstance(self.model.model[-1], RTDETRDecoder):
1063
1072
  self.args.opset = self.args.opset or 19
1064
1073
  assert 16 <= self.args.opset <= 19, "RTDETR export requires opset>=16;<=19"
1065
1074
  self.args.simplify = True
1066
- f_onnx = self.export_onnx()
1067
-
1068
- # Export to TF
1069
- np_data = None
1070
- if self.args.int8:
1071
- tmp_file = f / "tmp_tflite_int8_calibration_images.npy" # int8 calibration images file
1072
- if self.args.data:
1073
- f.mkdir()
1074
- images = [batch["img"] for batch in self.get_int8_calibration_dataloader(prefix)]
1075
- images = torch.nn.functional.interpolate(torch.cat(images, 0).float(), size=self.imgsz).permute(
1076
- 0, 2, 3, 1
1077
- )
1078
- np.save(str(tmp_file), images.numpy().astype(np.float32)) # BHWC
1079
- np_data = [["images", tmp_file, [[[[0, 0, 0]]]], [[[[255, 255, 255]]]]]]
1080
-
1081
- import onnx2tf # scoped for after ONNX export for reduced conflict during import
1082
-
1083
- LOGGER.info(f"{prefix} starting TFLite export with onnx2tf {onnx2tf.__version__}...")
1084
- keras_model = onnx2tf.convert(
1085
- input_onnx_file_path=f_onnx,
1086
- output_folder_path=str(f),
1087
- not_use_onnxsim=True,
1088
- verbosity="error", # note INT8-FP16 activation bug https://github.com/ultralytics/ultralytics/issues/15873
1089
- output_integer_quantized_tflite=self.args.int8,
1090
- custom_input_op_name_np_data_path=np_data,
1091
- enable_batchmatmul_unfold=True and not self.args.int8, # fix lower no. of detected objects on GPU delegate
1092
- output_signaturedefs=True, # fix error with Attention block group convolution
1093
- disable_group_convolution=self.args.format in {"tfjs", "edgetpu"}, # fix error with group convolution
1075
+ f_onnx = self.export_onnx() # ensure ONNX is available
1076
+ keras_model = onnx2saved_model(
1077
+ f_onnx,
1078
+ f,
1079
+ int8=self.args.int8,
1080
+ images=images,
1081
+ disable_group_convolution=self.args.format in {"tfjs", "edgetpu"},
1082
+ prefix=prefix,
1094
1083
  )
1095
1084
  YAML.save(f / "metadata.yaml", self.metadata) # add metadata.yaml
1096
-
1097
- # Remove/rename TFLite models
1098
- if self.args.int8:
1099
- tmp_file.unlink(missing_ok=True)
1100
- for file in f.rglob("*_dynamic_range_quant.tflite"):
1101
- file.rename(file.with_name(file.stem.replace("_dynamic_range_quant", "_int8") + file.suffix))
1102
- for file in f.rglob("*_integer_quant_with_int16_act.tflite"):
1103
- file.unlink() # delete extra fp16 activation TFLite files
1104
-
1105
1085
  # Add TFLite metadata
1106
1086
  for file in f.rglob("*.tflite"):
1107
- f.unlink() if "quant_with_int16_act.tflite" in str(f) else self._add_tflite_metadata(file)
1087
+ file.unlink() if "quant_with_int16_act.tflite" in str(file) else self._add_tflite_metadata(file)
1108
1088
 
1109
1089
  return str(f), keras_model # or keras_model = tf.saved_model.load(f, tags=None, options=None)
1110
1090
 
1111
1091
  @try_export
1112
1092
  def export_pb(self, keras_model, prefix=colorstr("TensorFlow GraphDef:")):
1113
1093
  """Export YOLO model to TensorFlow GraphDef *.pb format https://github.com/leimao/Frozen-Graph-TensorFlow."""
1114
- import tensorflow as tf
1115
- from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
1116
-
1117
- LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
1118
1094
  f = self.file.with_suffix(".pb")
1119
-
1120
- m = tf.function(lambda x: keras_model(x)) # full model
1121
- m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
1122
- frozen_func = convert_variables_to_constants_v2(m)
1123
- frozen_func.graph.as_graph_def()
1124
- tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)
1095
+ keras2pb(keras_model, f, prefix)
1125
1096
  return f
1126
1097
 
1127
1098
  @try_export
@@ -1189,22 +1160,11 @@ class Exporter:
1189
1160
  "sudo apt-get install edgetpu-compiler",
1190
1161
  ):
1191
1162
  subprocess.run(c if is_sudo_available() else c.replace("sudo ", ""), shell=True, check=True)
1192
- ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().rsplit(maxsplit=1)[-1]
1193
1163
 
1164
+ ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().rsplit(maxsplit=1)[-1]
1194
1165
  LOGGER.info(f"\n{prefix} starting export with Edge TPU compiler {ver}...")
1166
+ tflite2edgetpu(tflite_file=tflite_model, output_dir=tflite_model.parent, prefix=prefix)
1195
1167
  f = str(tflite_model).replace(".tflite", "_edgetpu.tflite") # Edge TPU model
1196
-
1197
- cmd = (
1198
- "edgetpu_compiler "
1199
- f'--out_dir "{Path(f).parent}" '
1200
- "--show_operations "
1201
- "--search_delegate "
1202
- "--delegate_search_step 30 "
1203
- "--timeout_sec 180 "
1204
- f'"{tflite_model}"'
1205
- )
1206
- LOGGER.info(f"{prefix} running '{cmd}'")
1207
- subprocess.run(cmd, shell=True)
1208
1168
  self._add_tflite_metadata(f)
1209
1169
  return f
1210
1170
 
@@ -1212,31 +1172,10 @@ class Exporter:
1212
1172
  def export_tfjs(self, prefix=colorstr("TensorFlow.js:")):
1213
1173
  """Export YOLO model to TensorFlow.js format."""
1214
1174
  check_requirements("tensorflowjs")
1215
- import tensorflow as tf
1216
- import tensorflowjs as tfjs
1217
1175
 
1218
- LOGGER.info(f"\n{prefix} starting export with tensorflowjs {tfjs.__version__}...")
1219
1176
  f = str(self.file).replace(self.file.suffix, "_web_model") # js dir
1220
1177
  f_pb = str(self.file.with_suffix(".pb")) # *.pb path
1221
-
1222
- gd = tf.Graph().as_graph_def() # TF GraphDef
1223
- with open(f_pb, "rb") as file:
1224
- gd.ParseFromString(file.read())
1225
- outputs = ",".join(gd_outputs(gd))
1226
- LOGGER.info(f"\n{prefix} output node names: {outputs}")
1227
-
1228
- quantization = "--quantize_float16" if self.args.half else "--quantize_uint8" if self.args.int8 else ""
1229
- with spaces_in_path(f_pb) as fpb_, spaces_in_path(f) as f_: # exporter can not handle spaces in path
1230
- cmd = (
1231
- "tensorflowjs_converter "
1232
- f'--input_format=tf_frozen_model {quantization} --output_node_names={outputs} "{fpb_}" "{f_}"'
1233
- )
1234
- LOGGER.info(f"{prefix} running '{cmd}'")
1235
- subprocess.run(cmd, shell=True)
1236
-
1237
- if " " in f:
1238
- LOGGER.warning(f"{prefix} your model may not work correctly with spaces in path '{f}'.")
1239
-
1178
+ pb2tfjs(pb_file=f_pb, output_dir=f, half=self.args.half, int8=self.args.int8, prefix=prefix)
1240
1179
  # Add metadata
1241
1180
  YAML.save(Path(f) / "metadata.yaml", self.metadata) # add metadata.yaml
1242
1181
  return f
@@ -89,7 +89,7 @@ class RTDETRDataset(YOLODataset):
89
89
  transforms = v8_transforms(self, self.imgsz, hyp, stretch=True)
90
90
  else:
91
91
  # transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scale_fill=True)])
92
- transforms = Compose([])
92
+ transforms = Compose([lambda x: {**x, **{"ratio_pad": [x["ratio_pad"], [0, 0]]}}])
93
93
  transforms.append(
94
94
  Format(
95
95
  bbox_format="xywh",
@@ -428,7 +428,7 @@ class AutoBackend(nn.Module):
428
428
  LOGGER.info(f"Loading {w} for TensorFlow GraphDef inference...")
429
429
  import tensorflow as tf
430
430
 
431
- from ultralytics.engine.exporter import gd_outputs
431
+ from ultralytics.utils.export.tensorflow import gd_outputs
432
432
 
433
433
  def wrap_frozen_graph(gd, inputs, outputs):
434
434
  """Wrap frozen graphs for deployment."""
@@ -166,22 +166,8 @@ class Detect(nn.Module):
166
166
  self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
167
167
  self.shape = shape
168
168
 
169
- if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops
170
- box = x_cat[:, : self.reg_max * 4]
171
- cls = x_cat[:, self.reg_max * 4 :]
172
- else:
173
- box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
174
-
175
- if self.export and self.format in {"tflite", "edgetpu"}:
176
- # Precompute normalization factor to increase numerical stability
177
- # See https://github.com/ultralytics/ultralytics/issues/7371
178
- grid_h = shape[2]
179
- grid_w = shape[3]
180
- grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
181
- norm = self.strides / (self.stride[0] * grid_size)
182
- dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
183
- else:
184
- dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
169
+ box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
170
+ dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
185
171
  return torch.cat((dbox, cls.sigmoid()), 1)
186
172
 
187
173
  def bias_init(self):
@@ -391,20 +377,9 @@ class Pose(Detect):
391
377
  """Decode keypoints from predictions."""
392
378
  ndim = self.kpt_shape[1]
393
379
  if self.export:
394
- if self.format in {
395
- "tflite",
396
- "edgetpu",
397
- }: # required for TFLite export to avoid 'PLACEHOLDER_FOR_GREATER_OP_CODES' bug
398
- # Precompute normalization factor to increase numerical stability
399
- y = kpts.view(bs, *self.kpt_shape, -1)
400
- grid_h, grid_w = self.shape[2], self.shape[3]
401
- grid_size = torch.tensor([grid_w, grid_h], device=y.device).reshape(1, 2, 1)
402
- norm = self.strides / (self.stride[0] * grid_size)
403
- a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * norm
404
- else:
405
- # NCNN fix
406
- y = kpts.view(bs, *self.kpt_shape, -1)
407
- a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides
380
+ # NCNN fix
381
+ y = kpts.view(bs, *self.kpt_shape, -1)
382
+ a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides
408
383
  if ndim == 3:
409
384
  a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2)
410
385
  return a.view(bs, self.nk, -1)
@@ -1,242 +1,7 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
- from __future__ import annotations
3
+ from .engine import onnx2engine, torch2onnx
4
+ from .imx import torch2imx
5
+ from .tensorflow import keras2pb, onnx2saved_model, pb2tfjs, tflite2edgetpu
4
6
 
5
- import json
6
- from pathlib import Path
7
-
8
- import torch
9
-
10
- from ultralytics.utils import IS_JETSON, LOGGER
11
- from ultralytics.utils.torch_utils import TORCH_2_4
12
-
13
- from .imx import torch2imx # noqa
14
-
15
-
16
- def torch2onnx(
17
- torch_model: torch.nn.Module,
18
- im: torch.Tensor,
19
- onnx_file: str,
20
- opset: int = 14,
21
- input_names: list[str] = ["images"],
22
- output_names: list[str] = ["output0"],
23
- dynamic: bool | dict = False,
24
- ) -> None:
25
- """
26
- Export a PyTorch model to ONNX format.
27
-
28
- Args:
29
- torch_model (torch.nn.Module): The PyTorch model to export.
30
- im (torch.Tensor): Example input tensor for the model.
31
- onnx_file (str): Path to save the exported ONNX file.
32
- opset (int): ONNX opset version to use for export.
33
- input_names (list[str]): List of input tensor names.
34
- output_names (list[str]): List of output tensor names.
35
- dynamic (bool | dict, optional): Whether to enable dynamic axes.
36
-
37
- Notes:
38
- Setting `do_constant_folding=True` may cause issues with DNN inference for torch>=1.12.
39
- """
40
- kwargs = {"dynamo": False} if TORCH_2_4 else {}
41
- torch.onnx.export(
42
- torch_model,
43
- im,
44
- onnx_file,
45
- verbose=False,
46
- opset_version=opset,
47
- do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
48
- input_names=input_names,
49
- output_names=output_names,
50
- dynamic_axes=dynamic or None,
51
- **kwargs,
52
- )
53
-
54
-
55
- def onnx2engine(
56
- onnx_file: str,
57
- engine_file: str | None = None,
58
- workspace: int | None = None,
59
- half: bool = False,
60
- int8: bool = False,
61
- dynamic: bool = False,
62
- shape: tuple[int, int, int, int] = (1, 3, 640, 640),
63
- dla: int | None = None,
64
- dataset=None,
65
- metadata: dict | None = None,
66
- verbose: bool = False,
67
- prefix: str = "",
68
- ) -> None:
69
- """
70
- Export a YOLO model to TensorRT engine format.
71
-
72
- Args:
73
- onnx_file (str): Path to the ONNX file to be converted.
74
- engine_file (str, optional): Path to save the generated TensorRT engine file.
75
- workspace (int, optional): Workspace size in GB for TensorRT.
76
- half (bool, optional): Enable FP16 precision.
77
- int8 (bool, optional): Enable INT8 precision.
78
- dynamic (bool, optional): Enable dynamic input shapes.
79
- shape (tuple[int, int, int, int], optional): Input shape (batch, channels, height, width).
80
- dla (int, optional): DLA core to use (Jetson devices only).
81
- dataset (ultralytics.data.build.InfiniteDataLoader, optional): Dataset for INT8 calibration.
82
- metadata (dict, optional): Metadata to include in the engine file.
83
- verbose (bool, optional): Enable verbose logging.
84
- prefix (str, optional): Prefix for log messages.
85
-
86
- Raises:
87
- ValueError: If DLA is enabled on non-Jetson devices or required precision is not set.
88
- RuntimeError: If the ONNX file cannot be parsed.
89
-
90
- Notes:
91
- TensorRT version compatibility is handled for workspace size and engine building.
92
- INT8 calibration requires a dataset and generates a calibration cache.
93
- Metadata is serialized and written to the engine file if provided.
94
- """
95
- import tensorrt as trt
96
-
97
- engine_file = engine_file or Path(onnx_file).with_suffix(".engine")
98
-
99
- logger = trt.Logger(trt.Logger.INFO)
100
- if verbose:
101
- logger.min_severity = trt.Logger.Severity.VERBOSE
102
-
103
- # Engine builder
104
- builder = trt.Builder(logger)
105
- config = builder.create_builder_config()
106
- workspace_bytes = int((workspace or 0) * (1 << 30))
107
- is_trt10 = int(trt.__version__.split(".", 1)[0]) >= 10 # is TensorRT >= 10
108
- if is_trt10 and workspace_bytes > 0:
109
- config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace_bytes)
110
- elif workspace_bytes > 0: # TensorRT versions 7, 8
111
- config.max_workspace_size = workspace_bytes
112
- flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
113
- network = builder.create_network(flag)
114
- half = builder.platform_has_fast_fp16 and half
115
- int8 = builder.platform_has_fast_int8 and int8
116
-
117
- # Optionally switch to DLA if enabled
118
- if dla is not None:
119
- if not IS_JETSON:
120
- raise ValueError("DLA is only available on NVIDIA Jetson devices")
121
- LOGGER.info(f"{prefix} enabling DLA on core {dla}...")
122
- if not half and not int8:
123
- raise ValueError(
124
- "DLA requires either 'half=True' (FP16) or 'int8=True' (INT8) to be enabled. Please enable one of them and try again."
125
- )
126
- config.default_device_type = trt.DeviceType.DLA
127
- config.DLA_core = int(dla)
128
- config.set_flag(trt.BuilderFlag.GPU_FALLBACK)
129
-
130
- # Read ONNX file
131
- parser = trt.OnnxParser(network, logger)
132
- if not parser.parse_from_file(onnx_file):
133
- raise RuntimeError(f"failed to load ONNX file: {onnx_file}")
134
-
135
- # Network inputs
136
- inputs = [network.get_input(i) for i in range(network.num_inputs)]
137
- outputs = [network.get_output(i) for i in range(network.num_outputs)]
138
- for inp in inputs:
139
- LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}')
140
- for out in outputs:
141
- LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}')
142
-
143
- if dynamic:
144
- profile = builder.create_optimization_profile()
145
- min_shape = (1, shape[1], 32, 32) # minimum input shape
146
- max_shape = (*shape[:2], *(int(max(2, workspace or 2) * d) for d in shape[2:])) # max input shape
147
- for inp in inputs:
148
- profile.set_shape(inp.name, min=min_shape, opt=shape, max=max_shape)
149
- config.add_optimization_profile(profile)
150
- if int8:
151
- config.set_calibration_profile(profile)
152
-
153
- LOGGER.info(f"{prefix} building {'INT8' if int8 else 'FP' + ('16' if half else '32')} engine as {engine_file}")
154
- if int8:
155
- config.set_flag(trt.BuilderFlag.INT8)
156
- config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED
157
-
158
- class EngineCalibrator(trt.IInt8Calibrator):
159
- """
160
- Custom INT8 calibrator for TensorRT engine optimization.
161
-
162
- This calibrator provides the necessary interface for TensorRT to perform INT8 quantization calibration
163
- using a dataset. It handles batch generation, caching, and calibration algorithm selection.
164
-
165
- Attributes:
166
- dataset: Dataset for calibration.
167
- data_iter: Iterator over the calibration dataset.
168
- algo (trt.CalibrationAlgoType): Calibration algorithm type.
169
- batch (int): Batch size for calibration.
170
- cache (Path): Path to save the calibration cache.
171
-
172
- Methods:
173
- get_algorithm: Get the calibration algorithm to use.
174
- get_batch_size: Get the batch size to use for calibration.
175
- get_batch: Get the next batch to use for calibration.
176
- read_calibration_cache: Use existing cache instead of calibrating again.
177
- write_calibration_cache: Write calibration cache to disk.
178
- """
179
-
180
- def __init__(
181
- self,
182
- dataset, # ultralytics.data.build.InfiniteDataLoader
183
- cache: str = "",
184
- ) -> None:
185
- """Initialize the INT8 calibrator with dataset and cache path."""
186
- trt.IInt8Calibrator.__init__(self)
187
- self.dataset = dataset
188
- self.data_iter = iter(dataset)
189
- self.algo = (
190
- trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2 # DLA quantization needs ENTROPY_CALIBRATION_2
191
- if dla is not None
192
- else trt.CalibrationAlgoType.MINMAX_CALIBRATION
193
- )
194
- self.batch = dataset.batch_size
195
- self.cache = Path(cache)
196
-
197
- def get_algorithm(self) -> trt.CalibrationAlgoType:
198
- """Get the calibration algorithm to use."""
199
- return self.algo
200
-
201
- def get_batch_size(self) -> int:
202
- """Get the batch size to use for calibration."""
203
- return self.batch or 1
204
-
205
- def get_batch(self, names) -> list[int] | None:
206
- """Get the next batch to use for calibration, as a list of device memory pointers."""
207
- try:
208
- im0s = next(self.data_iter)["img"] / 255.0
209
- im0s = im0s.to("cuda") if im0s.device.type == "cpu" else im0s
210
- return [int(im0s.data_ptr())]
211
- except StopIteration:
212
- # Return None to signal to TensorRT there is no calibration data remaining
213
- return None
214
-
215
- def read_calibration_cache(self) -> bytes | None:
216
- """Use existing cache instead of calibrating again, otherwise, implicitly return None."""
217
- if self.cache.exists() and self.cache.suffix == ".cache":
218
- return self.cache.read_bytes()
219
-
220
- def write_calibration_cache(self, cache: bytes) -> None:
221
- """Write calibration cache to disk."""
222
- _ = self.cache.write_bytes(cache)
223
-
224
- # Load dataset w/ builder (for batching) and calibrate
225
- config.int8_calibrator = EngineCalibrator(
226
- dataset=dataset,
227
- cache=str(Path(onnx_file).with_suffix(".cache")),
228
- )
229
-
230
- elif half:
231
- config.set_flag(trt.BuilderFlag.FP16)
232
-
233
- # Write file
234
- build = builder.build_serialized_network if is_trt10 else builder.build_engine
235
- with build(network, config) as engine, open(engine_file, "wb") as t:
236
- # Metadata
237
- if metadata is not None:
238
- meta = json.dumps(metadata)
239
- t.write(len(meta).to_bytes(4, byteorder="little", signed=True))
240
- t.write(meta.encode())
241
- # Model
242
- t.write(engine if is_trt10 else engine.serialize())
7
+ __all__ = ["keras2pb", "onnx2engine", "onnx2saved_model", "pb2tfjs", "tflite2edgetpu", "torch2imx", "torch2onnx"]
@@ -0,0 +1,240 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from pathlib import Path
7
+
8
+ import torch
9
+
10
+ from ultralytics.utils import IS_JETSON, LOGGER
11
+ from ultralytics.utils.torch_utils import TORCH_2_4
12
+
13
+
14
+ def torch2onnx(
15
+ torch_model: torch.nn.Module,
16
+ im: torch.Tensor,
17
+ onnx_file: str,
18
+ opset: int = 14,
19
+ input_names: list[str] = ["images"],
20
+ output_names: list[str] = ["output0"],
21
+ dynamic: bool | dict = False,
22
+ ) -> None:
23
+ """
24
+ Export a PyTorch model to ONNX format.
25
+
26
+ Args:
27
+ torch_model (torch.nn.Module): The PyTorch model to export.
28
+ im (torch.Tensor): Example input tensor for the model.
29
+ onnx_file (str): Path to save the exported ONNX file.
30
+ opset (int): ONNX opset version to use for export.
31
+ input_names (list[str]): List of input tensor names.
32
+ output_names (list[str]): List of output tensor names.
33
+ dynamic (bool | dict, optional): Whether to enable dynamic axes.
34
+
35
+ Notes:
36
+ Setting `do_constant_folding=True` may cause issues with DNN inference for torch>=1.12.
37
+ """
38
+ kwargs = {"dynamo": False} if TORCH_2_4 else {}
39
+ torch.onnx.export(
40
+ torch_model,
41
+ im,
42
+ onnx_file,
43
+ verbose=False,
44
+ opset_version=opset,
45
+ do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
46
+ input_names=input_names,
47
+ output_names=output_names,
48
+ dynamic_axes=dynamic or None,
49
+ **kwargs,
50
+ )
51
+
52
+
53
+ def onnx2engine(
54
+ onnx_file: str,
55
+ engine_file: str | None = None,
56
+ workspace: int | None = None,
57
+ half: bool = False,
58
+ int8: bool = False,
59
+ dynamic: bool = False,
60
+ shape: tuple[int, int, int, int] = (1, 3, 640, 640),
61
+ dla: int | None = None,
62
+ dataset=None,
63
+ metadata: dict | None = None,
64
+ verbose: bool = False,
65
+ prefix: str = "",
66
+ ) -> None:
67
+ """
68
+ Export a YOLO model to TensorRT engine format.
69
+
70
+ Args:
71
+ onnx_file (str): Path to the ONNX file to be converted.
72
+ engine_file (str, optional): Path to save the generated TensorRT engine file.
73
+ workspace (int, optional): Workspace size in GB for TensorRT.
74
+ half (bool, optional): Enable FP16 precision.
75
+ int8 (bool, optional): Enable INT8 precision.
76
+ dynamic (bool, optional): Enable dynamic input shapes.
77
+ shape (tuple[int, int, int, int], optional): Input shape (batch, channels, height, width).
78
+ dla (int, optional): DLA core to use (Jetson devices only).
79
+ dataset (ultralytics.data.build.InfiniteDataLoader, optional): Dataset for INT8 calibration.
80
+ metadata (dict, optional): Metadata to include in the engine file.
81
+ verbose (bool, optional): Enable verbose logging.
82
+ prefix (str, optional): Prefix for log messages.
83
+
84
+ Raises:
85
+ ValueError: If DLA is enabled on non-Jetson devices or required precision is not set.
86
+ RuntimeError: If the ONNX file cannot be parsed.
87
+
88
+ Notes:
89
+ TensorRT version compatibility is handled for workspace size and engine building.
90
+ INT8 calibration requires a dataset and generates a calibration cache.
91
+ Metadata is serialized and written to the engine file if provided.
92
+ """
93
+ import tensorrt as trt
94
+
95
+ engine_file = engine_file or Path(onnx_file).with_suffix(".engine")
96
+
97
+ logger = trt.Logger(trt.Logger.INFO)
98
+ if verbose:
99
+ logger.min_severity = trt.Logger.Severity.VERBOSE
100
+
101
+ # Engine builder
102
+ builder = trt.Builder(logger)
103
+ config = builder.create_builder_config()
104
+ workspace_bytes = int((workspace or 0) * (1 << 30))
105
+ is_trt10 = int(trt.__version__.split(".", 1)[0]) >= 10 # is TensorRT >= 10
106
+ if is_trt10 and workspace_bytes > 0:
107
+ config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace_bytes)
108
+ elif workspace_bytes > 0: # TensorRT versions 7, 8
109
+ config.max_workspace_size = workspace_bytes
110
+ flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
111
+ network = builder.create_network(flag)
112
+ half = builder.platform_has_fast_fp16 and half
113
+ int8 = builder.platform_has_fast_int8 and int8
114
+
115
+ # Optionally switch to DLA if enabled
116
+ if dla is not None:
117
+ if not IS_JETSON:
118
+ raise ValueError("DLA is only available on NVIDIA Jetson devices")
119
+ LOGGER.info(f"{prefix} enabling DLA on core {dla}...")
120
+ if not half and not int8:
121
+ raise ValueError(
122
+ "DLA requires either 'half=True' (FP16) or 'int8=True' (INT8) to be enabled. Please enable one of them and try again."
123
+ )
124
+ config.default_device_type = trt.DeviceType.DLA
125
+ config.DLA_core = int(dla)
126
+ config.set_flag(trt.BuilderFlag.GPU_FALLBACK)
127
+
128
+ # Read ONNX file
129
+ parser = trt.OnnxParser(network, logger)
130
+ if not parser.parse_from_file(onnx_file):
131
+ raise RuntimeError(f"failed to load ONNX file: {onnx_file}")
132
+
133
+ # Network inputs
134
+ inputs = [network.get_input(i) for i in range(network.num_inputs)]
135
+ outputs = [network.get_output(i) for i in range(network.num_outputs)]
136
+ for inp in inputs:
137
+ LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}')
138
+ for out in outputs:
139
+ LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}')
140
+
141
+ if dynamic:
142
+ profile = builder.create_optimization_profile()
143
+ min_shape = (1, shape[1], 32, 32) # minimum input shape
144
+ max_shape = (*shape[:2], *(int(max(2, workspace or 2) * d) for d in shape[2:])) # max input shape
145
+ for inp in inputs:
146
+ profile.set_shape(inp.name, min=min_shape, opt=shape, max=max_shape)
147
+ config.add_optimization_profile(profile)
148
+ if int8:
149
+ config.set_calibration_profile(profile)
150
+
151
+ LOGGER.info(f"{prefix} building {'INT8' if int8 else 'FP' + ('16' if half else '32')} engine as {engine_file}")
152
+ if int8:
153
+ config.set_flag(trt.BuilderFlag.INT8)
154
+ config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED
155
+
156
+ class EngineCalibrator(trt.IInt8Calibrator):
157
+ """
158
+ Custom INT8 calibrator for TensorRT engine optimization.
159
+
160
+ This calibrator provides the necessary interface for TensorRT to perform INT8 quantization calibration
161
+ using a dataset. It handles batch generation, caching, and calibration algorithm selection.
162
+
163
+ Attributes:
164
+ dataset: Dataset for calibration.
165
+ data_iter: Iterator over the calibration dataset.
166
+ algo (trt.CalibrationAlgoType): Calibration algorithm type.
167
+ batch (int): Batch size for calibration.
168
+ cache (Path): Path to save the calibration cache.
169
+
170
+ Methods:
171
+ get_algorithm: Get the calibration algorithm to use.
172
+ get_batch_size: Get the batch size to use for calibration.
173
+ get_batch: Get the next batch to use for calibration.
174
+ read_calibration_cache: Use existing cache instead of calibrating again.
175
+ write_calibration_cache: Write calibration cache to disk.
176
+ """
177
+
178
+ def __init__(
179
+ self,
180
+ dataset, # ultralytics.data.build.InfiniteDataLoader
181
+ cache: str = "",
182
+ ) -> None:
183
+ """Initialize the INT8 calibrator with dataset and cache path."""
184
+ trt.IInt8Calibrator.__init__(self)
185
+ self.dataset = dataset
186
+ self.data_iter = iter(dataset)
187
+ self.algo = (
188
+ trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2 # DLA quantization needs ENTROPY_CALIBRATION_2
189
+ if dla is not None
190
+ else trt.CalibrationAlgoType.MINMAX_CALIBRATION
191
+ )
192
+ self.batch = dataset.batch_size
193
+ self.cache = Path(cache)
194
+
195
+ def get_algorithm(self) -> trt.CalibrationAlgoType:
196
+ """Get the calibration algorithm to use."""
197
+ return self.algo
198
+
199
+ def get_batch_size(self) -> int:
200
+ """Get the batch size to use for calibration."""
201
+ return self.batch or 1
202
+
203
+ def get_batch(self, names) -> list[int] | None:
204
+ """Get the next batch to use for calibration, as a list of device memory pointers."""
205
+ try:
206
+ im0s = next(self.data_iter)["img"] / 255.0
207
+ im0s = im0s.to("cuda") if im0s.device.type == "cpu" else im0s
208
+ return [int(im0s.data_ptr())]
209
+ except StopIteration:
210
+ # Return None to signal to TensorRT there is no calibration data remaining
211
+ return None
212
+
213
+ def read_calibration_cache(self) -> bytes | None:
214
+ """Use existing cache instead of calibrating again, otherwise, implicitly return None."""
215
+ if self.cache.exists() and self.cache.suffix == ".cache":
216
+ return self.cache.read_bytes()
217
+
218
+ def write_calibration_cache(self, cache: bytes) -> None:
219
+ """Write calibration cache to disk."""
220
+ _ = self.cache.write_bytes(cache)
221
+
222
+ # Load dataset w/ builder (for batching) and calibrate
223
+ config.int8_calibrator = EngineCalibrator(
224
+ dataset=dataset,
225
+ cache=str(Path(onnx_file).with_suffix(".cache")),
226
+ )
227
+
228
+ elif half:
229
+ config.set_flag(trt.BuilderFlag.FP16)
230
+
231
+ # Write file
232
+ build = builder.build_serialized_network if is_trt10 else builder.build_engine
233
+ with build(network, config) as engine, open(engine_file, "wb") as t:
234
+ # Metadata
235
+ if metadata is not None:
236
+ meta = json.dumps(metadata)
237
+ t.write(len(meta).to_bytes(4, byteorder="little", signed=True))
238
+ t.write(meta.encode())
239
+ # Model
240
+ t.write(engine if is_trt10 else engine.serialize())
@@ -0,0 +1,221 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ from ultralytics.nn.modules import Detect, Pose
11
+ from ultralytics.utils import LOGGER
12
+ from ultralytics.utils.downloads import attempt_download_asset
13
+ from ultralytics.utils.files import spaces_in_path
14
+ from ultralytics.utils.tal import make_anchors
15
+
16
+
17
+ def tf_wrapper(model: torch.nn.Module) -> torch.nn.Module:
18
+ """A wrapper to add TensorFlow compatible inference methods to Detect and Pose layers."""
19
+ for m in model.modules():
20
+ if not isinstance(m, Detect):
21
+ continue
22
+ import types
23
+
24
+ m._inference = types.MethodType(_tf_inference, m)
25
+ if type(m) is Pose:
26
+ m.kpts_decode = types.MethodType(tf_kpts_decode, m)
27
+ return model
28
+
29
+
30
+ def _tf_inference(self, x: list[torch.Tensor]) -> tuple[torch.Tensor]:
31
+ """Decode boxes and cls scores for tf object detection."""
32
+ shape = x[0].shape # BCHW
33
+ x_cat = torch.cat([xi.view(x[0].shape[0], self.no, -1) for xi in x], 2)
34
+ box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
35
+ if self.dynamic or self.shape != shape:
36
+ self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
37
+ self.shape = shape
38
+ grid_h, grid_w = shape[2], shape[3]
39
+ grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
40
+ norm = self.strides / (self.stride[0] * grid_size)
41
+ dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
42
+ return torch.cat((dbox, cls.sigmoid()), 1)
43
+
44
+
45
+ def tf_kpts_decode(self, bs: int, kpts: torch.Tensor) -> torch.Tensor:
46
+ """Decode keypoints for tf pose estimation."""
47
+ ndim = self.kpt_shape[1]
48
+ # required for TFLite export to avoid 'PLACEHOLDER_FOR_GREATER_OP_CODES' bug
49
+ # Precompute normalization factor to increase numerical stability
50
+ y = kpts.view(bs, *self.kpt_shape, -1)
51
+ grid_h, grid_w = self.shape[2], self.shape[3]
52
+ grid_size = torch.tensor([grid_w, grid_h], device=y.device).reshape(1, 2, 1)
53
+ norm = self.strides / (self.stride[0] * grid_size)
54
+ a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * norm
55
+ if ndim == 3:
56
+ a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2)
57
+ return a.view(bs, self.nk, -1)
58
+
59
+
60
+ def onnx2saved_model(
61
+ onnx_file: str,
62
+ output_dir: Path,
63
+ int8: bool = False,
64
+ images: np.ndarray = None,
65
+ disable_group_convolution: bool = False,
66
+ prefix="",
67
+ ):
68
+ """
69
+ Convert a ONNX model to TensorFlow SavedModel format via ONNX.
70
+
71
+ Args:
72
+ onnx_file (str): ONNX file path.
73
+ output_dir (Path): Output directory path for the SavedModel.
74
+ int8 (bool, optional): Enable INT8 quantization. Defaults to False.
75
+ images (np.ndarray, optional): Calibration images for INT8 quantization in BHWC format.
76
+ disable_group_convolution (bool, optional): Disable group convolution optimization. Defaults to False.
77
+ prefix (str, optional): Logging prefix. Defaults to "".
78
+
79
+ Returns:
80
+ (keras.Model): Converted Keras model.
81
+
82
+ Note:
83
+ Requires onnx2tf package. Downloads calibration data if INT8 quantization is enabled.
84
+ Removes temporary files and renames quantized models after conversion.
85
+ """
86
+ # Pre-download calibration file to fix https://github.com/PINTO0309/onnx2tf/issues/545
87
+ onnx2tf_file = Path("calibration_image_sample_data_20x128x128x3_float32.npy")
88
+ if not onnx2tf_file.exists():
89
+ attempt_download_asset(f"{onnx2tf_file}.zip", unzip=True, delete=True)
90
+ np_data = None
91
+ if int8:
92
+ tmp_file = output_dir / "tmp_tflite_int8_calibration_images.npy" # int8 calibration images file
93
+ if images is not None:
94
+ output_dir.mkdir()
95
+ np.save(str(tmp_file), images) # BHWC
96
+ np_data = [["images", tmp_file, [[[[0, 0, 0]]]], [[[[255, 255, 255]]]]]]
97
+
98
+ import onnx2tf # scoped for after ONNX export for reduced conflict during import
99
+
100
+ LOGGER.info(f"{prefix} starting TFLite export with onnx2tf {onnx2tf.__version__}...")
101
+ keras_model = onnx2tf.convert(
102
+ input_onnx_file_path=onnx_file,
103
+ output_folder_path=str(output_dir),
104
+ not_use_onnxsim=True,
105
+ verbosity="error", # note INT8-FP16 activation bug https://github.com/ultralytics/ultralytics/issues/15873
106
+ output_integer_quantized_tflite=int8,
107
+ custom_input_op_name_np_data_path=np_data,
108
+ enable_batchmatmul_unfold=True and not int8, # fix lower no. of detected objects on GPU delegate
109
+ output_signaturedefs=True, # fix error with Attention block group convolution
110
+ disable_group_convolution=disable_group_convolution, # fix error with group convolution
111
+ )
112
+
113
+ # Remove/rename TFLite models
114
+ if int8:
115
+ tmp_file.unlink(missing_ok=True)
116
+ for file in output_dir.rglob("*_dynamic_range_quant.tflite"):
117
+ file.rename(file.with_name(file.stem.replace("_dynamic_range_quant", "_int8") + file.suffix))
118
+ for file in output_dir.rglob("*_integer_quant_with_int16_act.tflite"):
119
+ file.unlink() # delete extra fp16 activation TFLite files
120
+ return keras_model
121
+
122
+
123
+ def keras2pb(keras_model, file: Path, prefix=""):
124
+ """
125
+ Convert a Keras model to TensorFlow GraphDef (.pb) format.
126
+
127
+ Args:
128
+ keras_model(tf_keras): Keras model to convert to frozen graph format.
129
+ file (Path): Output file path (suffix will be changed to .pb).
130
+ prefix (str, optional): Logging prefix. Defaults to "".
131
+
132
+ Note:
133
+ Creates a frozen graph by converting variables to constants for inference optimization.
134
+ """
135
+ import tensorflow as tf
136
+ from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
137
+
138
+ LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
139
+ m = tf.function(lambda x: keras_model(x)) # full model
140
+ m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
141
+ frozen_func = convert_variables_to_constants_v2(m)
142
+ frozen_func.graph.as_graph_def()
143
+ tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(file.parent), name=file.name, as_text=False)
144
+
145
+
146
+ def tflite2edgetpu(tflite_file: str | Path, output_dir: str | Path, prefix: str = ""):
147
+ """
148
+ Convert a TensorFlow Lite model to Edge TPU format using the Edge TPU compiler.
149
+
150
+ Args:
151
+ tflite_file (str | Path): Path to the input TensorFlow Lite (.tflite) model file.
152
+ output_dir (str | Path): Output directory path for the compiled Edge TPU model.
153
+ prefix (str, optional): Logging prefix. Defaults to "".
154
+
155
+ Note:
156
+ Requires the Edge TPU compiler to be installed. The function compiles the TFLite model
157
+ for optimal performance on Google's Edge TPU hardware accelerator.
158
+ """
159
+ import subprocess
160
+
161
+ cmd = (
162
+ "edgetpu_compiler "
163
+ f'--out_dir "{output_dir}" '
164
+ "--show_operations "
165
+ "--search_delegate "
166
+ "--delegate_search_step 30 "
167
+ "--timeout_sec 180 "
168
+ f'"{tflite_file}"'
169
+ )
170
+ LOGGER.info(f"{prefix} running '{cmd}'")
171
+ subprocess.run(cmd, shell=True)
172
+
173
+
174
+ def pb2tfjs(pb_file: str, output_dir: str, half: bool = False, int8: bool = False, prefix: str = ""):
175
+ """
176
+ Convert a TensorFlow GraphDef (.pb) model to TensorFlow.js format.
177
+
178
+ Args:
179
+ pb_file (str): Path to the input TensorFlow GraphDef (.pb) model file.
180
+ output_dir (str): Output directory path for the converted TensorFlow.js model.
181
+ half (bool, optional): Enable FP16 quantization. Defaults to False.
182
+ int8 (bool, optional): Enable INT8 quantization. Defaults to False.
183
+ prefix (str, optional): Logging prefix. Defaults to "".
184
+
185
+ Note:
186
+ Requires tensorflowjs package. Uses tensorflowjs_converter command-line tool for conversion.
187
+ Handles spaces in file paths and warns if output directory contains spaces.
188
+ """
189
+ import subprocess
190
+
191
+ import tensorflow as tf
192
+ import tensorflowjs as tfjs
193
+
194
+ LOGGER.info(f"\n{prefix} starting export with tensorflowjs {tfjs.__version__}...")
195
+
196
+ gd = tf.Graph().as_graph_def() # TF GraphDef
197
+ with open(pb_file, "rb") as file:
198
+ gd.ParseFromString(file.read())
199
+ outputs = ",".join(gd_outputs(gd))
200
+ LOGGER.info(f"\n{prefix} output node names: {outputs}")
201
+
202
+ quantization = "--quantize_float16" if half else "--quantize_uint8" if int8 else ""
203
+ with spaces_in_path(pb_file) as fpb_, spaces_in_path(output_dir) as f_: # exporter can not handle spaces in path
204
+ cmd = (
205
+ "tensorflowjs_converter "
206
+ f'--input_format=tf_frozen_model {quantization} --output_node_names={outputs} "{fpb_}" "{f_}"'
207
+ )
208
+ LOGGER.info(f"{prefix} running '{cmd}'")
209
+ subprocess.run(cmd, shell=True)
210
+
211
+ if " " in output_dir:
212
+ LOGGER.warning(f"{prefix} your model may not work correctly with spaces in path '{output_dir}'.")
213
+
214
+
215
+ def gd_outputs(gd):
216
+ """Return TensorFlow GraphDef model output node names."""
217
+ name_list, input_list = [], []
218
+ for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef
219
+ name_list.append(node.name)
220
+ input_list.extend(node.input)
221
+ return sorted(f"{x}:0" for x in list(set(name_list) - set(input_list)) if not x.startswith("NoOp"))
ultralytics/utils/nms.py CHANGED
@@ -231,9 +231,11 @@ class TorchNMS:
231
231
  upper_mask = row_idx < col_idx
232
232
  ious = ious * upper_mask
233
233
  # Zeroing these scores ensures the additional indices would not affect the final results
234
- scores[~((ious >= iou_threshold).sum(0) <= 0)] = 0
234
+ scores_ = scores[sorted_idx]
235
+ scores_[~((ious >= iou_threshold).sum(0) <= 0)] = 0
236
+ scores[sorted_idx] = scores_ # update original tensor for NMSModel
235
237
  # NOTE: return indices with fixed length to avoid TFLite reshape error
236
- pick = torch.topk(scores, scores.shape[0]).indices
238
+ pick = torch.topk(scores_, scores_.shape[0]).indices
237
239
  return sorted_idx[pick]
238
240
 
239
241
  @staticmethod
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ultralytics
3
- Version: 8.3.222
3
+ Version: 8.3.223
4
4
  Summary: Ultralytics YOLO 🚀 for SOTA object detection, multi-object tracking, instance segmentation, pose estimation and image classification.
5
5
  Author-email: Glenn Jocher <glenn.jocher@ultralytics.com>, Jing Qiu <jing.qiu@ultralytics.com>
6
6
  Maintainer-email: Ultralytics <hello@ultralytics.com>
@@ -44,7 +44,7 @@ Requires-Dist: torch!=2.4.0,>=1.8.0; sys_platform == "win32"
44
44
  Requires-Dist: torchvision>=0.9.0
45
45
  Requires-Dist: psutil
46
46
  Requires-Dist: polars
47
- Requires-Dist: ultralytics-thop>=2.0.0
47
+ Requires-Dist: ultralytics-thop>=2.0.18
48
48
  Provides-Extra: dev
49
49
  Requires-Dist: ipython; extra == "dev"
50
50
  Requires-Dist: pytest; extra == "dev"
@@ -7,7 +7,7 @@ tests/test_exports.py,sha256=OMLio2uUhyqo8D8qB5xUwmk7Po2rMeAACRc8WYoxbj4,13147
7
7
  tests/test_integrations.py,sha256=6QgSh9n0J04RdUYz08VeVOnKmf4S5MDEQ0chzS7jo_c,6220
8
8
  tests/test_python.py,sha256=OChceQcDDAy07yACnmOoGfimRo_4YdyiMwukGEgozXA,27735
9
9
  tests/test_solutions.py,sha256=j_PZZ5tMR1Y5ararY-OTXZr1hYJ7vEVr8H3w4O1tbQs,14153
10
- ultralytics/__init__.py,sha256=pcHeLhMCBEuP-sM307qncLjoxjr7uW1KqDHyi7RwD60,1302
10
+ ultralytics/__init__.py,sha256=IFuXT77f7jmVOvOHnLjLEIrgQ-RfhI6Rq7ykdDC42GI,1302
11
11
  ultralytics/py.typed,sha256=la67KBlbjXN-_-DfGNcdOcjYumVpKG_Tkw-8n5dnGB4,8
12
12
  ultralytics/assets/bus.jpg,sha256=wCAZxJecGR63Od3ZRERe9Aja1Weayrb9Ug751DS_vGM,137419
13
13
  ultralytics/assets/zidane.jpg,sha256=Ftc4aeMmen1O0A3o6GCDO9FlfBslLpTAw0gnetx7bts,50427
@@ -18,7 +18,7 @@ ultralytics/cfg/datasets/DOTAv1.5.yaml,sha256=VZ_KKFX0H2YvlFVJ8JHcLWYBZ2xiQ6Z-RO
18
18
  ultralytics/cfg/datasets/DOTAv1.yaml,sha256=JrDuYcQ0JU9lJlCA-dCkMNko_jaj6MAVGHjsfjeZ_u0,1181
19
19
  ultralytics/cfg/datasets/GlobalWheat2020.yaml,sha256=dnr_loeYSE6Eo_f7V1yubILsMRBMRm1ozyC5r7uT-iY,2144
20
20
  ultralytics/cfg/datasets/HomeObjects-3K.yaml,sha256=xEtSqEad-rtfGuIrERjjhdISggmPlvaX-315ZzKz50I,934
21
- ultralytics/cfg/datasets/ImageNet.yaml,sha256=GvDWypLVG_H3H67Ai8IC1pvK6fwcTtF5FRhzO1OXXDU,42530
21
+ ultralytics/cfg/datasets/ImageNet.yaml,sha256=N9NHhIgnlNIBqZZbzQZAW3aCnz6RSXQABnopaDs5BmE,42529
22
22
  ultralytics/cfg/datasets/Objects365.yaml,sha256=8Bl-NAm0mlMW8EfMsz39JZo-HCvmp0ejJXaMeoHTpqw,9649
23
23
  ultralytics/cfg/datasets/SKU-110K.yaml,sha256=xvRkq3SdDOwBA91U85bln7HTXkod5MvFX6pt1PxTjJE,2609
24
24
  ultralytics/cfg/datasets/VOC.yaml,sha256=84BaL-iwG03M_W9hNzjgEQi918dZgSHbCgf9DShjwLA,3747
@@ -41,9 +41,9 @@ ultralytics/cfg/datasets/dog-pose.yaml,sha256=BI-2S3_cSVyV2Gfzbs_3GzvivRlikT0ANj
41
41
  ultralytics/cfg/datasets/dota8-multispectral.yaml,sha256=2lMBi1Q3_pc0auK00yX80oF7oUMo0bUlwjkOrp33hvs,1216
42
42
  ultralytics/cfg/datasets/dota8.yaml,sha256=5n4h_4zdrtUSkmH5DHJ-JLPvfiATcieIkgP3NeOP5nI,1060
43
43
  ultralytics/cfg/datasets/hand-keypoints.yaml,sha256=NglEDsfNRe0DaYnwy7n6hYUxEAjV-V2NZBUbj1qJaag,1365
44
- ultralytics/cfg/datasets/lvis.yaml,sha256=lMvPfuiDv_o2qLxAWoh9WMrvjKJ5moLrcx1gr3RG_pM,29680
44
+ ultralytics/cfg/datasets/lvis.yaml,sha256=RescdwAJ8EU1o7Sm0YlxYsGbQFNU1p-LFbFKYEt5MhE,29596
45
45
  ultralytics/cfg/datasets/medical-pills.yaml,sha256=RK7iQFpDDkUS6EsEGqlbFjoohi3cgSsUIbsk7UItyds,792
46
- ultralytics/cfg/datasets/open-images-v7.yaml,sha256=wK9v3OAGdHORkFdqoBi0hS0fa1b74LLroAzUSWjxEqw,12119
46
+ ultralytics/cfg/datasets/open-images-v7.yaml,sha256=2fVFmb8UEYH-LkX0z5GlYp__U0_GDqVgVqzmnfFerm8,12116
47
47
  ultralytics/cfg/datasets/package-seg.yaml,sha256=V4uyTDWWzgft24y9HJWuELKuZ5AndAHXbanxMI6T8GU,849
48
48
  ultralytics/cfg/datasets/signature.yaml,sha256=gBvU3715gVxVAafI_yaYczGX3kfEfA4BttbiMkgOXNk,774
49
49
  ultralytics/cfg/datasets/tiger-pose.yaml,sha256=bJ7nBTDQwXRHtlg3xmo4C2bOpPn_r4l8-DezSWMYNcU,1196
@@ -121,7 +121,7 @@ ultralytics/data/scripts/get_coco.sh,sha256=UuJpJeo3qQpTHVINeOpmP0NYmg8PhEFE3A8J
121
121
  ultralytics/data/scripts/get_coco128.sh,sha256=qmRQl_hOKrsdHrTrnyQuFIH01oDz3lfaz138OgGfLt8,650
122
122
  ultralytics/data/scripts/get_imagenet.sh,sha256=hr42H16bM47iT27rgS7MpEo-GeOZAYUQXgr0B2cwn48,1705
123
123
  ultralytics/engine/__init__.py,sha256=lm6MckFYCPTbqIoX7w0s_daxdjNeBeKW6DXppv1-QUM,70
124
- ultralytics/engine/exporter.py,sha256=zFk58rZvZvaxK9dIwHOLjnqcoyRnQn0l5t2RbNIGsIE,73417
124
+ ultralytics/engine/exporter.py,sha256=89hggNbcH7zFAG8QJmShoHFZMvn0SpHF_yTEJ4CMbsc,69852
125
125
  ultralytics/engine/model.py,sha256=d7yGl8ybd7v8W4Q-ueSDAVfumDhsx0QCp4mx8OKf0Z8,53448
126
126
  ultralytics/engine/predictor.py,sha256=ZQrx1Bz4X8aTgGjrOSdRSP7SCtQ05uqz6IitEan_Gyk,22813
127
127
  ultralytics/engine/results.py,sha256=oHQdV_eIMvAU2qLCV7wG7iLifdfaLEgP80lDXB5ghkg,71490
@@ -147,7 +147,7 @@ ultralytics/models/rtdetr/__init__.py,sha256=F4NEQqtcVKFxj97Dh7rkn2Vu3JG4Ea_nxqr
147
147
  ultralytics/models/rtdetr/model.py,sha256=Pq9QDgaZetDnjxdYSoomj2s6vOGSdpsqVfyN5j0GUmc,2292
148
148
  ultralytics/models/rtdetr/predict.py,sha256=43-gGCHEH7UQQ6H1oXdlDlrM39esnp-YEhqCvZOwtOM,4279
149
149
  ultralytics/models/rtdetr/train.py,sha256=SNntxGHXatbNqn1yna5_dDQiR_ciDK6o_4S7JIHU7EY,3765
150
- ultralytics/models/rtdetr/val.py,sha256=l26CzpcYHYC0sQ--rKUFBCYl73nsgAGOj1U3xScNzFs,8918
150
+ ultralytics/models/rtdetr/val.py,sha256=UXaoNiy81zdkv6d79x1oGyR8T7dwuV5Y4m0Gpe-LQts,8976
151
151
  ultralytics/models/sam/__init__.py,sha256=p1BKLawQFvVxmdk7LomFVWX-67Kc-AP4PJBNPfU_Nuc,359
152
152
  ultralytics/models/sam/amg.py,sha256=nFq4EwHf65W2N5Ipo4W69nGRhCbJEh_boYQ8SIPWBZ0,11816
153
153
  ultralytics/models/sam/build.py,sha256=uKCgHpcYgV26OFuMq5RaGR8aXYoEtNoituT06bmnW44,12790
@@ -196,14 +196,14 @@ ultralytics/models/yolo/yoloe/train.py,sha256=qefvNNXDTOK1tO3va0kNHr8lE5QJkOlV8G
196
196
  ultralytics/models/yolo/yoloe/train_seg.py,sha256=aCV7M8oQOvODFnU4piZdJh3tIrBJYAzZfRVRx1vRgxo,4956
197
197
  ultralytics/models/yolo/yoloe/val.py,sha256=5Gd9EoFH0FmKKvWXBl4J7gBe9DVxIczN-s3ceHwdUDo,9458
198
198
  ultralytics/nn/__init__.py,sha256=538LZPUKKvc3JCMgiQ4VLGqRN2ZAaVLFcQbeNNHFkEA,545
199
- ultralytics/nn/autobackend.py,sha256=918iNweM3fTuRIbxHbXC1wspOj9rlGkuwalQ61uYLbk,42694
199
+ ultralytics/nn/autobackend.py,sha256=gw8REfburF36l9Hyh11eYzy7UnMvuX1Dm3cjsJBA1TM,42702
200
200
  ultralytics/nn/tasks.py,sha256=vRr6HTucM7Eg3kxzhYtyjgEAdacZ7gIDU3yPbMnyYMM,70834
201
201
  ultralytics/nn/text_model.py,sha256=pHqnKe8UueR1MuwJcIE_IvrnYIlt68QL796xjcRJs2A,15275
202
202
  ultralytics/nn/modules/__init__.py,sha256=5Sg_28MDfKwdu14Ty_WCaiIXZyjBSQ-xCNCwnoz_w-w,3198
203
203
  ultralytics/nn/modules/activation.py,sha256=75JcIMH2Cu9GTC2Uf55r_5YLpxcrXQDaVoeGQ0hlUAU,2233
204
204
  ultralytics/nn/modules/block.py,sha256=eQ8DegyvBG9k-O_QgSZe5XGmpravqwlnSCCBW6bHRXo,70622
205
205
  ultralytics/nn/modules/conv.py,sha256=MISNAK8NzAZhNUusVKWvTHQ8IsofwM-5X0gChCagsaY,21457
206
- ultralytics/nn/modules/head.py,sha256=HBSoHOcd9hikpBLF9BdgIItYJDNWIO8NiNYMcMw6ThM,53512
206
+ ultralytics/nn/modules/head.py,sha256=XBOLfpxgApIhNmdgnWoECep0wKhrw8LWtmd1TrWNBak,52076
207
207
  ultralytics/nn/modules/transformer.py,sha256=9aq0Yo9V3C4y_McSje4qE1d_PTWDctTsrb98MyXxigc,31470
208
208
  ultralytics/nn/modules/utils.py,sha256=9kLeEtvEBFLugz53TkdI4mifD-39a-upjPD-wrE8opU,6092
209
209
  ultralytics/solutions/__init__.py,sha256=Jj7OcRiYjHH-e104H4xTgjjR5W6aPB4mBRndbaSPmgU,1209
@@ -252,7 +252,7 @@ ultralytics/utils/instance.py,sha256=_b_jMTECWJGzncCiTg7FtTDSSeXGnbiAhaJhIsqbn9k
252
252
  ultralytics/utils/logger.py,sha256=hK1APBBHmlLAm0zbAFY7gf7Iaejy0PdwLWnnpboboGg,15129
253
253
  ultralytics/utils/loss.py,sha256=wJ0F2DpRTI9-e9adxIm2io0zcXRa0RTWFTOc7WmS1-A,39827
254
254
  ultralytics/utils/metrics.py,sha256=EWwkVWNmN_9rIsR1UOTLz3PiXOzflUE0iWFibydvXgM,68882
255
- ultralytics/utils/nms.py,sha256=AVOmPuUTEJqmq2J6rvjq-nHNxYIyabgzHdc41siyA0w,14161
255
+ ultralytics/utils/nms.py,sha256=SnZF0VRzY933YzI92NLzmLwuVzu56UNZ7sFT0FryCaw,14285
256
256
  ultralytics/utils/ops.py,sha256=yb0jlahjxqUT_xb3y9wz0kXn0rx2AryUgWdtLat3yWY,27010
257
257
  ultralytics/utils/patches.py,sha256=0-2G4jXCIPnMonlft-cPcjfFcOXQS6ODwUDNUwanfg4,6541
258
258
  ultralytics/utils/plotting.py,sha256=l5G4MT2pB_LLMFqSgFbKb7ip5VMrnpi3i5QmZWytRRU,48369
@@ -273,11 +273,13 @@ ultralytics/utils/callbacks/platform.py,sha256=a7T_8htoBB0uX1WIc392UJnhDjxkRyQMv
273
273
  ultralytics/utils/callbacks/raytune.py,sha256=S6Bq16oQDQ8BQgnZzA0zJHGN_BBr8iAM_WtGoLiEcwg,1283
274
274
  ultralytics/utils/callbacks/tensorboard.py,sha256=_4nfGK1dDLn6ijpvphBDhc-AS8qhS3jjY2CAWB7SNF0,5283
275
275
  ultralytics/utils/callbacks/wb.py,sha256=ngQO8EJ1kxJDF1YajScVtzBbm26jGuejA0uWeOyvf5A,7685
276
- ultralytics/utils/export/__init__.py,sha256=uyRhb-0Z5FVf7vSz2Yba1m7g5m2U_ftAv4ThlmMsqZ8,10015
276
+ ultralytics/utils/export/__init__.py,sha256=Cfh-PwVfTF_lwPp-Ss4wiX4z8Sm1XRPklsqdFfmTZ30,333
277
+ ultralytics/utils/export/engine.py,sha256=V8ERERlpufTRm6k_7KOy9dUupAICC28W9TPO_7dkEJY,9979
277
278
  ultralytics/utils/export/imx.py,sha256=DH0rVe-gris7qA7bGT-WoOJHqWxLBAmei1JXmK-W7vM,11660
278
- ultralytics-8.3.222.dist-info/licenses/LICENSE,sha256=DZak_2itbUtvHzD3E7GNUYSRK6jdOJ-GqncQ2weavLA,34523
279
- ultralytics-8.3.222.dist-info/METADATA,sha256=CAmjZTJwgxx_WwAGa142W7xIM8wjfyJoGx_kBMaItIw,37667
280
- ultralytics-8.3.222.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
281
- ultralytics-8.3.222.dist-info/entry_points.txt,sha256=YM_wiKyTe9yRrsEfqvYolNO5ngwfoL4-NwgKzc8_7sI,93
282
- ultralytics-8.3.222.dist-info/top_level.txt,sha256=XP49TwiMw4QGsvTLSYiJhz1xF_k7ev5mQ8jJXaXi45Q,12
283
- ultralytics-8.3.222.dist-info/RECORD,,
279
+ ultralytics/utils/export/tensorflow.py,sha256=CxraBn-5pIDSd_-0-0vQGMz8lv75vjSl6N7DYgVS3SU,9382
280
+ ultralytics-8.3.223.dist-info/licenses/LICENSE,sha256=DZak_2itbUtvHzD3E7GNUYSRK6jdOJ-GqncQ2weavLA,34523
281
+ ultralytics-8.3.223.dist-info/METADATA,sha256=cmJXLXN705e0W2Hgph9hNPJoiMBaB-i2foZh3nUlttE,37668
282
+ ultralytics-8.3.223.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
283
+ ultralytics-8.3.223.dist-info/entry_points.txt,sha256=YM_wiKyTe9yRrsEfqvYolNO5ngwfoL4-NwgKzc8_7sI,93
284
+ ultralytics-8.3.223.dist-info/top_level.txt,sha256=XP49TwiMw4QGsvTLSYiJhz1xF_k7ev5mQ8jJXaXi45Q,12
285
+ ultralytics-8.3.223.dist-info/RECORD,,