dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/METADATA +64 -74
  2. dgenerate_ultralytics_headless-8.4.7.dist-info/RECORD +311 -0
  3. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -9
  5. tests/conftest.py +8 -15
  6. tests/test_cli.py +1 -1
  7. tests/test_cuda.py +13 -10
  8. tests/test_engine.py +9 -9
  9. tests/test_exports.py +65 -13
  10. tests/test_integrations.py +13 -13
  11. tests/test_python.py +125 -69
  12. tests/test_solutions.py +161 -152
  13. ultralytics/__init__.py +1 -1
  14. ultralytics/cfg/__init__.py +86 -92
  15. ultralytics/cfg/datasets/Argoverse.yaml +7 -6
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  17. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  18. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  19. ultralytics/cfg/datasets/TT100K.yaml +346 -0
  20. ultralytics/cfg/datasets/VOC.yaml +15 -16
  21. ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
  22. ultralytics/cfg/datasets/coco-pose.yaml +21 -0
  23. ultralytics/cfg/datasets/coco12-formats.yaml +101 -0
  24. ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
  25. ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
  26. ultralytics/cfg/datasets/dog-pose.yaml +28 -0
  27. ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
  28. ultralytics/cfg/datasets/dota8.yaml +2 -2
  29. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
  30. ultralytics/cfg/datasets/kitti.yaml +27 -0
  31. ultralytics/cfg/datasets/lvis.yaml +5 -5
  32. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  33. ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
  34. ultralytics/cfg/datasets/xView.yaml +16 -16
  35. ultralytics/cfg/default.yaml +4 -2
  36. ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
  37. ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
  38. ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
  39. ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
  40. ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
  41. ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
  42. ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
  43. ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
  44. ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
  45. ultralytics/cfg/models/26/yolo26.yaml +52 -0
  46. ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
  47. ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
  48. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
  49. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
  51. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
  52. ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
  53. ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
  54. ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
  55. ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
  56. ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
  57. ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
  58. ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
  59. ultralytics/cfg/models/v6/yolov6.yaml +1 -1
  60. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
  61. ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
  62. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
  63. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
  64. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
  65. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
  66. ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
  67. ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
  68. ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
  69. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
  70. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
  71. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
  72. ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
  73. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
  74. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  75. ultralytics/data/__init__.py +4 -4
  76. ultralytics/data/annotator.py +5 -6
  77. ultralytics/data/augment.py +300 -475
  78. ultralytics/data/base.py +18 -26
  79. ultralytics/data/build.py +147 -25
  80. ultralytics/data/converter.py +108 -87
  81. ultralytics/data/dataset.py +47 -75
  82. ultralytics/data/loaders.py +42 -49
  83. ultralytics/data/split.py +5 -6
  84. ultralytics/data/split_dota.py +8 -15
  85. ultralytics/data/utils.py +36 -45
  86. ultralytics/engine/exporter.py +351 -263
  87. ultralytics/engine/model.py +186 -225
  88. ultralytics/engine/predictor.py +45 -54
  89. ultralytics/engine/results.py +198 -325
  90. ultralytics/engine/trainer.py +165 -106
  91. ultralytics/engine/tuner.py +41 -43
  92. ultralytics/engine/validator.py +55 -38
  93. ultralytics/hub/__init__.py +16 -19
  94. ultralytics/hub/auth.py +6 -12
  95. ultralytics/hub/google/__init__.py +7 -10
  96. ultralytics/hub/session.py +15 -25
  97. ultralytics/hub/utils.py +5 -8
  98. ultralytics/models/__init__.py +1 -1
  99. ultralytics/models/fastsam/__init__.py +1 -1
  100. ultralytics/models/fastsam/model.py +8 -10
  101. ultralytics/models/fastsam/predict.py +18 -30
  102. ultralytics/models/fastsam/utils.py +1 -2
  103. ultralytics/models/fastsam/val.py +5 -7
  104. ultralytics/models/nas/__init__.py +1 -1
  105. ultralytics/models/nas/model.py +5 -8
  106. ultralytics/models/nas/predict.py +7 -9
  107. ultralytics/models/nas/val.py +1 -2
  108. ultralytics/models/rtdetr/__init__.py +1 -1
  109. ultralytics/models/rtdetr/model.py +5 -8
  110. ultralytics/models/rtdetr/predict.py +15 -19
  111. ultralytics/models/rtdetr/train.py +10 -13
  112. ultralytics/models/rtdetr/val.py +21 -23
  113. ultralytics/models/sam/__init__.py +15 -2
  114. ultralytics/models/sam/amg.py +14 -20
  115. ultralytics/models/sam/build.py +26 -19
  116. ultralytics/models/sam/build_sam3.py +377 -0
  117. ultralytics/models/sam/model.py +29 -32
  118. ultralytics/models/sam/modules/blocks.py +83 -144
  119. ultralytics/models/sam/modules/decoders.py +19 -37
  120. ultralytics/models/sam/modules/encoders.py +44 -101
  121. ultralytics/models/sam/modules/memory_attention.py +16 -30
  122. ultralytics/models/sam/modules/sam.py +200 -73
  123. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  124. ultralytics/models/sam/modules/transformer.py +18 -28
  125. ultralytics/models/sam/modules/utils.py +174 -50
  126. ultralytics/models/sam/predict.py +2248 -350
  127. ultralytics/models/sam/sam3/__init__.py +3 -0
  128. ultralytics/models/sam/sam3/decoder.py +546 -0
  129. ultralytics/models/sam/sam3/encoder.py +529 -0
  130. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  131. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  132. ultralytics/models/sam/sam3/model_misc.py +199 -0
  133. ultralytics/models/sam/sam3/necks.py +129 -0
  134. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  135. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  136. ultralytics/models/sam/sam3/vitdet.py +547 -0
  137. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  138. ultralytics/models/utils/loss.py +14 -26
  139. ultralytics/models/utils/ops.py +13 -17
  140. ultralytics/models/yolo/__init__.py +1 -1
  141. ultralytics/models/yolo/classify/predict.py +10 -13
  142. ultralytics/models/yolo/classify/train.py +12 -33
  143. ultralytics/models/yolo/classify/val.py +30 -29
  144. ultralytics/models/yolo/detect/predict.py +9 -12
  145. ultralytics/models/yolo/detect/train.py +17 -23
  146. ultralytics/models/yolo/detect/val.py +77 -59
  147. ultralytics/models/yolo/model.py +43 -60
  148. ultralytics/models/yolo/obb/predict.py +7 -16
  149. ultralytics/models/yolo/obb/train.py +14 -17
  150. ultralytics/models/yolo/obb/val.py +40 -37
  151. ultralytics/models/yolo/pose/__init__.py +1 -1
  152. ultralytics/models/yolo/pose/predict.py +7 -22
  153. ultralytics/models/yolo/pose/train.py +13 -16
  154. ultralytics/models/yolo/pose/val.py +39 -58
  155. ultralytics/models/yolo/segment/predict.py +17 -21
  156. ultralytics/models/yolo/segment/train.py +7 -10
  157. ultralytics/models/yolo/segment/val.py +95 -47
  158. ultralytics/models/yolo/world/train.py +8 -14
  159. ultralytics/models/yolo/world/train_world.py +11 -34
  160. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  161. ultralytics/models/yolo/yoloe/predict.py +16 -23
  162. ultralytics/models/yolo/yoloe/train.py +36 -44
  163. ultralytics/models/yolo/yoloe/train_seg.py +11 -11
  164. ultralytics/models/yolo/yoloe/val.py +15 -20
  165. ultralytics/nn/__init__.py +7 -7
  166. ultralytics/nn/autobackend.py +159 -85
  167. ultralytics/nn/modules/__init__.py +68 -60
  168. ultralytics/nn/modules/activation.py +4 -6
  169. ultralytics/nn/modules/block.py +260 -224
  170. ultralytics/nn/modules/conv.py +52 -97
  171. ultralytics/nn/modules/head.py +831 -299
  172. ultralytics/nn/modules/transformer.py +76 -88
  173. ultralytics/nn/modules/utils.py +16 -21
  174. ultralytics/nn/tasks.py +180 -195
  175. ultralytics/nn/text_model.py +45 -69
  176. ultralytics/optim/__init__.py +5 -0
  177. ultralytics/optim/muon.py +338 -0
  178. ultralytics/solutions/__init__.py +12 -12
  179. ultralytics/solutions/ai_gym.py +13 -19
  180. ultralytics/solutions/analytics.py +15 -16
  181. ultralytics/solutions/config.py +6 -7
  182. ultralytics/solutions/distance_calculation.py +10 -13
  183. ultralytics/solutions/heatmap.py +8 -14
  184. ultralytics/solutions/instance_segmentation.py +6 -9
  185. ultralytics/solutions/object_blurrer.py +7 -10
  186. ultralytics/solutions/object_counter.py +12 -19
  187. ultralytics/solutions/object_cropper.py +8 -14
  188. ultralytics/solutions/parking_management.py +34 -32
  189. ultralytics/solutions/queue_management.py +10 -12
  190. ultralytics/solutions/region_counter.py +9 -12
  191. ultralytics/solutions/security_alarm.py +15 -20
  192. ultralytics/solutions/similarity_search.py +10 -15
  193. ultralytics/solutions/solutions.py +77 -76
  194. ultralytics/solutions/speed_estimation.py +7 -10
  195. ultralytics/solutions/streamlit_inference.py +2 -4
  196. ultralytics/solutions/templates/similarity-search.html +7 -18
  197. ultralytics/solutions/trackzone.py +7 -10
  198. ultralytics/solutions/vision_eye.py +5 -8
  199. ultralytics/trackers/__init__.py +1 -1
  200. ultralytics/trackers/basetrack.py +3 -5
  201. ultralytics/trackers/bot_sort.py +10 -27
  202. ultralytics/trackers/byte_tracker.py +21 -37
  203. ultralytics/trackers/track.py +4 -7
  204. ultralytics/trackers/utils/gmc.py +11 -22
  205. ultralytics/trackers/utils/kalman_filter.py +37 -48
  206. ultralytics/trackers/utils/matching.py +12 -15
  207. ultralytics/utils/__init__.py +124 -124
  208. ultralytics/utils/autobatch.py +2 -4
  209. ultralytics/utils/autodevice.py +17 -18
  210. ultralytics/utils/benchmarks.py +57 -71
  211. ultralytics/utils/callbacks/base.py +8 -10
  212. ultralytics/utils/callbacks/clearml.py +5 -13
  213. ultralytics/utils/callbacks/comet.py +32 -46
  214. ultralytics/utils/callbacks/dvc.py +13 -18
  215. ultralytics/utils/callbacks/mlflow.py +4 -5
  216. ultralytics/utils/callbacks/neptune.py +7 -15
  217. ultralytics/utils/callbacks/platform.py +423 -38
  218. ultralytics/utils/callbacks/raytune.py +3 -4
  219. ultralytics/utils/callbacks/tensorboard.py +25 -31
  220. ultralytics/utils/callbacks/wb.py +16 -14
  221. ultralytics/utils/checks.py +127 -85
  222. ultralytics/utils/cpu.py +3 -8
  223. ultralytics/utils/dist.py +9 -12
  224. ultralytics/utils/downloads.py +25 -33
  225. ultralytics/utils/errors.py +6 -14
  226. ultralytics/utils/events.py +2 -4
  227. ultralytics/utils/export/__init__.py +4 -236
  228. ultralytics/utils/export/engine.py +246 -0
  229. ultralytics/utils/export/imx.py +117 -63
  230. ultralytics/utils/export/tensorflow.py +231 -0
  231. ultralytics/utils/files.py +26 -30
  232. ultralytics/utils/git.py +9 -11
  233. ultralytics/utils/instance.py +30 -51
  234. ultralytics/utils/logger.py +212 -114
  235. ultralytics/utils/loss.py +601 -215
  236. ultralytics/utils/metrics.py +128 -156
  237. ultralytics/utils/nms.py +13 -16
  238. ultralytics/utils/ops.py +117 -166
  239. ultralytics/utils/patches.py +75 -21
  240. ultralytics/utils/plotting.py +75 -80
  241. ultralytics/utils/tal.py +125 -59
  242. ultralytics/utils/torch_utils.py +53 -79
  243. ultralytics/utils/tqdm.py +24 -21
  244. ultralytics/utils/triton.py +13 -19
  245. ultralytics/utils/tuner.py +19 -10
  246. dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
  247. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/entry_points.txt +0 -0
  248. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/licenses/LICENSE +0 -0
  249. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,246 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from pathlib import Path
