dgenerate-ultralytics-headless 8.3.221__py3-none-any.whl → 8.3.223__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {dgenerate_ultralytics_headless-8.3.221.dist-info → dgenerate_ultralytics_headless-8.3.223.dist-info}/METADATA +2 -2
  2. {dgenerate_ultralytics_headless-8.3.221.dist-info → dgenerate_ultralytics_headless-8.3.223.dist-info}/RECORD +29 -27
  3. tests/test_python.py +5 -5
  4. ultralytics/__init__.py +1 -1
  5. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  6. ultralytics/cfg/datasets/lvis.yaml +5 -5
  7. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  8. ultralytics/data/base.py +1 -1
  9. ultralytics/data/utils.py +1 -1
  10. ultralytics/engine/exporter.py +46 -110
  11. ultralytics/engine/model.py +1 -1
  12. ultralytics/engine/trainer.py +1 -1
  13. ultralytics/models/rtdetr/val.py +1 -1
  14. ultralytics/models/yolo/classify/train.py +2 -2
  15. ultralytics/nn/autobackend.py +1 -1
  16. ultralytics/nn/modules/head.py +5 -30
  17. ultralytics/utils/__init__.py +4 -4
  18. ultralytics/utils/benchmarks.py +3 -1
  19. ultralytics/utils/export/__init__.py +4 -239
  20. ultralytics/utils/export/engine.py +240 -0
  21. ultralytics/utils/export/imx.py +39 -28
  22. ultralytics/utils/export/tensorflow.py +221 -0
  23. ultralytics/utils/metrics.py +2 -2
  24. ultralytics/utils/nms.py +4 -2
  25. ultralytics/utils/plotting.py +1 -1
  26. {dgenerate_ultralytics_headless-8.3.221.dist-info → dgenerate_ultralytics_headless-8.3.223.dist-info}/WHEEL +0 -0
  27. {dgenerate_ultralytics_headless-8.3.221.dist-info → dgenerate_ultralytics_headless-8.3.223.dist-info}/entry_points.txt +0 -0
  28. {dgenerate_ultralytics_headless-8.3.221.dist-info → dgenerate_ultralytics_headless-8.3.223.dist-info}/licenses/LICENSE +0 -0
  29. {dgenerate_ultralytics_headless-8.3.221.dist-info → dgenerate_ultralytics_headless-8.3.223.dist-info}/top_level.txt +0 -0
@@ -166,22 +166,8 @@ class Detect(nn.Module):
166
166
  self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
167
167
  self.shape = shape
168
168
 
169
- if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops
170
- box = x_cat[:, : self.reg_max * 4]
171
- cls = x_cat[:, self.reg_max * 4 :]
172
- else:
173
- box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
174
-
175
- if self.export and self.format in {"tflite", "edgetpu"}:
176
- # Precompute normalization factor to increase numerical stability
177
- # See https://github.com/ultralytics/ultralytics/issues/7371
178
- grid_h = shape[2]
179
- grid_w = shape[3]
180
- grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
181
- norm = self.strides / (self.stride[0] * grid_size)
182
- dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
183
- else:
184
- dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
169
+ box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
170
+ dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
185
171
  return torch.cat((dbox, cls.sigmoid()), 1)
186
172
 
187
173
  def bias_init(self):
@@ -391,20 +377,9 @@ class Pose(Detect):
391
377
  """Decode keypoints from predictions."""
392
378
  ndim = self.kpt_shape[1]
393
379
  if self.export:
394
- if self.format in {
395
- "tflite",
396
- "edgetpu",
397
- }: # required for TFLite export to avoid 'PLACEHOLDER_FOR_GREATER_OP_CODES' bug
398
- # Precompute normalization factor to increase numerical stability
399
- y = kpts.view(bs, *self.kpt_shape, -1)
400
- grid_h, grid_w = self.shape[2], self.shape[3]
401
- grid_size = torch.tensor([grid_w, grid_h], device=y.device).reshape(1, 2, 1)
402
- norm = self.strides / (self.stride[0] * grid_size)
403
- a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * norm
404
- else:
405
- # NCNN fix
406
- y = kpts.view(bs, *self.kpt_shape, -1)
407
- a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides
380
+ # NCNN fix
381
+ y = kpts.view(bs, *self.kpt_shape, -1)
382
+ a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides
408
383
  if ndim == 3:
