ultralytics-opencv-headless 8.3.242__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (298) hide show
  1. tests/__init__.py +23 -0
  2. tests/conftest.py +59 -0
  3. tests/test_cli.py +131 -0
  4. tests/test_cuda.py +216 -0
  5. tests/test_engine.py +157 -0
  6. tests/test_exports.py +309 -0
  7. tests/test_integrations.py +151 -0
  8. tests/test_python.py +777 -0
  9. tests/test_solutions.py +371 -0
  10. ultralytics/__init__.py +48 -0
  11. ultralytics/assets/bus.jpg +0 -0
  12. ultralytics/assets/zidane.jpg +0 -0
  13. ultralytics/cfg/__init__.py +1026 -0
  14. ultralytics/cfg/datasets/Argoverse.yaml +78 -0
  15. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  16. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  17. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  18. ultralytics/cfg/datasets/HomeObjects-3K.yaml +32 -0
  19. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  20. ultralytics/cfg/datasets/Objects365.yaml +447 -0
  21. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  22. ultralytics/cfg/datasets/VOC.yaml +102 -0
  23. ultralytics/cfg/datasets/VisDrone.yaml +87 -0
  24. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  25. ultralytics/cfg/datasets/brain-tumor.yaml +22 -0
  26. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  27. ultralytics/cfg/datasets/coco-pose.yaml +64 -0
  28. ultralytics/cfg/datasets/coco.yaml +118 -0
  29. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  30. ultralytics/cfg/datasets/coco128.yaml +101 -0
  31. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  32. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  33. ultralytics/cfg/datasets/coco8-pose.yaml +47 -0
  34. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  35. ultralytics/cfg/datasets/coco8.yaml +101 -0
  36. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  37. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  38. ultralytics/cfg/datasets/dog-pose.yaml +52 -0
  39. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  40. ultralytics/cfg/datasets/dota8.yaml +35 -0
  41. ultralytics/cfg/datasets/hand-keypoints.yaml +50 -0
  42. ultralytics/cfg/datasets/kitti.yaml +27 -0
  43. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  44. ultralytics/cfg/datasets/medical-pills.yaml +21 -0
  45. ultralytics/cfg/datasets/open-images-v7.yaml +663 -0
  46. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  47. ultralytics/cfg/datasets/signature.yaml +21 -0
  48. ultralytics/cfg/datasets/tiger-pose.yaml +41 -0
  49. ultralytics/cfg/datasets/xView.yaml +155 -0
  50. ultralytics/cfg/default.yaml +130 -0
  51. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  52. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  53. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  54. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  55. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  56. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  57. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  58. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  59. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  60. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  61. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  62. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  63. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  64. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  65. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  66. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  67. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  68. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  69. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  70. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  71. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  72. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  73. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  74. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  75. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  76. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  77. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  78. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  79. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  80. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +48 -0
  81. ultralytics/cfg/models/v8/yoloe-v8.yaml +48 -0
  82. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  83. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  84. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  85. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  86. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  87. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  88. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  89. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  90. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  91. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  92. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  93. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  94. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  95. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  96. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  97. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  98. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  99. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  100. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  101. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  102. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  103. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  104. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  105. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  106. ultralytics/cfg/trackers/botsort.yaml +21 -0
  107. ultralytics/cfg/trackers/bytetrack.yaml +12 -0
  108. ultralytics/data/__init__.py +26 -0
  109. ultralytics/data/annotator.py +66 -0
  110. ultralytics/data/augment.py +2801 -0
  111. ultralytics/data/base.py +435 -0
  112. ultralytics/data/build.py +437 -0
  113. ultralytics/data/converter.py +855 -0
  114. ultralytics/data/dataset.py +834 -0
  115. ultralytics/data/loaders.py +704 -0
  116. ultralytics/data/scripts/download_weights.sh +18 -0
  117. ultralytics/data/scripts/get_coco.sh +61 -0
  118. ultralytics/data/scripts/get_coco128.sh +18 -0
  119. ultralytics/data/scripts/get_imagenet.sh +52 -0
  120. ultralytics/data/split.py +138 -0
  121. ultralytics/data/split_dota.py +344 -0
  122. ultralytics/data/utils.py +798 -0
  123. ultralytics/engine/__init__.py +1 -0
  124. ultralytics/engine/exporter.py +1574 -0
  125. ultralytics/engine/model.py +1124 -0
  126. ultralytics/engine/predictor.py +508 -0
  127. ultralytics/engine/results.py +1522 -0
  128. ultralytics/engine/trainer.py +974 -0
  129. ultralytics/engine/tuner.py +448 -0
  130. ultralytics/engine/validator.py +384 -0
  131. ultralytics/hub/__init__.py +166 -0
  132. ultralytics/hub/auth.py +151 -0
  133. ultralytics/hub/google/__init__.py +174 -0
  134. ultralytics/hub/session.py +422 -0
  135. ultralytics/hub/utils.py +162 -0
  136. ultralytics/models/__init__.py +9 -0
  137. ultralytics/models/fastsam/__init__.py +7 -0
  138. ultralytics/models/fastsam/model.py +79 -0
  139. ultralytics/models/fastsam/predict.py +169 -0
  140. ultralytics/models/fastsam/utils.py +23 -0
  141. ultralytics/models/fastsam/val.py +38 -0
  142. ultralytics/models/nas/__init__.py +7 -0
  143. ultralytics/models/nas/model.py +98 -0
  144. ultralytics/models/nas/predict.py +56 -0
  145. ultralytics/models/nas/val.py +38 -0
  146. ultralytics/models/rtdetr/__init__.py +7 -0
  147. ultralytics/models/rtdetr/model.py +63 -0
  148. ultralytics/models/rtdetr/predict.py +88 -0
  149. ultralytics/models/rtdetr/train.py +89 -0
  150. ultralytics/models/rtdetr/val.py +216 -0
  151. ultralytics/models/sam/__init__.py +25 -0
  152. ultralytics/models/sam/amg.py +275 -0
  153. ultralytics/models/sam/build.py +365 -0
  154. ultralytics/models/sam/build_sam3.py +377 -0
  155. ultralytics/models/sam/model.py +169 -0
  156. ultralytics/models/sam/modules/__init__.py +1 -0
  157. ultralytics/models/sam/modules/blocks.py +1067 -0
  158. ultralytics/models/sam/modules/decoders.py +495 -0
  159. ultralytics/models/sam/modules/encoders.py +794 -0
  160. ultralytics/models/sam/modules/memory_attention.py +298 -0
  161. ultralytics/models/sam/modules/sam.py +1160 -0
  162. ultralytics/models/sam/modules/tiny_encoder.py +979 -0
  163. ultralytics/models/sam/modules/transformer.py +344 -0
  164. ultralytics/models/sam/modules/utils.py +512 -0
  165. ultralytics/models/sam/predict.py +3940 -0
  166. ultralytics/models/sam/sam3/__init__.py +3 -0
  167. ultralytics/models/sam/sam3/decoder.py +546 -0
  168. ultralytics/models/sam/sam3/encoder.py +529 -0
  169. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  170. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  171. ultralytics/models/sam/sam3/model_misc.py +199 -0
  172. ultralytics/models/sam/sam3/necks.py +129 -0
  173. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  174. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  175. ultralytics/models/sam/sam3/vitdet.py +547 -0
  176. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  177. ultralytics/models/utils/__init__.py +1 -0
  178. ultralytics/models/utils/loss.py +466 -0
  179. ultralytics/models/utils/ops.py +315 -0
  180. ultralytics/models/yolo/__init__.py +7 -0
  181. ultralytics/models/yolo/classify/__init__.py +7 -0
  182. ultralytics/models/yolo/classify/predict.py +90 -0
  183. ultralytics/models/yolo/classify/train.py +202 -0
  184. ultralytics/models/yolo/classify/val.py +216 -0
  185. ultralytics/models/yolo/detect/__init__.py +7 -0
  186. ultralytics/models/yolo/detect/predict.py +122 -0
  187. ultralytics/models/yolo/detect/train.py +227 -0
  188. ultralytics/models/yolo/detect/val.py +507 -0
  189. ultralytics/models/yolo/model.py +430 -0
  190. ultralytics/models/yolo/obb/__init__.py +7 -0
  191. ultralytics/models/yolo/obb/predict.py +56 -0
  192. ultralytics/models/yolo/obb/train.py +79 -0
  193. ultralytics/models/yolo/obb/val.py +302 -0
  194. ultralytics/models/yolo/pose/__init__.py +7 -0
  195. ultralytics/models/yolo/pose/predict.py +65 -0
  196. ultralytics/models/yolo/pose/train.py +110 -0
  197. ultralytics/models/yolo/pose/val.py +248 -0
  198. ultralytics/models/yolo/segment/__init__.py +7 -0
  199. ultralytics/models/yolo/segment/predict.py +109 -0
  200. ultralytics/models/yolo/segment/train.py +69 -0
  201. ultralytics/models/yolo/segment/val.py +307 -0
  202. ultralytics/models/yolo/world/__init__.py +5 -0
  203. ultralytics/models/yolo/world/train.py +173 -0
  204. ultralytics/models/yolo/world/train_world.py +178 -0
  205. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  206. ultralytics/models/yolo/yoloe/predict.py +162 -0
  207. ultralytics/models/yolo/yoloe/train.py +287 -0
  208. ultralytics/models/yolo/yoloe/train_seg.py +122 -0
  209. ultralytics/models/yolo/yoloe/val.py +206 -0
  210. ultralytics/nn/__init__.py +27 -0
  211. ultralytics/nn/autobackend.py +958 -0
  212. ultralytics/nn/modules/__init__.py +182 -0
  213. ultralytics/nn/modules/activation.py +54 -0
  214. ultralytics/nn/modules/block.py +1947 -0
  215. ultralytics/nn/modules/conv.py +669 -0
  216. ultralytics/nn/modules/head.py +1183 -0
  217. ultralytics/nn/modules/transformer.py +793 -0
  218. ultralytics/nn/modules/utils.py +159 -0
  219. ultralytics/nn/tasks.py +1768 -0
  220. ultralytics/nn/text_model.py +356 -0
  221. ultralytics/py.typed +1 -0
  222. ultralytics/solutions/__init__.py +41 -0
  223. ultralytics/solutions/ai_gym.py +108 -0
  224. ultralytics/solutions/analytics.py +264 -0
  225. ultralytics/solutions/config.py +107 -0
  226. ultralytics/solutions/distance_calculation.py +123 -0
  227. ultralytics/solutions/heatmap.py +125 -0
  228. ultralytics/solutions/instance_segmentation.py +86 -0
  229. ultralytics/solutions/object_blurrer.py +89 -0
  230. ultralytics/solutions/object_counter.py +190 -0
  231. ultralytics/solutions/object_cropper.py +87 -0
  232. ultralytics/solutions/parking_management.py +280 -0
  233. ultralytics/solutions/queue_management.py +93 -0
  234. ultralytics/solutions/region_counter.py +133 -0
  235. ultralytics/solutions/security_alarm.py +151 -0
  236. ultralytics/solutions/similarity_search.py +219 -0
  237. ultralytics/solutions/solutions.py +828 -0
  238. ultralytics/solutions/speed_estimation.py +114 -0
  239. ultralytics/solutions/streamlit_inference.py +260 -0
  240. ultralytics/solutions/templates/similarity-search.html +156 -0
  241. ultralytics/solutions/trackzone.py +88 -0
  242. ultralytics/solutions/vision_eye.py +67 -0
  243. ultralytics/trackers/__init__.py +7 -0
  244. ultralytics/trackers/basetrack.py +115 -0
  245. ultralytics/trackers/bot_sort.py +257 -0
  246. ultralytics/trackers/byte_tracker.py +469 -0
  247. ultralytics/trackers/track.py +116 -0
  248. ultralytics/trackers/utils/__init__.py +1 -0
  249. ultralytics/trackers/utils/gmc.py +339 -0
  250. ultralytics/trackers/utils/kalman_filter.py +482 -0
  251. ultralytics/trackers/utils/matching.py +154 -0
  252. ultralytics/utils/__init__.py +1450 -0
  253. ultralytics/utils/autobatch.py +118 -0
  254. ultralytics/utils/autodevice.py +205 -0
  255. ultralytics/utils/benchmarks.py +728 -0
  256. ultralytics/utils/callbacks/__init__.py +5 -0
  257. ultralytics/utils/callbacks/base.py +233 -0
  258. ultralytics/utils/callbacks/clearml.py +146 -0
  259. ultralytics/utils/callbacks/comet.py +625 -0
  260. ultralytics/utils/callbacks/dvc.py +197 -0
  261. ultralytics/utils/callbacks/hub.py +110 -0
  262. ultralytics/utils/callbacks/mlflow.py +134 -0
  263. ultralytics/utils/callbacks/neptune.py +126 -0
  264. ultralytics/utils/callbacks/platform.py +73 -0
  265. ultralytics/utils/callbacks/raytune.py +42 -0
  266. ultralytics/utils/callbacks/tensorboard.py +123 -0
  267. ultralytics/utils/callbacks/wb.py +188 -0
  268. ultralytics/utils/checks.py +998 -0
  269. ultralytics/utils/cpu.py +85 -0
  270. ultralytics/utils/dist.py +123 -0
  271. ultralytics/utils/downloads.py +529 -0
  272. ultralytics/utils/errors.py +35 -0
  273. ultralytics/utils/events.py +113 -0
  274. ultralytics/utils/export/__init__.py +7 -0
  275. ultralytics/utils/export/engine.py +237 -0
  276. ultralytics/utils/export/imx.py +315 -0
  277. ultralytics/utils/export/tensorflow.py +231 -0
  278. ultralytics/utils/files.py +219 -0
  279. ultralytics/utils/git.py +137 -0
  280. ultralytics/utils/instance.py +484 -0
  281. ultralytics/utils/logger.py +444 -0
  282. ultralytics/utils/loss.py +849 -0
  283. ultralytics/utils/metrics.py +1560 -0
  284. ultralytics/utils/nms.py +337 -0
  285. ultralytics/utils/ops.py +664 -0
  286. ultralytics/utils/patches.py +201 -0
  287. ultralytics/utils/plotting.py +1045 -0
  288. ultralytics/utils/tal.py +403 -0
  289. ultralytics/utils/torch_utils.py +984 -0
  290. ultralytics/utils/tqdm.py +440 -0
  291. ultralytics/utils/triton.py +112 -0
  292. ultralytics/utils/tuner.py +160 -0
  293. ultralytics_opencv_headless-8.3.242.dist-info/METADATA +374 -0
  294. ultralytics_opencv_headless-8.3.242.dist-info/RECORD +298 -0
  295. ultralytics_opencv_headless-8.3.242.dist-info/WHEEL +5 -0
  296. ultralytics_opencv_headless-8.3.242.dist-info/entry_points.txt +3 -0
  297. ultralytics_opencv_headless-8.3.242.dist-info/licenses/LICENSE +661 -0
  298. ultralytics_opencv_headless-8.3.242.dist-info/top_level.txt +1 -0