7
+
8
+ import torch
9
+
10
+ from ultralytics.utils import IS_JETSON, LOGGER
11
+ from ultralytics.utils.torch_utils import TORCH_2_4
12
+
13
+
14
+ def torch2onnx(
15
+ torch_model: torch.nn.Module,
16
+ im: torch.Tensor,
17
+ onnx_file: str,
18
+ opset: int = 14,
19
+ input_names: list[str] = ["images"],
20
+ output_names: list[str] = ["output0"],
21
+ dynamic: bool | dict = False,
22
+ ) -> None:
23
+ """Export a PyTorch model to ONNX format.
24
+
25
+ Args:
26
+ torch_model (torch.nn.Module): The PyTorch model to export.
27
+ im (torch.Tensor): Example input tensor for the model.
28
+ onnx_file (str): Path to save the exported ONNX file.
29
+ opset (int): ONNX opset version to use for export.
30
+ input_names (list[str]): List of input tensor names.
31
+ output_names (list[str]): List of output tensor names.
32
+ dynamic (bool | dict, optional): Whether to enable dynamic axes.
33
+
34
+ Notes:
35
+ Setting `do_constant_folding=True` may cause issues with DNN inference for torch>=1.12.
36
+ """
37
+ kwargs = {"dynamo": False} if TORCH_2_4 else {}
38
+ torch.onnx.export(
39
+ torch_model,
40
+ im,
41
+ onnx_file,
42
+ verbose=False,
43
+ opset_version=opset,
44
+ do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
45
+ input_names=input_names,
46
+ output_names=output_names,
47
+ dynamic_axes=dynamic or None,
48
+ **kwargs,
49
+ )
50
+
51
+
52
+ def onnx2engine(
53
+ onnx_file: str,
54
+ engine_file: str | None = None,
55
+ workspace: int | None = None,
56
+ half: bool = False,
57
+ int8: bool = False,
58
+ dynamic: bool = False,
59
+ shape: tuple[int, int, int, int] = (1, 3, 640, 640),
60
+ dla: int | None = None,
61
+ dataset=None,
62
+ metadata: dict | None = None,
63
+ verbose: bool = False,
64
+ prefix: str = "",
65
+ ) -> None:
66
+ """Export a YOLO model to TensorRT engine format.
67
+
68
+ Args:
69
+ onnx_file (str): Path to the ONNX file to be converted.
70
+ engine_file (str, optional): Path to save the generated TensorRT engine file.
71
+ workspace (int, optional): Workspace size in GB for TensorRT.
72
+ half (bool, optional): Enable FP16 precision.
73
+ int8 (bool, optional): Enable INT8 precision.
74
+ dynamic (bool, optional): Enable dynamic input shapes.
75
+ shape (tuple[int, int, int, int], optional): Input shape (batch, channels, height, width).
76
+ dla (int, optional): DLA core to use (Jetson devices only).
77
+ dataset (ultralytics.data.build.InfiniteDataLoader, optional): Dataset for INT8 calibration.
78
+ metadata (dict, optional): Metadata to include in the engine file.
79
+ verbose (bool, optional): Enable verbose logging.
80
+ prefix (str, optional): Prefix for log messages.
81
+
82
+ Raises:
83
+ ValueError: If DLA is enabled on non-Jetson devices or required precision is not set.
84
+ RuntimeError: If the ONNX file cannot be parsed.
85
+
86
+ Notes:
87
+ TensorRT version compatibility is handled for workspace size and engine building.
88
+ INT8 calibration requires a dataset and generates a calibration cache.
89
+ Metadata is serialized and written to the engine file if provided.
90
+ """
91
+ import tensorrt as trt
92
+
93
+ engine_file = engine_file or Path(onnx_file).with_suffix(".engine")
94
+
95
+ logger = trt.Logger(trt.Logger.INFO)
96
+ if verbose:
97
+ logger.min_severity = trt.Logger.Severity.VERBOSE
98
+
99
+ # Engine builder
100
+ builder = trt.Builder(logger)
101
+ config = builder.create_builder_config()
102
+ workspace_bytes = int((workspace or 0) * (1 << 30))
103
+ is_trt10 = int(trt.__version__.split(".", 1)[0]) >= 10 # is TensorRT >= 10
104
+ if is_trt10 and workspace_bytes > 0:
105
+ config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace_bytes)
106
+ elif workspace_bytes > 0: # TensorRT versions 7, 8
107
+ config.max_workspace_size = workspace_bytes
108
+ flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
109
+ network = builder.create_network(flag)
110
+ half = builder.platform_has_fast_fp16 and half
111
+ int8 = builder.platform_has_fast_int8 and int8
112
+
113
+ # Optionally switch to DLA if enabled
114
+ if dla is not None:
115
+ if not IS_JETSON:
116
+ raise ValueError("DLA is only available on NVIDIA Jetson devices")
117
+ LOGGER.info(f"{prefix} enabling DLA on core {dla}...")
118
+ if not half and not int8:
119
+ raise ValueError(
120
+ "DLA requires either 'half=True' (FP16) or 'int8=True' (INT8) to be enabled. Please enable one of them and try again."
121
+ )
122
+ config.default_device_type = trt.DeviceType.DLA
123
+ config.DLA_core = int(dla)
124
+ config.set_flag(trt.BuilderFlag.GPU_FALLBACK)
125
+
126
+ # Read ONNX file
127
+ parser = trt.OnnxParser(network, logger)
128
+ if not parser.parse_from_file(onnx_file):
129
+ raise RuntimeError(f"failed to load ONNX file: {onnx_file}")
130
+
131
+ # Network inputs
132
+ inputs = [network.get_input(i) for i in range(network.num_inputs)]
133
+ outputs = [network.get_output(i) for i in range(network.num_outputs)]
134
+ for inp in inputs:
135
+ LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}')
136
+ for out in outputs:
137
+ LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}')
138
+
139
+ if dynamic:
140
+ profile = builder.create_optimization_profile()
141
+ min_shape = (1, shape[1], 32, 32) # minimum input shape
142
+ max_shape = (*shape[:2], *(int(max(2, workspace or 2) * d) for d in shape[2:])) # max input shape
143
+ for inp in inputs:
144
+ profile.set_shape(inp.name, min=min_shape, opt=shape, max=max_shape)
145
+ config.add_optimization_profile(profile)
146
+ if int8 and not is_trt10: # deprecated in TensorRT 10, causes internal errors
147
+ config.set_calibration_profile(profile)
148
+
149
+ LOGGER.info(f"{prefix} building {'INT8' if int8 else 'FP' + ('16' if half else '32')} engine as {engine_file}")
150
+ if int8:
151
+ config.set_flag(trt.BuilderFlag.INT8)
152
+ config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED
153
+
154
+ class EngineCalibrator(trt.IInt8Calibrator):
155
+ """Custom INT8 calibrator for TensorRT engine optimization.
156
+
157
+ This calibrator provides the necessary interface for TensorRT to perform INT8 quantization calibration using
158
+ a dataset. It handles batch generation, caching, and calibration algorithm selection.
159
+
160
+ Attributes:
161
+ dataset: Dataset for calibration.
162
+ data_iter: Iterator over the calibration dataset.
163
+ algo (trt.CalibrationAlgoType): Calibration algorithm type.
164
+ batch (int): Batch size for calibration.
165
+ cache (Path): Path to save the calibration cache.
166
+
167
+ Methods:
168
+ get_algorithm: Get the calibration algorithm to use.
169
+ get_batch_size: Get the batch size to use for calibration.
170
+ get_batch: Get the next batch to use for calibration.
171
+ read_calibration_cache: Use existing cache instead of calibrating again.
172
+ write_calibration_cache: Write calibration cache to disk.
173
+ """
174
+
175
+ def __init__(
176
+ self,
177
+ dataset, # ultralytics.data.build.InfiniteDataLoader
178
+ cache: str = "",
179
+ ) -> None:
180
+ """Initialize the INT8 calibrator with dataset and cache path."""
181
+ trt.IInt8Calibrator.__init__(self)
182
+ self.dataset = dataset
183
+ self.data_iter = iter(dataset)
184
+ self.algo = (
185
+ trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2 # DLA quantization needs ENTROPY_CALIBRATION_2
186
+ if dla is not None
187
+ else trt.CalibrationAlgoType.MINMAX_CALIBRATION
188
+ )
189
+ self.batch = dataset.batch_size
190
+ self.cache = Path(cache)
191
+
192
+ def get_algorithm(self) -> trt.CalibrationAlgoType:
193
+ """Get the calibration algorithm to use."""
194
+ return self.algo
195
+
196
+ def get_batch_size(self) -> int:
197
+ """Get the batch size to use for calibration."""
198
+ return self.batch or 1
199
+
200
+ def get_batch(self, names) -> list[int] | None:
201
+ """Get the next batch to use for calibration, as a list of device memory pointers."""
202
+ try:
203
+ im0s = next(self.data_iter)["img"] / 255.0
204
+ im0s = im0s.to("cuda") if im0s.device.type == "cpu" else im0s
205
+ return [int(im0s.data_ptr())]
206
+ except StopIteration:
207
+ # Return None to signal to TensorRT there is no calibration data remaining
208
+ return None
209
+
210
+ def read_calibration_cache(self) -> bytes | None:
211
+ """Use existing cache instead of calibrating again, otherwise, implicitly return None."""
212
+ if self.cache.exists() and self.cache.suffix == ".cache":
213
+ return self.cache.read_bytes()
214
+
215
+ def write_calibration_cache(self, cache: bytes) -> None:
216
+ """Write calibration cache to disk."""
217
+ _ = self.cache.write_bytes(cache)
218
+
219
+ # Load dataset w/ builder (for batching) and calibrate
220
+ config.int8_calibrator = EngineCalibrator(
221
+ dataset=dataset,
222
+ cache=str(Path(onnx_file).with_suffix(".cache")),
223
+ )
224
+
225
+ elif half:
226
+ config.set_flag(trt.BuilderFlag.FP16)
227
+
228
+ # Write file
229
+ if is_trt10:
230
+ # TensorRT 10+ returns bytes directly, not a context manager
231
+ engine = builder.build_serialized_network(network, config)
232
+ if engine is None:
233
+ raise RuntimeError("TensorRT engine build failed, check logs for errors")
234
+ with open(engine_file, "wb") as t:
235
+ if metadata is not None:
236
+ meta = json.dumps(metadata)
237
+ t.write(len(meta).to_bytes(4, byteorder="little", signed=True))
238
+ t.write(meta.encode())
239
+ t.write(engine)
240
+ else:
241
+ with builder.build_engine(network, config) as engine, open(engine_file, "wb") as t:
242
+ if metadata is not None:
243
+ meta = json.dumps(metadata)
244
+ t.write(len(meta).to_bytes(4, byteorder="little", signed=True))
245
+ t.write(meta.encode())
246
+ t.write(engine.serialize())
@@ -3,20 +3,63 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import subprocess
6
+ import sys
6
7
  import types