409
384
  a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2)
410
385
  return a.view(bs, self.nk, -1)
@@ -795,13 +795,13 @@ def is_pip_package(filepath: str = __name__) -> bool:
795
795
 
796
796
  def is_dir_writeable(dir_path: str | Path) -> bool:
797
797
  """
798
- Check if a directory is writeable.
798
+ Check if a directory is writable.
799
799
 
800
800
  Args:
801
801
  dir_path (str | Path): The path to the directory.
802
802
 
803
803
  Returns:
804
- (bool): True if the directory is writeable, False otherwise.
804
+ (bool): True if the directory is writable, False otherwise.
805
805
  """
806
806
  return os.access(str(dir_path), os.W_OK)
807
807
 
@@ -882,14 +882,14 @@ def get_user_config_dir(sub_dir="Ultralytics"):
882
882
  p.mkdir(parents=True, exist_ok=True)
883
883
  return p
884
884
 
885
- # Fallbacks for Docker, GCP/AWS functions where only /tmp is writeable
885
+ # Fallbacks for Docker, GCP/AWS functions where only /tmp is writable
886
886
  for alt in [Path("/tmp") / sub_dir, Path.cwd() / sub_dir]:
887
887
  if alt.exists():
888
888
  return alt
889
889
  if is_dir_writeable(alt.parent):
890
890
  alt.mkdir(parents=True, exist_ok=True)
891
891
  LOGGER.warning(
892
- f"user config directory '{p}' is not writeable, using '{alt}'. Set YOLO_CONFIG_DIR to override."
892
+ f"user config directory '{p}' is not writable, using '{alt}'. Set YOLO_CONFIG_DIR to override."
893
893
  )
894
894
  return alt
895
895
 