@@ -0,0 +1,958 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from __future__ import annotations
4
+
5
+ import ast
6
+ import json
7
+ import platform
8
+ import zipfile
9
+ from collections import OrderedDict, namedtuple
10
+ from pathlib import Path
11
+ from typing import Any
12
+
13
+ import cv2
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ from PIL import Image
18
+
19
+ from ultralytics.utils import ARM64, IS_JETSON, LINUX, LOGGER, PYTHON_VERSION, ROOT, YAML, is_jetson
20
+ from ultralytics.utils.checks import check_requirements, check_suffix, check_version, check_yaml, is_rockchip
21
+ from ultralytics.utils.downloads import attempt_download_asset, is_url
22
+ from ultralytics.utils.nms import non_max_suppression
23
+
24
+
25
+ def check_class_names(names: list | dict) -> dict[int, str]:
26
+ """Check class names and convert to dict format if needed.
27
+
28
+ Args:
29
+ names (list | dict): Class names as list or dict format.
30
+
31
+ Returns:
32
+ (dict): Class names in dict format with integer keys and string values.
33
+
34
+ Raises:
35
+ KeyError: If class indices are invalid for the dataset size.
36
+ """
37
+ if isinstance(names, list): # names is a list
38
+ names = dict(enumerate(names)) # convert to dict
39
+ if isinstance(names, dict):
40
+ # Convert 1) string keys to int, i.e. '0' to 0, and non-string values to strings, i.e. True to 'True'
41
+ names = {int(k): str(v) for k, v in names.items()}
42
+ n = len(names)
43
+ if max(names.keys()) >= n:
44
+ raise KeyError(
45
+ f"{n}-class dataset requires class indices 0-{n - 1}, but you have invalid class indices "
46
+ f"{min(names.keys())}-{max(names.keys())} defined in your dataset YAML."
47
+ )
48
+ if isinstance(names[0], str) and names[0].startswith("n0"): # imagenet class codes, i.e. 'n01440764'
49
+ names_map = YAML.load(ROOT / "cfg/datasets/ImageNet.yaml")["map"] # human-readable names
50
+ names = {k: names_map[v] for k, v in names.items()}
51
+ return names
52
+
53
+
54
+ def default_class_names(data: str | Path | None = None) -> dict[int, str]:
55
+ """Apply default class names to an input YAML file or return numerical class names.
56
+
57
+ Args:
58
+ data (str | Path, optional): Path to YAML file containing class names.
59
+
60
+ Returns:
61
+ (dict): Dictionary mapping class indices to class names.
62
+ """
63
+ if data:
64
+ try:
65
+ return YAML.load(check_yaml(data))["names"]
66
+ except Exception:
67
+ pass
68
+ return {i: f"class{i}" for i in range(999)} # return default if above errors
69
+
70
+
71
+ class AutoBackend(nn.Module):
72
+ """Handle dynamic backend selection for running inference using Ultralytics YOLO models.
73
+
74
+ The AutoBackend class is designed to provide an abstraction layer for various inference engines. It supports a wide
75
+ range of formats, each with specific naming conventions as outlined below:
76
+
77
+ Supported Formats and Naming Conventions:
78
+ | Format | File Suffix |
79
+ | --------------------- | ----------------- |
80
+ | PyTorch | *.pt |
81
+ | TorchScript | *.torchscript |
82
+ | ONNX Runtime | *.onnx |
83
+ | ONNX OpenCV DNN | *.onnx (dnn=True) |
84
+ | OpenVINO | *openvino_model/ |
85
+ | CoreML | *.mlpackage |
86
+ | TensorRT | *.engine |
87
+ | TensorFlow SavedModel | *_saved_model/ |
88
+ | TensorFlow GraphDef | *.pb |
89
+ | TensorFlow Lite | *.tflite |
90
+ | TensorFlow Edge TPU | *_edgetpu.tflite |
91
+ | PaddlePaddle | *_paddle_model/ |
92
+ | MNN | *.mnn |
93
+ | NCNN | *_ncnn_model/ |
94
+ | IMX | *_imx_model/ |
95
+ | RKNN | *_rknn_model/ |
96
+ | Triton Inference | triton://model |
97
+ | ExecuTorch | *.pte |
98
+ | Axelera | *_axelera_model/ |
99
+
100
+ Attributes:
101
+ model (torch.nn.Module): The loaded YOLO model.
102
+ device (torch.device): The device (CPU or GPU) on which the model is loaded.
103
+ task (str): The type of task the model performs (detect, segment, classify, pose).
104
+ names (dict): A dictionary of class names that the model can detect.
105
+ stride (int): The model stride, typically 32 for YOLO models.
106
+ fp16 (bool): Whether the model uses half-precision (FP16) inference.
107
+ nhwc (bool): Whether the model expects NHWC input format instead of NCHW.
108
+ pt (bool): Whether the model is a PyTorch model.
109
+ jit (bool): Whether the model is a TorchScript model.
110
+ onnx (bool): Whether the model is an ONNX model.
111
+ xml (bool): Whether the model is an OpenVINO model.
112
+ engine (bool): Whether the model is a TensorRT engine.
113
+ coreml (bool): Whether the model is a CoreML model.
114
+ saved_model (bool): Whether the model is a TensorFlow SavedModel.
115
+ pb (bool): Whether the model is a TensorFlow GraphDef.
116
+ tflite (bool): Whether the model is a TensorFlow Lite model.
117
+ edgetpu (bool): Whether the model is a TensorFlow Edge TPU model.
118
+ tfjs (bool): Whether the model is a TensorFlow.js model.
119
+ paddle (bool): Whether the model is a PaddlePaddle model.
120
+ mnn (bool): Whether the model is an MNN model.
121
+ ncnn (bool): Whether the model is an NCNN model.
122
+ imx (bool): Whether the model is an IMX model.
123
+ rknn (bool): Whether the model is an RKNN model.
124
+ triton (bool): Whether the model is a Triton Inference Server model.
125
+ pte (bool): Whether the model is a PyTorch ExecuTorch model.
126
+ axelera (bool): Whether the model is an Axelera model.
127
+
128
+ Methods:
129
+ forward: Run inference on an input image.
130
+ from_numpy: Convert NumPy arrays to tensors on the model device.
131
+ warmup: Warm up the model with a dummy input.
132
+ _model_type: Determine the model type from file path.
133
+
134
+ Examples:
135
+ >>> model = AutoBackend(model="yolo11n.pt", device="cuda")
136
+ >>> results = model(img)
137
+ """
138
+
139
+ @torch.no_grad()
140
+ def __init__(
141
+ self,
142
+ model: str | torch.nn.Module = "yolo11n.pt",
143
+ device: torch.device = torch.device("cpu"),
144
+ dnn: bool = False,
145
+ data: str | Path | None = None,
146
+ fp16: bool = False,
147
+ fuse: bool = True,
148
+ verbose: bool = True,
149
+ ):
150
+ """Initialize the AutoBackend for inference.
151
+
152
+ Args:
153
+ model (str | torch.nn.Module): Path to the model weights file or a module instance.
154
+ device (torch.device): Device to run the model on.
155
+ dnn (bool): Use OpenCV DNN module for ONNX inference.
156
+ data (str | Path, optional): Path to the additional data.yaml file containing class names.
157
+ fp16 (bool): Enable half-precision inference. Supported only on specific backends.
158
+ fuse (bool): Fuse Conv2D + BatchNorm layers for optimization.
159
+ verbose (bool): Enable verbose logging.
160
+ """
161
+ super().__init__()
162
+ nn_module = isinstance(model, torch.nn.Module)
163
+ (
164
+ pt,
165
+ jit,
166
+ onnx,
167
+ xml,
168
+ engine,
169
+ coreml,
170
+ saved_model,
171
+ pb,
172
+ tflite,
173
+ edgetpu,
174
+ tfjs,
175
+ paddle,
176
+ mnn,
177
+ ncnn,
178
+ imx,
179
+ rknn,
180
+ pte,
181
+ axelera,
182
+ triton,
183
+ ) = self._model_type("" if nn_module else model)
184
+ fp16 &= pt or jit or onnx or xml or engine or nn_module or triton # FP16
185
+ nhwc = coreml or saved_model or pb or tflite or edgetpu or rknn # BHWC formats (vs torch BCHW)
186
+ stride, ch = 32, 3 # default stride and channels
187
+ end2end, dynamic = False, False
188
+ metadata, task = None, None
189
+
190
+ # Set device
191
+ cuda = isinstance(device, torch.device) and torch.cuda.is_available() and device.type != "cpu" # use CUDA
192
+ if cuda and not any([nn_module, pt, jit, engine, onnx, paddle]): # GPU dataloader formats
193
+ device = torch.device("cpu")
194
+ cuda = False
195
+
196
+ # Download if not local
197
+ w = attempt_download_asset(model) if pt else model # weights path
198
+
199
+ # PyTorch (in-memory or file)
200
+ if nn_module or pt:
201
+ if nn_module:
202
+ pt = True
203
+ if fuse:
204
+ if IS_JETSON and is_jetson(jetpack=5):
205
+ # Jetson Jetpack5 requires device before fuse https://github.com/ultralytics/ultralytics/pull/21028
206
+ model = model.to(device)
207
+ model = model.fuse(verbose=verbose)
208
+ model = model.to(device)
209
+ else: # pt file
210
+ from ultralytics.nn.tasks import load_checkpoint
211
+
212
+ model, _ = load_checkpoint(model, device=device, fuse=fuse) # load model, ckpt
213
+
214
+ # Common PyTorch model processing
215
+ if hasattr(model, "kpt_shape"):
216
+ kpt_shape = model.kpt_shape # pose-only
217
+ stride = max(int(model.stride.max()), 32) # model stride
218
+ names = model.module.names if hasattr(model, "module") else model.names # get class names
219
+ model.half() if fp16 else model.float()
220
+ ch = model.yaml.get("channels", 3)
221
+ for p in model.parameters():
222
+ p.requires_grad = False
223
+ self.model = model # explicitly assign for to(), cpu(), cuda(), half()
224
+
225
+ # TorchScript
226
+ elif jit:
227
+ import torchvision # noqa - https://github.com/ultralytics/ultralytics/pull/19747
228
+
229
+ LOGGER.info(f"Loading {w} for TorchScript inference...")
230
+ extra_files = {"config.txt": ""} # model metadata
231
+ model = torch.jit.load(w, _extra_files=extra_files, map_location=device)
232
+ model.half() if fp16 else model.float()
233
+ if extra_files["config.txt"]: # load metadata dict
234
+ metadata = json.loads(extra_files["config.txt"], object_hook=lambda x: dict(x.items()))
235
+
236
+ # ONNX OpenCV DNN
237
+ elif dnn:
238
+ LOGGER.info(f"Loading {w} for ONNX OpenCV DNN inference...")
239
+ check_requirements("opencv-python>=4.5.4")
240
+ net = cv2.dnn.readNetFromONNX(w)
241
+
242
+ # ONNX Runtime and IMX
243
+ elif onnx or imx:
244
+ LOGGER.info(f"Loading {w} for ONNX Runtime inference...")
245
+ check_requirements(("onnx", "onnxruntime-gpu" if cuda else "onnxruntime"))
246
+ import onnxruntime
247
+
248
+ # Select execution provider: CUDA > CoreML (mps) > CPU
249
+ available = onnxruntime.get_available_providers()
250
+ if cuda and "CUDAExecutionProvider" in available:
251
+ providers = [("CUDAExecutionProvider", {"device_id": device.index}), "CPUExecutionProvider"]
252
+ elif device.type == "mps" and "CoreMLExecutionProvider" in available:
253
+ providers = ["CoreMLExecutionProvider", "CPUExecutionProvider"]
254
+ else:
255
+ providers = ["CPUExecutionProvider"]
256
+ if cuda:
257
+ LOGGER.warning("CUDA requested but CUDAExecutionProvider not available. Using CPU...")
258
+ device, cuda = torch.device("cpu"), False
259
+ LOGGER.info(
260
+ f"Using ONNX Runtime {onnxruntime.__version__} with {providers[0] if isinstance(providers[0], str) else providers[0][0]}"
261
+ )
262
+ if onnx:
263
+ session = onnxruntime.InferenceSession(w, providers=providers)
264
+ else:
265
+ check_requirements(("model-compression-toolkit>=2.4.1", "edge-mdt-cl<1.1.0", "onnxruntime-extensions"))
266
+ w = next(Path(w).glob("*.onnx"))
267
+ LOGGER.info(f"Loading {w} for ONNX IMX inference...")
268
+ import mct_quantizers as mctq
269
+ from edgemdt_cl.pytorch.nms import nms_ort # noqa - register custom NMS ops
270
+
271
+ session_options = mctq.get_ort_session_options()
272
+ session_options.enable_mem_reuse = False # fix the shape mismatch from onnxruntime
273
+ session = onnxruntime.InferenceSession(w, session_options, providers=["CPUExecutionProvider"])
274
+
275
+ output_names = [x.name for x in session.get_outputs()]
276
+ metadata = session.get_modelmeta().custom_metadata_map
277
+ dynamic = isinstance(session.get_outputs()[0].shape[0], str)
278
+ fp16 = "float16" in session.get_inputs()[0].type
279
+
280
+ # Setup IO binding for optimized inference (CUDA only, not supported for CoreML)
281
+ use_io_binding = not dynamic and cuda
282
+ if use_io_binding:
283
+ io = session.io_binding()
284
+ bindings = []
285
+ for output in session.get_outputs():
286
+ out_fp16 = "float16" in output.type
287
+ y_tensor = torch.empty(output.shape, dtype=torch.float16 if out_fp16 else torch.float32).to(device)
288
+ io.bind_output(
289
+ name=output.name,
290
+ device_type=device.type,
291
+ device_id=device.index if cuda else 0,
292
+ element_type=np.float16 if out_fp16 else np.float32,
293
+ shape=tuple(y_tensor.shape),
294
+ buffer_ptr=y_tensor.data_ptr(),
295
+ )
296
+ bindings.append(y_tensor)
297
+
298
+ # OpenVINO
299
+ elif xml:
300
+ LOGGER.info(f"Loading {w} for OpenVINO inference...")
301
+ check_requirements("openvino>=2024.0.0")
302
+ import openvino as ov
303
+
304
+ core = ov.Core()
305
+ device_name = "AUTO"
306
+ if isinstance(device, str) and device.startswith("intel"):
307
+ device_name = device.split(":")[1].upper() # Intel OpenVINO device
308
+ device = torch.device("cpu")
309
+ if device_name not in core.available_devices:
310
+ LOGGER.warning(f"OpenVINO device '{device_name}' not available. Using 'AUTO' instead.")
311
+ device_name = "AUTO"
312
+ w = Path(w)
313
+ if not w.is_file(): # if not *.xml
314
+ w = next(w.glob("*.xml")) # get *.xml file from *_openvino_model dir
315
+ ov_model = core.read_model(model=str(w), weights=w.with_suffix(".bin"))
316
+ if ov_model.get_parameters()[0].get_layout().empty:
317
+ ov_model.get_parameters()[0].set_layout(ov.Layout("NCHW"))
318
+
319
+ metadata = w.parent / "metadata.yaml"
320
+ if metadata.exists():
321
+ metadata = YAML.load(metadata)
322
+ batch = metadata["batch"]
323
+ dynamic = metadata.get("args", {}).get("dynamic", dynamic)
324
+ # OpenVINO inference modes are 'LATENCY', 'THROUGHPUT' (not recommended), or 'CUMULATIVE_THROUGHPUT'
325
+ inference_mode = "CUMULATIVE_THROUGHPUT" if batch > 1 and dynamic else "LATENCY"
326
+ ov_compiled_model = core.compile_model(
327
+ ov_model,
328
+ device_name=device_name,
329
+ config={"PERFORMANCE_HINT": inference_mode},
330
+ )
331
+ LOGGER.info(
332
+ f"Using OpenVINO {inference_mode} mode for batch={batch} inference on {', '.join(ov_compiled_model.get_property('EXECUTION_DEVICES'))}..."
333
+ )
334
+ input_name = ov_compiled_model.input().get_any_name()
335
+
336
+ # TensorRT
337
+ elif engine:
338
+ LOGGER.info(f"Loading {w} for TensorRT inference...")
339
+
340
+ if IS_JETSON and check_version(PYTHON_VERSION, "<=3.8.10"):
341
+ # fix error: `np.bool` was a deprecated alias for the builtin `bool` for JetPack 4 and JetPack 5 with Python <= 3.8.10
342
+ check_requirements("numpy==1.23.5")
343
+
344
+ try: # https://developer.nvidia.com/nvidia-tensorrt-download
345
+ import tensorrt as trt
346
+ except ImportError:
347
+ if LINUX:
348
+ check_requirements("tensorrt>7.0.0,!=10.1.0")
349
+ import tensorrt as trt
350
+ check_version(trt.__version__, ">=7.0.0", hard=True)
351
+ check_version(trt.__version__, "!=10.1.0", msg="https://github.com/ultralytics/ultralytics/pull/14239")
352
+ if device.type == "cpu":
353
+ device = torch.device("cuda:0")
354
+ Binding = namedtuple("Binding", ("name", "dtype", "shape", "data", "ptr"))
355
+ logger = trt.Logger(trt.Logger.INFO)
356
+ # Read file
357
+ with open(w, "rb") as f, trt.Runtime(logger) as runtime:
358
+ try:
359
+ meta_len = int.from_bytes(f.read(4), byteorder="little") # read metadata length
360
+ metadata = json.loads(f.read(meta_len).decode("utf-8")) # read metadata
361
+ dla = metadata.get("dla", None)
362
+ if dla is not None:
363
+ runtime.DLA_core = int(dla)
364
+ except UnicodeDecodeError:
365
+ f.seek(0) # engine file may lack embedded Ultralytics metadata
366
+ model = runtime.deserialize_cuda_engine(f.read()) # read engine
367
+
368
+ # Model context
369
+ try:
370
+ context = model.create_execution_context()
371
+ except Exception as e: # model is None
372
+ LOGGER.error(f"TensorRT model exported with a different version than {trt.__version__}\n")
373
+ raise e
374
+
375
+ bindings = OrderedDict()
376
+ output_names = []
377
+ fp16 = False # default updated below
378
+ dynamic = False
379
+ is_trt10 = not hasattr(model, "num_bindings")
380
+ num = range(model.num_io_tensors) if is_trt10 else range(model.num_bindings)
381
+ for i in num:
382
+ # Get tensor info using TRT10+ or legacy API
383
+ if is_trt10:
384
+ name = model.get_tensor_name(i)
385
+ dtype = trt.nptype(model.get_tensor_dtype(name))
386
+ is_input = model.get_tensor_mode(name) == trt.TensorIOMode.INPUT
387
+ shape = tuple(model.get_tensor_shape(name))
388
+ profile_shape = tuple(model.get_tensor_profile_shape(name, 0)[2]) if is_input else None
389
+ else:
390
+ name = model.get_binding_name(i)
391
+ dtype = trt.nptype(model.get_binding_dtype(i))
392
+ is_input = model.binding_is_input(i)
393
+ shape = tuple(model.get_binding_shape(i))
394
+ profile_shape = tuple(model.get_profile_shape(0, i)[1]) if is_input else None
395
+
396
+ # Process input/output tensors
397
+ if is_input:
398
+ if -1 in shape:
399
+ dynamic = True
400
+ if is_trt10:
401
+ context.set_input_shape(name, profile_shape)
402
+ else:
403
+ context.set_binding_shape(i, profile_shape)
404
+ if dtype == np.float16:
405
+ fp16 = True
406
+ else:
407
+ output_names.append(name)
408
+ shape = tuple(context.get_tensor_shape(name)) if is_trt10 else tuple(context.get_binding_shape(i))
409
+ im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
410
+ bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
411
+ binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
412
+
413
+ # CoreML
414
+ elif coreml:
415
+ check_requirements(
416
+ ["coremltools>=9.0", "numpy>=1.14.5,<=2.3.5"]
417
+ ) # latest numpy 2.4.0rc1 breaks coremltools exports
418
+ LOGGER.info(f"Loading {w} for CoreML inference...")
419
+ import coremltools as ct
420
+
421
+ model = ct.models.MLModel(w)
422
+ dynamic = model.get_spec().description.input[0].type.HasField("multiArrayType")
423
+ metadata = dict(model.user_defined_metadata)
424
+
425
+ # TF SavedModel
426
+ elif saved_model:
427
+ LOGGER.info(f"Loading {w} for TensorFlow SavedModel inference...")
428
+ import tensorflow as tf
429
+
430
+ model = tf.saved_model.load(w)
431
+ metadata = Path(w) / "metadata.yaml"
432
+
433
+ # TF GraphDef
434
+ elif pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
435
+ LOGGER.info(f"Loading {w} for TensorFlow GraphDef inference...")
436
+ import tensorflow as tf
437
+
438
+ from ultralytics.utils.export.tensorflow import gd_outputs
439
+
440
+ def wrap_frozen_graph(gd, inputs, outputs):
441
+ """Wrap frozen graphs for deployment."""
442
+ x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped
443
+ ge = x.graph.as_graph_element
444
+ return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
445
+
446
+ gd = tf.Graph().as_graph_def() # TF GraphDef
447
+ with open(w, "rb") as f:
448
+ gd.ParseFromString(f.read())
449
+ frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd))
450
+ try: # find metadata in SavedModel alongside GraphDef
451
+ metadata = next(Path(w).resolve().parent.rglob(f"{Path(w).stem}_saved_model*/metadata.yaml"))
452
+ except StopIteration:
453
+ pass
454
+
455
+ # TFLite or TFLite Edge TPU
456
+ elif tflite or edgetpu: # https://ai.google.dev/edge/litert/microcontrollers/python
457
+ try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
458
+ from tflite_runtime.interpreter import Interpreter, load_delegate
459
+ except ImportError:
460
+ import tensorflow as tf
461
+
462
+ Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate
463
+ if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime
464
+ device = device[3:] if str(device).startswith("tpu") else ":0"
465
+ LOGGER.info(f"Loading {w} on device {device[1:]} for TensorFlow Lite Edge TPU inference...")
466
+ delegate = {"Linux": "libedgetpu.so.1", "Darwin": "libedgetpu.1.dylib", "Windows": "edgetpu.dll"}[
467
+ platform.system()
468
+ ]
469
+ interpreter = Interpreter(
470
+ model_path=w,
471
+ experimental_delegates=[load_delegate(delegate, options={"device": device})],
472
+ )
473
+ device = "cpu" # Required, otherwise PyTorch will try to use the wrong device
474
+ else: # TFLite
475
+ LOGGER.info(f"Loading {w} for TensorFlow Lite inference...")
476
+ interpreter = Interpreter(model_path=w) # load TFLite model
477
+ interpreter.allocate_tensors() # allocate
478
+ input_details = interpreter.get_input_details() # inputs
479
+ output_details = interpreter.get_output_details() # outputs
480
+ # Load metadata
481
+ try:
482
+ with zipfile.ZipFile(w, "r") as zf:
483
+ name = zf.namelist()[0]
484
+ contents = zf.read(name).decode("utf-8")
485
+ if name == "metadata.json": # Custom Ultralytics metadata dict for Python>=3.12
486
+ metadata = json.loads(contents)
487
+ else:
488
+ metadata = ast.literal_eval(contents) # Default tflite-support metadata for Python<=3.11
489
+ except (zipfile.BadZipFile, SyntaxError, ValueError, json.JSONDecodeError):
490
+ pass
491
+
492
+ # TF.js
493
+ elif tfjs:
494
+ raise NotImplementedError("Ultralytics TF.js inference is not currently supported.")
495
+
496
+ # PaddlePaddle
497
+ elif paddle:
498
+ LOGGER.info(f"Loading {w} for PaddlePaddle inference...")
499
+ check_requirements(
500
+ "paddlepaddle-gpu"
501
+ if torch.cuda.is_available()
502
+ else "paddlepaddle==3.0.0" # pin 3.0.0 for ARM64
503
+ if ARM64
504
+ else "paddlepaddle>=3.0.0"
505
+ )
506
+ import paddle.inference as pdi
507
+
508
+ w = Path(w)
509
+ model_file, params_file = None, None
510
+ if w.is_dir():
511
+ model_file = next(w.rglob("*.json"), None)
512
+ params_file = next(w.rglob("*.pdiparams"), None)
513
+ elif w.suffix == ".pdiparams":
514
+ model_file = w.with_name("model.json")
515
+ params_file = w
516
+
517
+ if not (model_file and params_file and model_file.is_file() and params_file.is_file()):
518
+ raise FileNotFoundError(f"Paddle model not found in {w}. Both .json and .pdiparams files are required.")
519
+
520
+ config = pdi.Config(str(model_file), str(params_file))
521
+ if cuda:
522
+ config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0)
523
+ predictor = pdi.create_predictor(config)
524
+ input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
525
+ output_names = predictor.get_output_names()
526
+ metadata = w / "metadata.yaml"
527
+
528
+ # MNN
529
+ elif mnn:
530
+ LOGGER.info(f"Loading {w} for MNN inference...")
531
+ check_requirements("MNN") # requires MNN
532
+ import os
533
+
534
+ import MNN
535
+
536
+ config = {"precision": "low", "backend": "CPU", "numThread": (os.cpu_count() + 1) // 2}
537
+ rt = MNN.nn.create_runtime_manager((config,))
538
+ net = MNN.nn.load_module_from_file(w, [], [], runtime_manager=rt, rearrange=True)
539
+
540
+ def torch_to_mnn(x):
541
+ return MNN.expr.const(x.data_ptr(), x.shape)
542
+
543
+ metadata = json.loads(net.get_info()["bizCode"])
544
+
545
+ # NCNN
546
+ elif ncnn:
547
+ LOGGER.info(f"Loading {w} for NCNN inference...")
548
+ check_requirements("git+https://github.com/Tencent/ncnn.git" if ARM64 else "ncnn", cmds="--no-deps")
549
+ import ncnn as pyncnn
550
+
551
+ net = pyncnn.Net()
552
+ net.opt.use_vulkan_compute = cuda
553
+ w = Path(w)
554
+ if not w.is_file(): # if not *.param
555
+ w = next(w.glob("*.param")) # get *.param file from *_ncnn_model dir
556
+ net.load_param(str(w))
557
+ net.load_model(str(w.with_suffix(".bin")))
558
+ metadata = w.parent / "metadata.yaml"
559
+
560
+ # NVIDIA Triton Inference Server
561
+ elif triton:
562
+ check_requirements("tritonclient[all]")
563
+ from ultralytics.utils.triton import TritonRemoteModel
564
+
565
+ model = TritonRemoteModel(w)
566
+ metadata = model.metadata
567
+
568
+ # RKNN
569
+ elif rknn:
570
+ if not is_rockchip():
571
+ raise OSError("RKNN inference is only supported on Rockchip devices.")
572
+ LOGGER.info(f"Loading {w} for RKNN inference...")
573
+ check_requirements("rknn-toolkit-lite2")
574
+ from rknnlite.api import RKNNLite
575
+
576
+ w = Path(w)
577
+ if not w.is_file(): # if not *.rknn
578
+ w = next(w.rglob("*.rknn")) # get *.rknn file from *_rknn_model dir
579
+ rknn_model = RKNNLite()
580
+ rknn_model.load_rknn(str(w))
581
+ rknn_model.init_runtime()
582
+ metadata = w.parent / "metadata.yaml"
583
+
584
+ # Axelera
585
+ elif axelera:
586
+ import os
587
+
588
+ if not os.environ.get("AXELERA_RUNTIME_DIR"):
589
+ LOGGER.warning(
590
+ "Axelera runtime environment is not activated."
591
+ "\nPlease run: source /opt/axelera/sdk/latest/axelera_activate.sh"
592
+ "\n\nIf this fails, verify driver installation: https://docs.ultralytics.com/integrations/axelera/#axelera-driver-installation"
593
+ )
594
+ try:
595
+ from axelera.runtime import op
596
+ except ImportError:
597
+ check_requirements(
598
+ "axelera_runtime2==0.1.2",
599
+ cmds="--extra-index-url https://software.axelera.ai/artifactory/axelera-runtime-pypi",
600
+ )
601
+ from axelera.runtime import op
602
+
603
+ w = Path(w)
604
+ if (found := next(w.rglob("*.axm"), None)) is None:
605
+ raise FileNotFoundError(f"No .axm file found in: {w}")
606
+
607
+ ax_model = op.load(str(found))
608
+ metadata = found.parent / "metadata.yaml"
609
+
610
+ # ExecuTorch
611
+ elif pte:
612
+ LOGGER.info(f"Loading {w} for ExecuTorch inference...")
613
+ # TorchAO release compatibility table bug https://github.com/pytorch/ao/issues/2919
614
+ check_requirements("setuptools<71.0.0") # Setuptools bug: https://github.com/pypa/setuptools/issues/4483
615
+ check_requirements(("executorch==1.0.1", "flatbuffers"))
616
+ from executorch.runtime import Runtime
617
+
618
+ w = Path(w)
619
+ if w.is_dir():
620
+ model_file = next(w.rglob("*.pte"))
621
+ metadata = w / "metadata.yaml"
622
+ else:
623
+ model_file = w
624
+ metadata = w.parent / "metadata.yaml"
625
+
626
+ program = Runtime.get().load_program(str(model_file))
627
+ model = program.load_method("forward")
628
+
629
+ # Any other format (unsupported)
630
+ else:
631
+ from ultralytics.engine.exporter import export_formats
632
+
633
+ raise TypeError(
634
+ f"model='{w}' is not a supported model format. Ultralytics supports: {export_formats()['Format']}\n"
635
+ f"See https://docs.ultralytics.com/modes/predict for help."
636
+ )
637
+
638
+ # Load external metadata YAML
639
+ if isinstance(metadata, (str, Path)) and Path(metadata).exists():
640
+ metadata = YAML.load(metadata)
641
+ if metadata and isinstance(metadata, dict):
642
+ for k, v in metadata.items():
643
+ if k in {"stride", "batch", "channels"}:
644
+ metadata[k] = int(v)
645
+ elif k in {"imgsz", "names", "kpt_shape", "kpt_names", "args"} and isinstance(v, str):
646
+ metadata[k] = ast.literal_eval(v)
647
+ stride = metadata["stride"]
648
+ task = metadata["task"]
649
+ batch = metadata["batch"]
650
+ imgsz = metadata["imgsz"]
651
+ names = metadata["names"]
652
+ kpt_shape = metadata.get("kpt_shape")
653
+ kpt_names = metadata.get("kpt_names")
654
+ end2end = metadata.get("args", {}).get("nms", False)
655
+ dynamic = metadata.get("args", {}).get("dynamic", dynamic)
656
+ ch = metadata.get("channels", 3)
657
+ elif not (pt or triton or nn_module):
658
+ LOGGER.warning(f"Metadata not found for 'model={w}'")
659
+
660
+ # Check names
661
+ if "names" not in locals(): # names missing
662
+ names = default_class_names(data)
663
+ names = check_class_names(names)
664
+
665
+ self.__dict__.update(locals()) # assign all variables to self
666
+
667
+ def forward(
668
+ self,
669
+ im: torch.Tensor,
670
+ augment: bool = False,
671
+ visualize: bool = False,
672
+ embed: list | None = None,
673
+ **kwargs: Any,
674
+ ) -> torch.Tensor | list[torch.Tensor]:
675
+ """Run inference on an AutoBackend model.
676
+
677
+ Args:
678
+ im (torch.Tensor): The image tensor to perform inference on.
679
+ augment (bool): Whether to perform data augmentation during inference.
680
+ visualize (bool): Whether to visualize the output predictions.
681
+ embed (list, optional): A list of feature vectors/embeddings to return.
682
+ **kwargs (Any): Additional keyword arguments for model configuration.
683
+
684
+ Returns:
685
+ (torch.Tensor | list[torch.Tensor]): The raw output tensor(s) from the model.
686
+ """
687
+ _b, _ch, h, w = im.shape # batch, channel, height, width
688
+ if self.fp16 and im.dtype != torch.float16:
689
+ im = im.half() # to FP16
690
+ if self.nhwc:
691
+ im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3)
692
+
693
+ # PyTorch
694
+ if self.pt or self.nn_module:
695
+ y = self.model(im, augment=augment, visualize=visualize, embed=embed, **kwargs)
696
+
697
+ # TorchScript
698
+ elif self.jit:
699
+ y = self.model(im)
700
+
701
+ # ONNX OpenCV DNN
702
+ elif self.dnn:
703
+ im = im.cpu().numpy() # torch to numpy
704
+ self.net.setInput(im)
705
+ y = self.net.forward()
706
+
707
+ # ONNX Runtime
708
+ elif self.onnx or self.imx:
709
+ if self.use_io_binding:
710
+ if not self.cuda:
711
+ im = im.cpu()
712
+ self.io.bind_input(
713
+ name="images",
714
+ device_type=im.device.type,
715
+ device_id=im.device.index if im.device.type == "cuda" else 0,
716
+ element_type=np.float16 if self.fp16 else np.float32,
717
+ shape=tuple(im.shape),
718
+ buffer_ptr=im.data_ptr(),
719
+ )
720
+ self.session.run_with_iobinding(self.io)
721
+ y = self.bindings
722
+ else:
723
+ im = im.cpu().numpy() # torch to numpy
724
+ y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})
725
+ if self.imx:
726
+ if self.task == "detect":
727
+ # boxes, conf, cls
728
+ y = np.concatenate([y[0], y[1][:, :, None], y[2][:, :, None]], axis=-1)
729
+ elif self.task == "pose":
730
+ # boxes, conf, kpts
731
+ y = np.concatenate([y[0], y[1][:, :, None], y[2][:, :, None], y[3]], axis=-1, dtype=y[0].dtype)
732
+ elif self.task == "segment":
733
+ y = (
734
+ np.concatenate([y[0], y[1][:, :, None], y[2][:, :, None], y[3]], axis=-1, dtype=y[0].dtype),
735
+ y[4],
736
+ )
737
+
738
+ # OpenVINO
739
+ elif self.xml:
740
+ im = im.cpu().numpy() # FP32
741
+
742
+ if self.inference_mode in {"THROUGHPUT", "CUMULATIVE_THROUGHPUT"}: # optimized for larger batch-sizes
743
+ n = im.shape[0] # number of images in batch
744
+ results = [None] * n # preallocate list with None to match the number of images
745
+
746
+ def callback(request, userdata):
747
+ """Place result in preallocated list using userdata index."""
748
+ results[userdata] = request.results
749
+
750
+ # Create AsyncInferQueue, set the callback and start asynchronous inference for each input image
751
+ async_queue = self.ov.AsyncInferQueue(self.ov_compiled_model)
752
+ async_queue.set_callback(callback)
753
+ for i in range(n):
754
+ # Start async inference with userdata=i to specify the position in results list
755
+ async_queue.start_async(inputs={self.input_name: im[i : i + 1]}, userdata=i) # keep image as BCHW
756
+ async_queue.wait_all() # wait for all inference requests to complete
757
+ y = [list(r.values()) for r in results]
758
+ y = [np.concatenate(x) for x in zip(*y)]
759
+ else: # inference_mode = "LATENCY", optimized for fastest first result at batch-size 1
760
+ y = list(self.ov_compiled_model(im).values())
761
+
762
+ # TensorRT
763
+ elif self.engine:
764
+ if self.dynamic and im.shape != self.bindings["images"].shape:
765
+ if self.is_trt10:
766
+ self.context.set_input_shape("images", im.shape)
767
+ self.bindings["images"] = self.bindings["images"]._replace(shape=im.shape)
768
+ for name in self.output_names:
769
+ self.bindings[name].data.resize_(tuple(self.context.get_tensor_shape(name)))
770
+ else:
771
+ i = self.model.get_binding_index("images")
772
+ self.context.set_binding_shape(i, im.shape)
773
+ self.bindings["images"] = self.bindings["images"]._replace(shape=im.shape)
774
+ for name in self.output_names:
775
+ i = self.model.get_binding_index(name)
776
+ self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i)))
777
+
778
+ s = self.bindings["images"].shape
779
+ assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}"
780
+ self.binding_addrs["images"] = int(im.data_ptr())
781
+ self.context.execute_v2(list(self.binding_addrs.values()))
782
+ y = [self.bindings[x].data for x in sorted(self.output_names)]
783
+
784
+ # CoreML
785
+ elif self.coreml:
786
+ im = im.cpu().numpy()
787
+ if self.dynamic:
788
+ im = im.transpose(0, 3, 1, 2)
789
+ else:
790
+ im = Image.fromarray((im[0] * 255).astype("uint8"))
791
+ # im = im.resize((192, 320), Image.BILINEAR)
792
+ y = self.model.predict({"image": im}) # coordinates are xywh normalized
793
+ if "confidence" in y: # NMS included
794
+ from ultralytics.utils.ops import xywh2xyxy
795
+
796
+ box = xywh2xyxy(y["coordinates"] * [[w, h, w, h]]) # xyxy pixels
797
+ cls = y["confidence"].argmax(1, keepdims=True)
798
+ y = np.concatenate((box, np.take_along_axis(y["confidence"], cls, axis=1), cls), 1)[None]
799
+ else:
800
+ y = list(y.values())
801
+ if len(y) == 2 and len(y[1].shape) != 4: # segmentation model
802
+ y = list(reversed(y)) # reversed for segmentation models (pred, proto)
803
+
804
+ # PaddlePaddle
805
+ elif self.paddle:
806
+ im = im.cpu().numpy().astype(np.float32)
807
+ self.input_handle.copy_from_cpu(im)
808
+ self.predictor.run()
809
+ y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names]
810
+
811
+ # MNN
812
+ elif self.mnn:
813
+ input_var = self.torch_to_mnn(im)
814
+ output_var = self.net.onForward([input_var])
815
+ y = [x.read() for x in output_var]
816
+
817
+ # NCNN
818
+ elif self.ncnn:
819
+ mat_in = self.pyncnn.Mat(im[0].cpu().numpy())
820
+ with self.net.create_extractor() as ex:
821
+ ex.input(self.net.input_names()[0], mat_in)
822
+ # WARNING: 'output_names' sorted as a temporary fix for https://github.com/pnnx/pnnx/issues/130
823
+ y = [np.array(ex.extract(x)[1])[None] for x in sorted(self.net.output_names())]
824
+
825
+ # NVIDIA Triton Inference Server
826
+ elif self.triton:
827
+ im = im.cpu().numpy() # torch to numpy
828
+ y = self.model(im)
829
+
830
+ # RKNN
831
+ elif self.rknn:
832
+ im = (im.cpu().numpy() * 255).astype("uint8")
833
+ im = im if isinstance(im, (list, tuple)) else [im]
834
+ y = self.rknn_model.inference(inputs=im)
835
+
836
+ # Axelera
837
+ elif self.axelera:
838
+ y = self.ax_model(im.cpu())
839
+
840
+ # ExecuTorch
841
+ elif self.pte:
842
+ y = self.model.execute([im])
843
+
844
+ # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
845
+ else:
846
+ im = im.cpu().numpy()
847
+ if self.saved_model: # SavedModel
848
+ y = self.model.serving_default(im)
849
+ if not isinstance(y, list):
850
+ y = [y]
851
+ elif self.pb: # GraphDef
852
+ y = self.frozen_func(x=self.tf.constant(im))
853
+ else: # Lite or Edge TPU
854
+ details = self.input_details[0]
855
+ is_int = details["dtype"] in {np.int8, np.int16} # is TFLite quantized int8 or int16 model
856
+ if is_int:
857
+ scale, zero_point = details["quantization"]
858
+ im = (im / scale + zero_point).astype(details["dtype"]) # de-scale
859
+ self.interpreter.set_tensor(details["index"], im)
860
+ self.interpreter.invoke()
861
+ y = []
862
+ for output in self.output_details:
863
+ x = self.interpreter.get_tensor(output["index"])
864
+ if is_int:
865
+ scale, zero_point = output["quantization"]
866
+ x = (x.astype(np.float32) - zero_point) * scale # re-scale
867
+ if x.ndim == 3: # if task is not classification, excluding masks (ndim=4) as well
868
+ # Denormalize xywh by image size. See https://github.com/ultralytics/ultralytics/pull/1695
869
+ # xywh are normalized in TFLite/EdgeTPU to mitigate quantization error of integer models
870
+ if x.shape[-1] == 6 or self.end2end: # end-to-end model
871
+ x[:, :, [0, 2]] *= w
872
+ x[:, :, [1, 3]] *= h
873
+ if self.task == "pose":
874
+ x[:, :, 6::3] *= w
875
+ x[:, :, 7::3] *= h
876
+ else:
877
+ x[:, [0, 2]] *= w
878
+ x[:, [1, 3]] *= h
879
+ if self.task == "pose":
880
+ x[:, 5::3] *= w
881
+ x[:, 6::3] *= h
882
+ y.append(x)
883
+ # TF segment fixes: export is reversed vs ONNX export and protos are transposed
884
+ if len(y) == 2: # segment with (det, proto) output order reversed
885
+ if len(y[1].shape) != 4:
886
+ y = list(reversed(y)) # should be y = (1, 116, 8400), (1, 160, 160, 32)
887
+ if y[1].shape[-1] == 6: # end-to-end model
888
+ y = [y[1]]
889
+ else:
890
+ y[1] = np.transpose(y[1], (0, 3, 1, 2)) # should be y = (1, 116, 8400), (1, 32, 160, 160)
891
+ y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]
892
+
893
+ if isinstance(y, (list, tuple)):
894
+ if len(self.names) == 999 and (self.task == "segment" or len(y) == 2): # segments and names not defined
895
+ nc = y[0].shape[1] - y[1].shape[1] - 4 # y = (1, 32, 160, 160), (1, 116, 8400)
896
+ self.names = {i: f"class{i}" for i in range(nc)}
897
+ return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y]
898
+ else:
899
+ return self.from_numpy(y)
900
+
901
+ def from_numpy(self, x: np.ndarray | torch.Tensor) -> torch.Tensor:
902
+ """Convert a NumPy array to a torch tensor on the model device.
903
+
904
+ Args:
905
+ x (np.ndarray | torch.Tensor): Input array or tensor.
906
+
907
+ Returns:
908
+ (torch.Tensor): Tensor on `self.device`.
909
+ """
910
+ return torch.tensor(x).to(self.device) if isinstance(x, np.ndarray) else x
911
+
912
+ def warmup(self, imgsz: tuple[int, int, int, int] = (1, 3, 640, 640)) -> None:
913
+ """Warm up the model by running one forward pass with a dummy input.
914
+
915
+ Args:
916
+ imgsz (tuple[int, int, int, int]): Dummy input shape in (batch, channels, height, width) format.
917
+ """
918
+ warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton, self.nn_module
919
+ if any(warmup_types) and (self.device.type != "cpu" or self.triton):
920
+ im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
921
+ for _ in range(2 if self.jit else 1):
922
+ self.forward(im) # warmup model
923
+ warmup_boxes = torch.rand(1, 84, 16, device=self.device) # 16 boxes works best empirically
924
+ warmup_boxes[:, :4] *= imgsz[-1]
925
+ non_max_suppression(warmup_boxes) # warmup NMS
926
+
927
+ @staticmethod
928
+ def _model_type(p: str = "path/to/model.pt") -> list[bool]:
929
+ """Take a path to a model file and return the model type.
930
+
931
+ Args:
932
+ p (str): Path to the model file.
933
+
934
+ Returns:
935
+ (list[bool]): List of booleans indicating the model type.
936
+
937
+ Examples:
938
+ >>> types = AutoBackend._model_type("path/to/model.onnx")
939
+ >>> assert types[2] # onnx
940
+ """
941
+ from ultralytics.engine.exporter import export_formats
942
+
943
+ sf = export_formats()["Suffix"] # export suffixes
944
+ if not is_url(p) and not isinstance(p, str):
945
+ check_suffix(p, sf) # checks
946
+ name = Path(p).name
947
+ types = [s in name for s in sf]
948
+ types[5] |= name.endswith(".mlmodel") # retain support for older Apple CoreML *.mlmodel formats
949
+ types[8] &= not types[9] # tflite &= not edgetpu
950
+ if any(types):
951
+ triton = False
952
+ else:
953
+ from urllib.parse import urlsplit
954
+
955
+ url = urlsplit(p)
956
+ triton = bool(url.netloc) and bool(url.path) and url.scheme in {"http", "grpc"}
957
+
958
+ return [*types, triton]