7
8
  from pathlib import Path
9
+ from shutil import which
8
10
 
11
+ import numpy as np
9
12
  import torch
10
13
 
11
- from ultralytics.nn.modules import Detect, Pose
12
- from ultralytics.utils import LOGGER
14
+ from ultralytics.nn.modules import Detect, Pose, Segment
15
+ from ultralytics.utils import LOGGER, WINDOWS
16
+ from ultralytics.utils.patches import onnx_export_patch
13
17
  from ultralytics.utils.tal import make_anchors
14
18
  from ultralytics.utils.torch_utils import copy_attr
15
19
 
20
+ # Configuration for Model Compression Toolkit (MCT) quantization
21
+ MCT_CONFIG = {
22
+ "YOLO11": {
23
+ "detect": {
24
+ "layer_names": ["sub", "mul_2", "add_14", "cat_19"],
25
+ "weights_memory": 2585350.2439,
26
+ "n_layers": {238, 239},
27
+ },
28
+ "pose": {
29
+ "layer_names": ["sub", "mul_2", "add_14", "cat_21", "cat_22", "mul_4", "add_15"],
30
+ "weights_memory": 2437771.67,
31
+ "n_layers": {257, 258},
32
+ },
33
+ "classify": {"layer_names": [], "weights_memory": np.inf, "n_layers": {112}},
34
+ "segment": {
35
+ "layer_names": ["sub", "mul_2", "add_14", "cat_21"],
36
+ "weights_memory": 2466604.8,
37
+ "n_layers": {265, 266},
38
+ },
39
+ },
40
+ "YOLOv8": {
41
+ "detect": {
42
+ "layer_names": ["sub", "mul", "add_6", "cat_15"],
43
+ "weights_memory": 2550540.8,
44
+ "n_layers": {168, 169},
45
+ },
46
+ "pose": {
47
+ "layer_names": ["add_7", "mul_2", "cat_17", "mul", "sub", "add_6", "cat_18"],
48
+ "weights_memory": 2482451.85,
49
+ "n_layers": {187, 188},
50
+ },
51
+ "classify": {"layer_names": [], "weights_memory": np.inf, "n_layers": {73}},
52
+ "segment": {
53
+ "layer_names": ["sub", "mul", "add_6", "cat_17"],
54
+ "weights_memory": 2580060.0,
55
+ "n_layers": {195, 196},
56
+ },
57
+ },
58
+ }
59
+
16
60
 