@@ -144,7 +144,9 @@ def benchmark(
144
144
  if format == "imx":
145
145
  assert not is_end2end
146
146
  assert not isinstance(model, YOLOWorld), "YOLOWorldv2 IMX exports not supported"
147
- assert model.task == "detect", "IMX only supported for detection task"
147
+ assert model.task in {"detect", "classify", "pose"}, (
148
+ "IMX export is only supported for detection, classification and pose estimation tasks"
149
+ )
148
150
  assert "C2f" in model.__str__(), "IMX only supported for YOLOv8n and YOLO11n"
149
151
  if format == "rknn":
150
152
  assert not isinstance(model, YOLOWorld), "YOLOWorldv2 RKNN exports not supported yet"
@@ -1,242 +1,7 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
- from __future__ import annotations
3
+ from .engine import onnx2engine, torch2onnx
4
+ from .imx import torch2imx
5
+ from .tensorflow import keras2pb, onnx2saved_model, pb2tfjs, tflite2edgetpu
4
6
 
5
- import json
6
- from pathlib import Path
7
-
8
- import torch
9
-
10
- from ultralytics.utils import IS_JETSON, LOGGER
11
- from ultralytics.utils.torch_utils import TORCH_2_4
12
-
13
- from .imx import torch2imx # noqa
14
-
15
-
16
- def torch2onnx(
17
- torch_model: torch.nn.Module,
18
- im: torch.Tensor,
19
- onnx_file: str,
20
- opset: int = 14,
21
- input_names: list[str] = ["images"],
22
- output_names: list[str] = ["output0"],
23
- dynamic: bool | dict = False,
24
- ) -> None:
25
- """
26
- Export a PyTorch model to ONNX format.
27
-
28
- Args:
29
- torch_model (torch.nn.Module): The PyTorch model to export.
30
- im (torch.Tensor): Example input tensor for the model.
31
- onnx_file (str): Path to save the exported ONNX file.
32
- opset (int): ONNX opset version to use for export.
33
- input_names (list[str]): List of input tensor names.
34
- output_names (list[str]): List of output tensor names.
35
- dynamic (bool | dict, optional): Whether to enable dynamic axes.
36
-
37
- Notes:
38
- Setting `do_constant_folding=True` may cause issues with DNN inference for torch>=1.12.
39
- """
40
- kwargs = {"dynamo": False} if TORCH_2_4 else {}
41
- torch.onnx.export(
42
- torch_model,
43
- im,
44
- onnx_file,
45
- verbose=False,
46
- opset_version=opset,
47
- do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
48
- input_names=input_names,
49
- output_names=output_names,
50
- dynamic_axes=dynamic or None,
51
- **kwargs,
52
- )
53
-
54
-
55
- def onnx2engine(
56
- onnx_file: str,
57
- engine_file: str | None = None,
58
- workspace: int | None = None,
59
- half: bool = False,
60
- int8: bool = False,
61
- dynamic: bool = False,
62
- shape: tuple[int, int, int, int] = (1, 3, 640, 640),
63
- dla: int | None = None,
64
- dataset=None,
65
- metadata: dict | None = None,
66
- verbose: bool = False,
67
- prefix: str = "",
68
- ) -> None:
69
- """
70
- Export a YOLO model to TensorRT engine format.
71
-
72
- Args:
73
- onnx_file (str): Path to the ONNX file to be converted.
74
- engine_file (str, optional): Path to save the generated TensorRT engine file.
75
- workspace (int, optional): Workspace size in GB for TensorRT.
76
- half (bool, optional): Enable FP16 precision.
77
- int8 (bool, optional): Enable INT8 precision.
78
- dynamic (bool, optional): Enable dynamic input shapes.
79
- shape (tuple[int, int, int, int], optional): Input shape (batch, channels, height, width).
80
- dla (int, optional): DLA core to use (Jetson devices only).
81
- dataset (ultralytics.data.build.InfiniteDataLoader, optional): Dataset for INT8 calibration.
82
- metadata (dict, optional): Metadata to include in the engine file.
83
- verbose (bool, optional): Enable verbose logging.
84
- prefix (str, optional): Prefix for log messages.
85
-
86
- Raises:
87
- ValueError: If DLA is enabled on non-Jetson devices or required precision is not set.
88
- RuntimeError: If the ONNX file cannot be parsed.
89
-
90
- Notes:
91
- TensorRT version compatibility is handled for workspace size and engine building.
92
- INT8 calibration requires a dataset and generates a calibration cache.
93
- Metadata is serialized and written to the engine file if provided.
94
- """
95
- import tensorrt as trt
96
-
97
- engine_file = engine_file or Path(onnx_file).with_suffix(".engine")
98
-
99
- logger = trt.Logger(trt.Logger.INFO)
100
- if verbose:
101
- logger.min_severity = trt.Logger.Severity.VERBOSE
102
-
103
- # Engine builder
104
- builder = trt.Builder(logger)
105
- config = builder.create_builder_config()
106
- workspace_bytes = int((workspace or 0) * (1 << 30))
107
- is_trt10 = int(trt.__version__.split(".", 1)[0]) >= 10 # is TensorRT >= 10
108
- if is_trt10 and workspace_bytes > 0:
109
- config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace_bytes)
110
- elif workspace_bytes > 0: # TensorRT versions 7, 8
111
- config.max_workspace_size = workspace_bytes
112
- flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
113
- network = builder.create_network(flag)
114
- half = builder.platform_has_fast_fp16 and half
115
- int8 = builder.platform_has_fast_int8 and int8
116
-
117
- # Optionally switch to DLA if enabled
118
- if dla is not None:
119
- if not IS_JETSON:
120
- raise ValueError("DLA is only available on NVIDIA Jetson devices")
121
- LOGGER.info(f"{prefix} enabling DLA on core {dla}...")
122
- if not half and not int8:
123
- raise ValueError(
124
- "DLA requires either 'half=True' (FP16) or 'int8=True' (INT8) to be enabled. Please enable one of them and try again."
125
- )
126
- config.default_device_type = trt.DeviceType.DLA
127
- config.DLA_core = int(dla)
128
- config.set_flag(trt.BuilderFlag.GPU_FALLBACK)
129
-
130
- # Read ONNX file
131
- parser = trt.OnnxParser(network, logger)
132
- if not parser.parse_from_file(onnx_file):
133
- raise RuntimeError(f"failed to load ONNX file: {onnx_file}")
134
-
135
- # Network inputs
136
- inputs = [network.get_input(i) for i in range(network.num_inputs)]
137
- outputs = [network.get_output(i) for i in range(network.num_outputs)]
138
- for inp in inputs:
139
- LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}')
140
- for out in outputs:
141
- LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}')
142
-
143
- if dynamic:
144
- profile = builder.create_optimization_profile()
145
- min_shape = (1, shape[1], 32, 32) # minimum input shape
146
- max_shape = (*shape[:2], *(int(max(2, workspace or 2) * d) for d in shape[2:])) # max input shape
147
- for inp in inputs:
148
- profile.set_shape(inp.name, min=min_shape, opt=shape, max=max_shape)
149
- config.add_optimization_profile(profile)
150
- if int8:
151
- config.set_calibration_profile(profile)
152
-
153
- LOGGER.info(f"{prefix} building {'INT8' if int8 else 'FP' + ('16' if half else '32')} engine as {engine_file}")
154
- if int8:
155
- config.set_flag(trt.BuilderFlag.INT8)
156
- config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED
157
-
158
- class EngineCalibrator(trt.IInt8Calibrator):
159
- """
160
- Custom INT8 calibrator for TensorRT engine optimization.
161
-
162
- This calibrator provides the necessary interface for TensorRT to perform INT8 quantization calibration
163
- using a dataset. It handles batch generation, caching, and calibration algorithm selection.
164
-
165
- Attributes:
166
- dataset: Dataset for calibration.
167
- data_iter: Iterator over the calibration dataset.
168
- algo (trt.CalibrationAlgoType): Calibration algorithm type.
169
- batch (int): Batch size for calibration.
170
- cache (Path): Path to save the calibration cache.
171
-
172
- Methods:
173
- get_algorithm: Get the calibration algorithm to use.
174
- get_batch_size: Get the batch size to use for calibration.
175
- get_batch: Get the next batch to use for calibration.
176
- read_calibration_cache: Use existing cache instead of calibrating again.
177
- write_calibration_cache: Write calibration cache to disk.
178
- """
179
-
180
- def __init__(
181
- self,
182
- dataset, # ultralytics.data.build.InfiniteDataLoader
183
- cache: str = "",
184
- ) -> None:
185
- """Initialize the INT8 calibrator with dataset and cache path."""
186
- trt.IInt8Calibrator.__init__(self)
187
- self.dataset = dataset
188
- self.data_iter = iter(dataset)
189
- self.algo = (
190
- trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2 # DLA quantization needs ENTROPY_CALIBRATION_2
191
- if dla is not None
192
- else trt.CalibrationAlgoType.MINMAX_CALIBRATION
193
- )
194
- self.batch = dataset.batch_size
195
- self.cache = Path(cache)
196
-
197
- def get_algorithm(self) -> trt.CalibrationAlgoType:
198
- """Get the calibration algorithm to use."""
199
- return self.algo
200
-
201
- def get_batch_size(self) -> int:
202
- """Get the batch size to use for calibration."""
203
- return self.batch or 1
204
-
205
- def get_batch(self, names) -> list[int] | None:
206
- """Get the next batch to use for calibration, as a list of device memory pointers."""
207
- try:
208
- im0s = next(self.data_iter)["img"] / 255.0
209
- im0s = im0s.to("cuda") if im0s.device.type == "cpu" else im0s
210
- return [int(im0s.data_ptr())]
211
- except StopIteration:
212
- # Return None to signal to TensorRT there is no calibration data remaining
213
- return None
214
-
215
- def read_calibration_cache(self) -> bytes | None:
216
- """Use existing cache instead of calibrating again, otherwise, implicitly return None."""
217
- if self.cache.exists() and self.cache.suffix == ".cache":
218
- return self.cache.read_bytes()
219
-
220
- def write_calibration_cache(self, cache: bytes) -> None:
221
- """Write calibration cache to disk."""
222
- _ = self.cache.write_bytes(cache)
223
-
224
- # Load dataset w/ builder (for batching) and calibrate
225
- config.int8_calibrator = EngineCalibrator(
226
- dataset=dataset,
227
- cache=str(Path(onnx_file).with_suffix(".cache")),
228
- )
229
-
230
- elif half:
231
- config.set_flag(trt.BuilderFlag.FP16)
232
-
233
- # Write file
234
- build = builder.build_serialized_network if is_trt10 else builder.build_engine
235
- with build(network, config) as engine, open(engine_file, "wb") as t:
236
- # Metadata
237
- if metadata is not None:
238
- meta = json.dumps(metadata)
239
- t.write(len(meta).to_bytes(4, byteorder="little", signed=True))
240
- t.write(meta.encode())
241
- # Model
242
- t.write(engine if is_trt10 else engine.serialize())
7
+ __all__ = ["keras2pb", "onnx2engine", "onnx2saved_model", "pb2tfjs", "tflite2edgetpu", "torch2imx", "torch2onnx"]
@@ -0,0 +1,240 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from pathlib import Path
7
+
8
+ import torch
9
+
10
+ from ultralytics.utils import IS_JETSON, LOGGER
11
+ from ultralytics.utils.torch_utils import TORCH_2_4
12
+
13
+
14
+ def torch2onnx(
15
+ torch_model: torch.nn.Module,
16
+ im: torch.Tensor,
17
+ onnx_file: str,
18
+ opset: int = 14,
19
+ input_names: list[str] = ["images"],
20
+ output_names: list[str] = ["output0"],
21
+ dynamic: bool | dict = False,
22
+ ) -> None:
23
+ """
24
+ Export a PyTorch model to ONNX format.
25
+
26
+ Args:
27
+ torch_model (torch.nn.Module): The PyTorch model to export.
28
+ im (torch.Tensor): Example input tensor for the model.
29
+ onnx_file (str): Path to save the exported ONNX file.
30
+ opset (int): ONNX opset version to use for export.
31
+ input_names (list[str]): List of input tensor names.
32
+ output_names (list[str]): List of output tensor names.
33
+ dynamic (bool | dict, optional): Whether to enable dynamic axes.
34
+
35
+ Notes:
36
+ Setting `do_constant_folding=True` may cause issues with DNN inference for torch>=1.12.
37
+ """
38
+ kwargs = {"dynamo": False} if TORCH_2_4 else {}
39
+ torch.onnx.export(
40
+ torch_model,
41
+ im,
42
+ onnx_file,
43
+ verbose=False,
44
+ opset_version=opset,
45
+ do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
46
+ input_names=input_names,
47
+ output_names=output_names,
48
+ dynamic_axes=dynamic or None,
49
+ **kwargs,
50
+ )
51
+
52
+
53
+ def onnx2engine(
54
+ onnx_file: str,
55
+ engine_file: str | None = None,
56
+ workspace: int | None = None,
57
+ half: bool = False,
58
+ int8: bool = False,
59
+ dynamic: bool = False,
60
+ shape: tuple[int, int, int, int] = (1, 3, 640, 640),
61
+ dla: int | None = None,
62
+ dataset=None,
63
+ metadata: dict | None = None,
64
+ verbose: bool = False,
65
+ prefix: str = "",
66
+ ) -> None:
67
+ """
68
+ Export a YOLO model to TensorRT engine format.
69
+
70
+ Args:
71
+ onnx_file (str): Path to the ONNX file to be converted.
72
+ engine_file (str, optional): Path to save the generated TensorRT engine file.
73
+ workspace (int, optional): Workspace size in GB for TensorRT.
74
+ half (bool, optional): Enable FP16 precision.
75
+ int8 (bool, optional): Enable INT8 precision.
76
+ dynamic (bool, optional): Enable dynamic input shapes.
77
+ shape (tuple[int, int, int, int], optional): Input shape (batch, channels, height, width).
78
+ dla (int, optional): DLA core to use (Jetson devices only).
79
+ dataset (ultralytics.data.build.InfiniteDataLoader, optional): Dataset for INT8 calibration.
80
+ metadata (dict, optional): Metadata to include in the engine file.
81
+ verbose (bool, optional): Enable verbose logging.
82
+ prefix (str, optional): Prefix for log messages.
83
+
84
+ Raises:
85
+ ValueError: If DLA is enabled on non-Jetson devices or required precision is not set.
86
+ RuntimeError: If the ONNX file cannot be parsed.
87
+
88
+ Notes:
89
+ TensorRT version compatibility is handled for workspace size and engine building.
90
+ INT8 calibration requires a dataset and generates a calibration cache.
91
+ Metadata is serialized and written to the engine file if provided.
92
+ """
93
+ import tensorrt as trt
94
+
95
+ engine_file = engine_file or Path(onnx_file).with_suffix(".engine")
96
+
97
+ logger = trt.Logger(trt.Logger.INFO)
98
+ if verbose:
99
+ logger.min_severity = trt.Logger.Severity.VERBOSE
100
+
101
+ # Engine builder
102
+ builder = trt.Builder(logger)
103
+ config = builder.create_builder_config()
104
+ workspace_bytes = int((workspace or 0) * (1 << 30))
105
+ is_trt10 = int(trt.__version__.split(".", 1)[0]) >= 10 # is TensorRT >= 10
106
+ if is_trt10 and workspace_bytes > 0:
107
+ config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace_bytes)
108
+ elif workspace_bytes > 0: # TensorRT versions 7, 8
109
+ config.max_workspace_size = workspace_bytes
110
+ flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
111
+ network = builder.create_network(flag)
112
+ half = builder.platform_has_fast_fp16 and half
113
+ int8 = builder.platform_has_fast_int8 and int8
114
+
115
+ # Optionally switch to DLA if enabled
116
+ if dla is not None:
117
+ if not IS_JETSON:
118
+ raise ValueError("DLA is only available on NVIDIA Jetson devices")
119
+ LOGGER.info(f"{prefix} enabling DLA on core {dla}...")
120
+ if not half and not int8:
121
+ raise ValueError(
122
+ "DLA requires either 'half=True' (FP16) or 'int8=True' (INT8) to be enabled. Please enable one of them and try again."
123
+ )
124
+ config.default_device_type = trt.DeviceType.DLA
125
+ config.DLA_core = int(dla)
126
+ config.set_flag(trt.BuilderFlag.GPU_FALLBACK)
127
+
128
+ # Read ONNX file
129
+ parser = trt.OnnxParser(network, logger)
130
+ if not parser.parse_from_file(onnx_file):
131
+ raise RuntimeError(f"failed to load ONNX file: {onnx_file}")
132
+
133
+ # Network inputs
134
+ inputs = [network.get_input(i) for i in range(network.num_inputs)]
135
+ outputs = [network.get_output(i) for i in range(network.num_outputs)]
136
+ for inp in inputs:
137
+ LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}')
138
+ for out in outputs:
139
+ LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}')
140
+
141
+ if dynamic:
142
+ profile = builder.create_optimization_profile()
143
+ min_shape = (1, shape[1], 32, 32) # minimum input shape
144
+ max_shape = (*shape[:2], *(int(max(2, workspace or 2) * d) for d in shape[2:])) # max input shape
145
+ for inp in inputs:
146
+ profile.set_shape(inp.name, min=min_shape, opt=shape, max=max_shape)
147
+ config.add_optimization_profile(profile)
148
+ if int8:
149
+ config.set_calibration_profile(profile)
150
+
151
+ LOGGER.info(f"{prefix} building {'INT8' if int8 else 'FP' + ('16' if half else '32')} engine as {engine_file}")
152
+ if int8:
153
+ config.set_flag(trt.BuilderFlag.INT8)
154
+ config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED
155
+
156
+ class EngineCalibrator(trt.IInt8Calibrator):
157
+ """
158
+ Custom INT8 calibrator for TensorRT engine optimization.
159
+
160
+ This calibrator provides the necessary interface for TensorRT to perform INT8 quantization calibration
161
+ using a dataset. It handles batch generation, caching, and calibration algorithm selection.
162
+
163
+ Attributes:
164
+ dataset: Dataset for calibration.
165
+ data_iter: Iterator over the calibration dataset.
166
+ algo (trt.CalibrationAlgoType): Calibration algorithm type.
167
+ batch (int): Batch size for calibration.
168
+ cache (Path): Path to save the calibration cache.
169
+
170
+ Methods:
171
+ get_algorithm: Get the calibration algorithm to use.
172
+ get_batch_size: Get the batch size to use for calibration.
173
+ get_batch: Get the next batch to use for calibration.
174
+ read_calibration_cache: Use existing cache instead of calibrating again.
175
+ write_calibration_cache: Write calibration cache to disk.
176
+ """
177
+
178
+ def __init__(
179
+ self,
180
+ dataset, # ultralytics.data.build.InfiniteDataLoader
181
+ cache: str = "",
182
+ ) -> None:
183
+ """Initialize the INT8 calibrator with dataset and cache path."""
184
+ trt.IInt8Calibrator.__init__(self)
185
+ self.dataset = dataset
186
+ self.data_iter = iter(dataset)
187
+ self.algo = (
188
+ trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2 # DLA quantization needs ENTROPY_CALIBRATION_2
189
+ if dla is not None
190
+ else trt.CalibrationAlgoType.MINMAX_CALIBRATION
191
+ )
192
+ self.batch = dataset.batch_size
193
+ self.cache = Path(cache)
194
+
195
+ def get_algorithm(self) -> trt.CalibrationAlgoType:
196
+ """Get the calibration algorithm to use."""
197
+ return self.algo
198
+
199
+ def get_batch_size(self) -> int:
200
+ """Get the batch size to use for calibration."""
201
+ return self.batch or 1
202
+
203
+ def get_batch(self, names) -> list[int] | None:
204
+ """Get the next batch to use for calibration, as a list of device memory pointers."""
205
+ try:
206
+ im0s = next(self.data_iter)["img"] / 255.0
207
+ im0s = im0s.to("cuda") if im0s.device.type == "cpu" else im0s
208
+ return [int(im0s.data_ptr())]
209
+ except StopIteration:
210
+ # Return None to signal to TensorRT there is no calibration data remaining
211
+ return None
212
+
213
+ def read_calibration_cache(self) -> bytes | None:
214
+ """Use existing cache instead of calibrating again, otherwise, implicitly return None."""
215
+ if self.cache.exists() and self.cache.suffix == ".cache":
216
+ return self.cache.read_bytes()
217
+
218
+ def write_calibration_cache(self, cache: bytes) -> None:
219
+ """Write calibration cache to disk."""
220
+ _ = self.cache.write_bytes(cache)
221
+
222
+ # Load dataset w/ builder (for batching) and calibrate
223
+ config.int8_calibrator = EngineCalibrator(
224
+ dataset=dataset,
225
+ cache=str(Path(onnx_file).with_suffix(".cache")),
226
+ )
227
+
228
+ elif half:
229
+ config.set_flag(trt.BuilderFlag.FP16)
230
+
231
+ # Write file
232
+ build = builder.build_serialized_network if is_trt10 else builder.build_engine
233
+ with build(network, config) as engine, open(engine_file, "wb") as t:
234
+ # Metadata
235
+ if metadata is not None:
236
+ meta = json.dumps(metadata)
237
+ t.write(len(meta).to_bytes(4, byteorder="little", signed=True))
238
+ t.write(meta.encode())
239
+ # Model
240
+ t.write(engine if is_trt10 else engine.serialize())