17
61
  class FXModel(torch.nn.Module):
18
- """
19
- A custom model class for torch.fx compatibility.
62
+ """A custom model class for torch.fx compatibility.
20
63
 
21
64
  This class extends `torch.nn.Module` and is designed to ensure compatibility with torch.fx for tracing and graph
22
65
  manipulation. It copies attributes from an existing model and explicitly sets the model attribute to ensure proper
@@ -27,8 +70,7 @@ class FXModel(torch.nn.Module):
27
70
  """
28
71
 
29
72
  def __init__(self, model, imgsz=(640, 640)):
30
- """
31
- Initialize the FXModel.
73
+ """Initialize the FXModel.
32
74
 
33
75
  Args:
34
76
  model (nn.Module): The original model to wrap for torch.fx compatibility.
@@ -41,8 +83,7 @@ class FXModel(torch.nn.Module):
41
83
  self.imgsz = imgsz
42
84
 
43
85
  def forward(self, x):
44
- """
45
- Forward pass through the model.
86
+ """Forward pass through the model.
46
87
 
47
88
  This method performs the forward pass through the model, handling the dependencies between layers and saving
48
89
  intermediate outputs.
@@ -68,30 +109,47 @@ class FXModel(torch.nn.Module):
68
109
  )
69
110
  if type(m) is Pose:
70
111
  m.forward = types.MethodType(pose_forward, m) # bind method to Detect
112
+ if type(m) is Segment:
113
+ m.forward = types.MethodType(segment_forward, m) # bind method to Detect
71
114
  x = m(x) # run
72
115
  y.append(x) # save output
73
116
  return x
74
117
 
75
118
 
76
- def _inference(self, x: list[torch.Tensor]) -> tuple[torch.Tensor]:
119
+ def _inference(self, x: dict[str, torch.Tensor]) -> tuple[torch.Tensor]:
77
120
  """Decode boxes and cls scores for imx object detection."""
78
- x_cat = torch.cat([xi.view(x[0].shape[0], self.no, -1) for xi in x], 2)
79
- box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
80
- dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
81
- return dbox.transpose(1, 2), cls.sigmoid().permute(0, 2, 1)
121
+ dbox = self.decode_bboxes(self.dfl(x["boxes"]), self.anchors.unsqueeze(0)) * self.strides
122
+ return dbox.transpose(1, 2), x["scores"].sigmoid().permute(0, 2, 1)
82
123
 
83
124
 
84
125
  def pose_forward(self, x: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
85
126
  """Forward pass for imx pose estimation, including keypoint decoding."""
86
127
  bs = x[0].shape[0] # batch size
87
- kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w)
128
+ nk_out = getattr(self, "nk_output", self.nk)
129
+ kpt = torch.cat([self.cv4[i](x[i]).view(bs, nk_out, -1) for i in range(self.nl)], -1)
130
+
131
+ # If using Pose26 with 5 dims, convert to 3 dims for export
132
+ if hasattr(self, "nk_output") and self.nk_output != self.nk:
133
+ spatial = kpt.shape[-1]
134
+ kpt = kpt.view(bs, self.kpt_shape[0], self.kpt_shape[1] + 2, spatial)
135
+ kpt = kpt[:, :, :-2, :] # Remove sigma_x, sigma_y
136
+ kpt = kpt.view(bs, self.nk, spatial)
88
137
  x = Detect.forward(self, x)
89
- pred_kpt = self.kpts_decode(bs, kpt)
90
- return (*x, pred_kpt.permute(0, 2, 1))
138
+ pred_kpt = self.kpts_decode(kpt)
139
+ return *x, pred_kpt.permute(0, 2, 1)
140
+
141
+
142
+ def segment_forward(self, x: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
143
+ """Forward pass for imx segmentation."""
144
+ p = self.proto(x[0]) # mask protos
145
+ bs = p.shape[0] # batch size
146
+ mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
147
+ x = Detect.forward(self, x)
148
+ return *x, mc.transpose(1, 2), p
91
149
 
92
150
 
93
151
  class NMSWrapper(torch.nn.Module):
94
- """Wrap PyTorch Module with multiclass_nms layer from sony_custom_layers."""
152
+ """Wrap PyTorch Module with multiclass_nms layer from edge-mdt-cl."""
95
153
 
96
154
  def __init__(
97
155
  self,
@@ -101,8 +159,7 @@ class NMSWrapper(torch.nn.Module):
101
159
  max_detections: int = 300,
102
160
  task: str = "detect",
103
161
  ):
104
- """
105
- Initialize NMSWrapper with PyTorch Module and NMS parameters.
162
+ """Initialize NMSWrapper with PyTorch Module and NMS parameters.
106
163
 
107
164
  Args:
108
165
  model (torch.nn.Module): Model instance.
@@ -120,7 +177,7 @@ class NMSWrapper(torch.nn.Module):
120
177
 
121
178
  def forward(self, images):
122
179
  """Forward pass with model inference and NMS post-processing."""
123
- from sony_custom_layers.pytorch import multiclass_nms_with_indices
180
+ from edgemdt_cl.pytorch.nms.nms_with_indices import multiclass_nms_with_indices
124
181
 
125
182
  # model inference
126
183
  outputs = self.model(images)
@@ -136,6 +193,10 @@ class NMSWrapper(torch.nn.Module):
136
193
  kpts = outputs[2] # (bs, max_detections, kpts 17*3)
137
194
  out_kpts = torch.gather(kpts, 1, nms_outputs.indices.unsqueeze(-1).expand(-1, -1, kpts.size(-1)))
138
195
  return nms_outputs.boxes, nms_outputs.scores, nms_outputs.labels, out_kpts
196
+ if self.task == "segment":
197
+ mc, proto = outputs[2], outputs[3]
198
+ out_mc = torch.gather(mc, 1, nms_outputs.indices.unsqueeze(-1).expand(-1, -1, mc.size(-1)))
199
+ return nms_outputs.boxes, nms_outputs.scores, nms_outputs.labels, out_mc, proto
139
200
  return nms_outputs.boxes, nms_outputs.scores, nms_outputs.labels, nms_outputs.n_valid
140
201
 
141
202
 
@@ -150,12 +211,10 @@ def torch2imx(
150
211
  dataset=None,
151
212
  prefix: str = "",
152
213
  ):
153
- """
154
- Export YOLO model to IMX format for deployment on Sony IMX500 devices.
214
+ """Export YOLO model to IMX format for deployment on Sony IMX500 devices.
155
215
 
156
- This function quantizes a YOLO model using Model Compression Toolkit (MCT) and exports it
157
- to IMX format compatible with Sony IMX500 edge devices. It supports both YOLOv8n and YOLO11n
158
- models for detection and pose estimation tasks.
216
+ This function quantizes a YOLO model using Model Compression Toolkit (MCT) and exports it to IMX format compatible
217
+ with Sony IMX500 edge devices. It supports both YOLOv8n and YOLO11n models for detection and pose estimation tasks.
159
218
 
160
219
  Args:
161
220
  model (torch.nn.Module): The YOLO model to export. Must be YOLOv8n or YOLO11n.
@@ -164,8 +223,8 @@ def torch2imx(
164
223
  iou (float): IoU threshold for NMS post-processing.
165
224
  max_det (int): Maximum number of detections to return.
166
225
  metadata (dict | None, optional): Metadata to embed in the ONNX model. Defaults to None.
167
- gptq (bool, optional): Whether to use Gradient-Based Post Training Quantization.
168
- If False, uses standard Post Training Quantization. Defaults to False.
226
+ gptq (bool, optional): Whether to use Gradient-Based Post Training Quantization. If False, uses standard Post
227
+ Training Quantization. Defaults to False.
169
228
  dataset (optional): Representative dataset for quantization calibration. Defaults to None.
170
229
  prefix (str, optional): Logging prefix string. Defaults to "".
171
230
 
@@ -175,13 +234,13 @@ def torch2imx(
175
234
  Raises:
176
235
  ValueError: If the model is not a supported YOLOv8n or YOLO11n variant.
177
236
 
178
- Example:
237
+ Examples:
179
238
  >>> from ultralytics import YOLO
180
239
  >>> model = YOLO("yolo11n.pt")
181
- >>> path, _ = export_imx(model, "model.imx", conf=0.25, iou=0.45, max_det=300)
240
+ >>> path, _ = export_imx(model, "model.imx", conf=0.25, iou=0.7, max_det=300)
182
241
 
183
- Note:
184
- - Requires model_compression_toolkit, onnx, edgemdt_tpc, and sony_custom_layers packages
242
+ Notes:
243
+ - Requires model_compression_toolkit, onnx, edgemdt_tpc, and edge-mdt-cl packages
185
244
  - Only supports YOLOv8n and YOLO11n models (detection and pose tasks)
186
245
  - Output includes quantized ONNX model, IMX binary, and labels.txt file
187
246
  """
@@ -197,33 +256,17 @@ def torch2imx(
197
256
  img = img / 255.0
198
257
  yield [img]
199
258
 
259
+ # NOTE: need tpc_version to be "4.0" for IMX500 Pose estimation models
200
260
  tpc = get_target_platform_capabilities(tpc_version="4.0", device_type="imx500")
201
261
 
202
262
  bit_cfg = mct.core.BitWidthConfig()
203
- if "C2PSA" in model.__str__(): # YOLO11
204
- if model.task == "detect":
205
- layer_names = ["sub", "mul_2", "add_14", "cat_21"]
206
- weights_memory = 2585350.2439
207
- n_layers = 238 # 238 layers for fused YOLO11n
208
- elif model.task == "pose":
209
- layer_names = ["sub", "mul_2", "add_14", "cat_22", "cat_23", "mul_4", "add_15"]
210
- weights_memory = 2437771.67
211
- n_layers = 257 # 257 layers for fused YOLO11n-pose
212
- else: # YOLOv8
213
- if model.task == "detect":
214
- layer_names = ["sub", "mul", "add_6", "cat_17"]
215
- weights_memory = 2550540.8
216
- n_layers = 168 # 168 layers for fused YOLOv8n
217
- elif model.task == "pose":
218
- layer_names = ["add_7", "mul_2", "cat_19", "mul", "sub", "add_6", "cat_18"]
219
- weights_memory = 2482451.85
220
- n_layers = 187 # 187 layers for fused YOLO11n-pose
263
+ mct_config = MCT_CONFIG["YOLO11" if "C2PSA" in model.__str__() else "YOLOv8"][model.task]
221
264
 
222
265
  # Check if the model has the expected number of layers
223
- if len(list(model.modules())) != n_layers:
266
+ if len(list(model.modules())) not in mct_config["n_layers"]:
224
267
  raise ValueError("IMX export only supported for YOLOv8n and YOLO11n models.")
225
268
 
226
- for layer_name in layer_names:
269
+ for layer_name in mct_config["layer_names"]:
227
270
  bit_cfg.set_manual_activation_bit_width([mct.core.common.network_editors.NodeNameFilter(layer_name)], 16)
228
271
 
229
272
  config = mct.core.CoreConfig(
@@ -232,7 +275,7 @@ def torch2imx(
232
275
  bit_width_config=bit_cfg,
233
276
  )
234
277
 
235
- resource_utilization = mct.core.ResourceUtilization(weights_memory=weights_memory)
278
+ resource_utilization = mct.core.ResourceUtilization(weights_memory=mct_config["weights_memory"])
236
279
 
237
280
  quant_model = (
238
281
  mct.gptq.pytorch_gradient_post_training_quantization( # Perform Gradient-Based Post Training Quantization
@@ -255,20 +298,23 @@ def torch2imx(
255
298
  )[0]
256
299
  )
257
300
 
258
- quant_model = NMSWrapper(
259
- model=quant_model,
260
- score_threshold=conf or 0.001,
261
- iou_threshold=iou,
262
- max_detections=max_det,
263
- task=model.task,
264
- )
301
+ if model.task != "classify":
302
+ quant_model = NMSWrapper(
303
+ model=quant_model,
304
+ score_threshold=conf or 0.001,
305
+ iou_threshold=iou,
306
+ max_detections=max_det,
307
+ task=model.task,
308
+ )
265
309
 
266
310
  f = Path(str(file).replace(file.suffix, "_imx_model"))
267
311
  f.mkdir(exist_ok=True)
268
312
  onnx_model = f / Path(str(file.name).replace(file.suffix, "_imx.onnx")) # js dir
269
- mct.exporter.pytorch_export_model(
270
- model=quant_model, save_model_path=onnx_model, repr_dataset=representative_dataset_gen
271
- )
313
+
314
+ with onnx_export_patch():
315
+ mct.exporter.pytorch_export_model(
316
+ model=quant_model, save_model_path=onnx_model, repr_dataset=representative_dataset_gen
317
+ )
272
318
 
273
319
  model_onnx = onnx.load(onnx_model) # load onnx model
274
320
  for k, v in metadata.items():
@@ -277,8 +323,16 @@ def torch2imx(
277
323
 
278
324
  onnx.save(model_onnx, onnx_model)
279
325
 
326
+ # Find imxconv-pt binary - check venv bin directory first, then PATH
327
+ bin_dir = Path(sys.executable).parent
328
+ imxconv = bin_dir / ("imxconv-pt.exe" if WINDOWS else "imxconv-pt")
329
+ if not imxconv.exists():
330
+ imxconv = which("imxconv-pt") # fallback to PATH
331
+ if not imxconv:
332
+ raise FileNotFoundError("imxconv-pt not found. Install with: pip install imx500-converter[pt]")
333
+
280
334
  subprocess.run(
281
- ["imxconv-pt", "-i", str(onnx_model), "-o", str(f), "--no-input-persistency", "--overwrite-output"],
335
+ [str(imxconv), "-i", str(onnx_model), "-o", str(f), "--no-input-persistency", "--overwrite-output"],
282
336
  check=True,
283
337
  )
284
338