dgenerate-ultralytics-headless 8.3.134__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (272) hide show
  1. dgenerate_ultralytics_headless-8.3.134.dist-info/METADATA +400 -0
  2. dgenerate_ultralytics_headless-8.3.134.dist-info/RECORD +272 -0
  3. dgenerate_ultralytics_headless-8.3.134.dist-info/WHEEL +5 -0
  4. dgenerate_ultralytics_headless-8.3.134.dist-info/entry_points.txt +3 -0
  5. dgenerate_ultralytics_headless-8.3.134.dist-info/licenses/LICENSE +661 -0
  6. dgenerate_ultralytics_headless-8.3.134.dist-info/top_level.txt +1 -0
  7. tests/__init__.py +22 -0
  8. tests/conftest.py +83 -0
  9. tests/test_cli.py +138 -0
  10. tests/test_cuda.py +215 -0
  11. tests/test_engine.py +131 -0
  12. tests/test_exports.py +236 -0
  13. tests/test_integrations.py +154 -0
  14. tests/test_python.py +694 -0
  15. tests/test_solutions.py +187 -0
  16. ultralytics/__init__.py +30 -0
  17. ultralytics/assets/bus.jpg +0 -0
  18. ultralytics/assets/zidane.jpg +0 -0
  19. ultralytics/cfg/__init__.py +1023 -0
  20. ultralytics/cfg/datasets/Argoverse.yaml +77 -0
  21. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  22. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  23. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  24. ultralytics/cfg/datasets/HomeObjects-3K.yaml +33 -0
  25. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  26. ultralytics/cfg/datasets/Objects365.yaml +443 -0
  27. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  28. ultralytics/cfg/datasets/VOC.yaml +106 -0
  29. ultralytics/cfg/datasets/VisDrone.yaml +77 -0
  30. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  31. ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
  32. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  33. ultralytics/cfg/datasets/coco-pose.yaml +42 -0
  34. ultralytics/cfg/datasets/coco.yaml +118 -0
  35. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  36. ultralytics/cfg/datasets/coco128.yaml +101 -0
  37. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  38. ultralytics/cfg/datasets/coco8-pose.yaml +26 -0
  39. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  40. ultralytics/cfg/datasets/coco8.yaml +101 -0
  41. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  42. ultralytics/cfg/datasets/dog-pose.yaml +24 -0
  43. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  44. ultralytics/cfg/datasets/dota8.yaml +35 -0
  45. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
  46. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  47. ultralytics/cfg/datasets/medical-pills.yaml +22 -0
  48. ultralytics/cfg/datasets/open-images-v7.yaml +666 -0
  49. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  50. ultralytics/cfg/datasets/signature.yaml +21 -0
  51. ultralytics/cfg/datasets/tiger-pose.yaml +25 -0
  52. ultralytics/cfg/datasets/xView.yaml +155 -0
  53. ultralytics/cfg/default.yaml +127 -0
  54. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  55. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  56. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  57. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  58. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  59. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  60. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  61. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  62. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  63. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  64. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  65. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  66. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  67. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  68. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  69. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  70. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  71. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  72. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  73. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  74. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  75. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  76. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  77. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  78. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  79. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  80. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  81. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  82. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  83. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
  84. ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
  85. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  86. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  87. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  88. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  89. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  90. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  91. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  92. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  93. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  94. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  95. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  96. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  97. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  98. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  99. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  100. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  101. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  102. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  103. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  104. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  105. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  106. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  107. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  108. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  109. ultralytics/cfg/trackers/botsort.yaml +22 -0
  110. ultralytics/cfg/trackers/bytetrack.yaml +14 -0
  111. ultralytics/data/__init__.py +26 -0
  112. ultralytics/data/annotator.py +66 -0
  113. ultralytics/data/augment.py +2945 -0
  114. ultralytics/data/base.py +438 -0
  115. ultralytics/data/build.py +258 -0
  116. ultralytics/data/converter.py +754 -0
  117. ultralytics/data/dataset.py +834 -0
  118. ultralytics/data/loaders.py +676 -0
  119. ultralytics/data/scripts/download_weights.sh +18 -0
  120. ultralytics/data/scripts/get_coco.sh +61 -0
  121. ultralytics/data/scripts/get_coco128.sh +18 -0
  122. ultralytics/data/scripts/get_imagenet.sh +52 -0
  123. ultralytics/data/split.py +125 -0
  124. ultralytics/data/split_dota.py +325 -0
  125. ultralytics/data/utils.py +777 -0
  126. ultralytics/engine/__init__.py +1 -0
  127. ultralytics/engine/exporter.py +1519 -0
  128. ultralytics/engine/model.py +1156 -0
  129. ultralytics/engine/predictor.py +502 -0
  130. ultralytics/engine/results.py +1840 -0
  131. ultralytics/engine/trainer.py +853 -0
  132. ultralytics/engine/tuner.py +243 -0
  133. ultralytics/engine/validator.py +377 -0
  134. ultralytics/hub/__init__.py +168 -0
  135. ultralytics/hub/auth.py +137 -0
  136. ultralytics/hub/google/__init__.py +176 -0
  137. ultralytics/hub/session.py +446 -0
  138. ultralytics/hub/utils.py +248 -0
  139. ultralytics/models/__init__.py +9 -0
  140. ultralytics/models/fastsam/__init__.py +7 -0
  141. ultralytics/models/fastsam/model.py +61 -0
  142. ultralytics/models/fastsam/predict.py +181 -0
  143. ultralytics/models/fastsam/utils.py +24 -0
  144. ultralytics/models/fastsam/val.py +40 -0
  145. ultralytics/models/nas/__init__.py +7 -0
  146. ultralytics/models/nas/model.py +102 -0
  147. ultralytics/models/nas/predict.py +58 -0
  148. ultralytics/models/nas/val.py +39 -0
  149. ultralytics/models/rtdetr/__init__.py +7 -0
  150. ultralytics/models/rtdetr/model.py +63 -0
  151. ultralytics/models/rtdetr/predict.py +84 -0
  152. ultralytics/models/rtdetr/train.py +85 -0
  153. ultralytics/models/rtdetr/val.py +191 -0
  154. ultralytics/models/sam/__init__.py +6 -0
  155. ultralytics/models/sam/amg.py +260 -0
  156. ultralytics/models/sam/build.py +358 -0
  157. ultralytics/models/sam/model.py +170 -0
  158. ultralytics/models/sam/modules/__init__.py +1 -0
  159. ultralytics/models/sam/modules/blocks.py +1129 -0
  160. ultralytics/models/sam/modules/decoders.py +515 -0
  161. ultralytics/models/sam/modules/encoders.py +854 -0
  162. ultralytics/models/sam/modules/memory_attention.py +299 -0
  163. ultralytics/models/sam/modules/sam.py +1006 -0
  164. ultralytics/models/sam/modules/tiny_encoder.py +1002 -0
  165. ultralytics/models/sam/modules/transformer.py +351 -0
  166. ultralytics/models/sam/modules/utils.py +394 -0
  167. ultralytics/models/sam/predict.py +1605 -0
  168. ultralytics/models/utils/__init__.py +1 -0
  169. ultralytics/models/utils/loss.py +455 -0
  170. ultralytics/models/utils/ops.py +268 -0
  171. ultralytics/models/yolo/__init__.py +7 -0
  172. ultralytics/models/yolo/classify/__init__.py +7 -0
  173. ultralytics/models/yolo/classify/predict.py +88 -0
  174. ultralytics/models/yolo/classify/train.py +233 -0
  175. ultralytics/models/yolo/classify/val.py +215 -0
  176. ultralytics/models/yolo/detect/__init__.py +7 -0
  177. ultralytics/models/yolo/detect/predict.py +124 -0
  178. ultralytics/models/yolo/detect/train.py +217 -0
  179. ultralytics/models/yolo/detect/val.py +451 -0
  180. ultralytics/models/yolo/model.py +354 -0
  181. ultralytics/models/yolo/obb/__init__.py +7 -0
  182. ultralytics/models/yolo/obb/predict.py +66 -0
  183. ultralytics/models/yolo/obb/train.py +81 -0
  184. ultralytics/models/yolo/obb/val.py +283 -0
  185. ultralytics/models/yolo/pose/__init__.py +7 -0
  186. ultralytics/models/yolo/pose/predict.py +79 -0
  187. ultralytics/models/yolo/pose/train.py +154 -0
  188. ultralytics/models/yolo/pose/val.py +394 -0
  189. ultralytics/models/yolo/segment/__init__.py +7 -0
  190. ultralytics/models/yolo/segment/predict.py +113 -0
  191. ultralytics/models/yolo/segment/train.py +123 -0
  192. ultralytics/models/yolo/segment/val.py +428 -0
  193. ultralytics/models/yolo/world/__init__.py +5 -0
  194. ultralytics/models/yolo/world/train.py +119 -0
  195. ultralytics/models/yolo/world/train_world.py +176 -0
  196. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  197. ultralytics/models/yolo/yoloe/predict.py +169 -0
  198. ultralytics/models/yolo/yoloe/train.py +298 -0
  199. ultralytics/models/yolo/yoloe/train_seg.py +124 -0
  200. ultralytics/models/yolo/yoloe/val.py +191 -0
  201. ultralytics/nn/__init__.py +29 -0
  202. ultralytics/nn/autobackend.py +842 -0
  203. ultralytics/nn/modules/__init__.py +182 -0
  204. ultralytics/nn/modules/activation.py +53 -0
  205. ultralytics/nn/modules/block.py +1966 -0
  206. ultralytics/nn/modules/conv.py +712 -0
  207. ultralytics/nn/modules/head.py +880 -0
  208. ultralytics/nn/modules/transformer.py +713 -0
  209. ultralytics/nn/modules/utils.py +164 -0
  210. ultralytics/nn/tasks.py +1627 -0
  211. ultralytics/nn/text_model.py +351 -0
  212. ultralytics/solutions/__init__.py +41 -0
  213. ultralytics/solutions/ai_gym.py +116 -0
  214. ultralytics/solutions/analytics.py +252 -0
  215. ultralytics/solutions/config.py +106 -0
  216. ultralytics/solutions/distance_calculation.py +124 -0
  217. ultralytics/solutions/heatmap.py +127 -0
  218. ultralytics/solutions/instance_segmentation.py +84 -0
  219. ultralytics/solutions/object_blurrer.py +90 -0
  220. ultralytics/solutions/object_counter.py +195 -0
  221. ultralytics/solutions/object_cropper.py +84 -0
  222. ultralytics/solutions/parking_management.py +273 -0
  223. ultralytics/solutions/queue_management.py +93 -0
  224. ultralytics/solutions/region_counter.py +120 -0
  225. ultralytics/solutions/security_alarm.py +154 -0
  226. ultralytics/solutions/similarity_search.py +172 -0
  227. ultralytics/solutions/solutions.py +724 -0
  228. ultralytics/solutions/speed_estimation.py +110 -0
  229. ultralytics/solutions/streamlit_inference.py +196 -0
  230. ultralytics/solutions/templates/similarity-search.html +160 -0
  231. ultralytics/solutions/trackzone.py +88 -0
  232. ultralytics/solutions/vision_eye.py +68 -0
  233. ultralytics/trackers/__init__.py +7 -0
  234. ultralytics/trackers/basetrack.py +124 -0
  235. ultralytics/trackers/bot_sort.py +260 -0
  236. ultralytics/trackers/byte_tracker.py +480 -0
  237. ultralytics/trackers/track.py +125 -0
  238. ultralytics/trackers/utils/__init__.py +1 -0
  239. ultralytics/trackers/utils/gmc.py +376 -0
  240. ultralytics/trackers/utils/kalman_filter.py +493 -0
  241. ultralytics/trackers/utils/matching.py +157 -0
  242. ultralytics/utils/__init__.py +1435 -0
  243. ultralytics/utils/autobatch.py +106 -0
  244. ultralytics/utils/autodevice.py +174 -0
  245. ultralytics/utils/benchmarks.py +695 -0
  246. ultralytics/utils/callbacks/__init__.py +5 -0
  247. ultralytics/utils/callbacks/base.py +234 -0
  248. ultralytics/utils/callbacks/clearml.py +153 -0
  249. ultralytics/utils/callbacks/comet.py +552 -0
  250. ultralytics/utils/callbacks/dvc.py +205 -0
  251. ultralytics/utils/callbacks/hub.py +108 -0
  252. ultralytics/utils/callbacks/mlflow.py +138 -0
  253. ultralytics/utils/callbacks/neptune.py +140 -0
  254. ultralytics/utils/callbacks/raytune.py +43 -0
  255. ultralytics/utils/callbacks/tensorboard.py +132 -0
  256. ultralytics/utils/callbacks/wb.py +185 -0
  257. ultralytics/utils/checks.py +897 -0
  258. ultralytics/utils/dist.py +119 -0
  259. ultralytics/utils/downloads.py +499 -0
  260. ultralytics/utils/errors.py +43 -0
  261. ultralytics/utils/export.py +219 -0
  262. ultralytics/utils/files.py +221 -0
  263. ultralytics/utils/instance.py +499 -0
  264. ultralytics/utils/loss.py +813 -0
  265. ultralytics/utils/metrics.py +1356 -0
  266. ultralytics/utils/ops.py +885 -0
  267. ultralytics/utils/patches.py +143 -0
  268. ultralytics/utils/plotting.py +1011 -0
  269. ultralytics/utils/tal.py +416 -0
  270. ultralytics/utils/torch_utils.py +990 -0
  271. ultralytics/utils/triton.py +116 -0
  272. ultralytics/utils/tuner.py +159 -0
@@ -0,0 +1,842 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ import ast
4
+ import json
5
+ import platform
6
+ import zipfile
7
+ from collections import OrderedDict, namedtuple
8
+ from pathlib import Path
9
+ from typing import List, Optional, Union
10
+
11
+ import cv2
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ from PIL import Image
16
+
17
+ from ultralytics.utils import ARM64, IS_JETSON, LINUX, LOGGER, PYTHON_VERSION, ROOT, YAML
18
+ from ultralytics.utils.checks import check_requirements, check_suffix, check_version, check_yaml, is_rockchip
19
+ from ultralytics.utils.downloads import attempt_download_asset, is_url
20
+
21
+
22
+ def check_class_names(names):
23
+ """Check class names and convert to dict format if needed."""
24
+ if isinstance(names, list): # names is a list
25
+ names = dict(enumerate(names)) # convert to dict
26
+ if isinstance(names, dict):
27
+ # Convert 1) string keys to int, i.e. '0' to 0, and non-string values to strings, i.e. True to 'True'
28
+ names = {int(k): str(v) for k, v in names.items()}
29
+ n = len(names)
30
+ if max(names.keys()) >= n:
31
+ raise KeyError(
32
+ f"{n}-class dataset requires class indices 0-{n - 1}, but you have invalid class indices "
33
+ f"{min(names.keys())}-{max(names.keys())} defined in your dataset YAML."
34
+ )
35
+ if isinstance(names[0], str) and names[0].startswith("n0"): # imagenet class codes, i.e. 'n01440764'
36
+ names_map = YAML.load(ROOT / "cfg/datasets/ImageNet.yaml")["map"] # human-readable names
37
+ names = {k: names_map[v] for k, v in names.items()}
38
+ return names
39
+
40
+
41
+ def default_class_names(data=None):
42
+ """Applies default class names to an input YAML file or returns numerical class names."""
43
+ if data:
44
+ try:
45
+ return YAML.load(check_yaml(data))["names"]
46
+ except Exception:
47
+ pass
48
+ return {i: f"class{i}" for i in range(999)} # return default if above errors
49
+
50
+
51
+ class AutoBackend(nn.Module):
52
+ """
53
+ Handles dynamic backend selection for running inference using Ultralytics YOLO models.
54
+
55
+ The AutoBackend class is designed to provide an abstraction layer for various inference engines. It supports a wide
56
+ range of formats, each with specific naming conventions as outlined below:
57
+
58
+ Supported Formats and Naming Conventions:
59
+ | Format | File Suffix |
60
+ | --------------------- | ----------------- |
61
+ | PyTorch | *.pt |
62
+ | TorchScript | *.torchscript |
63
+ | ONNX Runtime | *.onnx |
64
+ | ONNX OpenCV DNN | *.onnx (dnn=True) |
65
+ | OpenVINO | *openvino_model/ |
66
+ | CoreML | *.mlpackage |
67
+ | TensorRT | *.engine |
68
+ | TensorFlow SavedModel | *_saved_model/ |
69
+ | TensorFlow GraphDef | *.pb |
70
+ | TensorFlow Lite | *.tflite |
71
+ | TensorFlow Edge TPU | *_edgetpu.tflite |
72
+ | PaddlePaddle | *_paddle_model/ |
73
+ | MNN | *.mnn |
74
+ | NCNN | *_ncnn_model/ |
75
+ | IMX | *_imx_model/ |
76
+ | RKNN | *_rknn_model/ |
77
+
78
+ Attributes:
79
+ model (torch.nn.Module): The loaded YOLO model.
80
+ device (torch.device): The device (CPU or GPU) on which the model is loaded.
81
+ task (str): The type of task the model performs (detect, segment, classify, pose).
82
+ names (dict): A dictionary of class names that the model can detect.
83
+ stride (int): The model stride, typically 32 for YOLO models.
84
+ fp16 (bool): Whether the model uses half-precision (FP16) inference.
85
+
86
+ Methods:
87
+ forward: Run inference on an input image.
88
+ from_numpy: Convert numpy array to tensor.
89
+ warmup: Warm up the model with a dummy input.
90
+ _model_type: Determine the model type from file path.
91
+
92
+ Examples:
93
+ >>> model = AutoBackend(weights="yolo11n.pt", device="cuda")
94
+ >>> results = model(img)
95
+ """
96
+
97
+ @torch.no_grad()
98
+ def __init__(
99
+ self,
100
+ weights: Union[str, List[str], torch.nn.Module] = "yolo11n.pt",
101
+ device: torch.device = torch.device("cpu"),
102
+ dnn: bool = False,
103
+ data: Optional[Union[str, Path]] = None,
104
+ fp16: bool = False,
105
+ batch: int = 1,
106
+ fuse: bool = True,
107
+ verbose: bool = True,
108
+ ):
109
+ """
110
+ Initialize the AutoBackend for inference.
111
+
112
+ Args:
113
+ weights (str | List[str] | torch.nn.Module): Path to the model weights file or a module instance.
114
+ device (torch.device): Device to run the model on.
115
+ dnn (bool): Use OpenCV DNN module for ONNX inference.
116
+ data (str | Path | optional): Path to the additional data.yaml file containing class names.
117
+ fp16 (bool): Enable half-precision inference. Supported only on specific backends.
118
+ batch (int): Batch-size to assume for inference.
119
+ fuse (bool): Fuse Conv2D + BatchNorm layers for optimization.
120
+ verbose (bool): Enable verbose logging.
121
+ """
122
+ super().__init__()
123
+ w = str(weights[0] if isinstance(weights, list) else weights)
124
+ nn_module = isinstance(weights, torch.nn.Module)
125
+ (
126
+ pt,
127
+ jit,
128
+ onnx,
129
+ xml,
130
+ engine,
131
+ coreml,
132
+ saved_model,
133
+ pb,
134
+ tflite,
135
+ edgetpu,
136
+ tfjs,
137
+ paddle,
138
+ mnn,
139
+ ncnn,
140
+ imx,
141
+ rknn,
142
+ triton,
143
+ ) = self._model_type(w)
144
+ fp16 &= pt or jit or onnx or xml or engine or nn_module or triton # FP16
145
+ nhwc = coreml or saved_model or pb or tflite or edgetpu or rknn # BHWC formats (vs torch BCWH)
146
+ stride, ch = 32, 3 # default stride and channels
147
+ end2end, dynamic = False, False
148
+ model, metadata, task = None, None, None
149
+
150
+ # Set device
151
+ cuda = isinstance(device, torch.device) and torch.cuda.is_available() and device.type != "cpu" # use CUDA
152
+ if cuda and not any([nn_module, pt, jit, engine, onnx, paddle]): # GPU dataloader formats
153
+ device = torch.device("cpu")
154
+ cuda = False
155
+
156
+ # Download if not local
157
+ if not (pt or triton or nn_module):
158
+ w = attempt_download_asset(w)
159
+
160
+ # In-memory PyTorch model
161
+ if nn_module:
162
+ if fuse:
163
+ weights = weights.fuse(verbose=verbose) # fuse before move to gpu
164
+ model = weights.to(device)
165
+ if hasattr(model, "kpt_shape"):
166
+ kpt_shape = model.kpt_shape # pose-only
167
+ stride = max(int(model.stride.max()), 32) # model stride
168
+ names = model.module.names if hasattr(model, "module") else model.names # get class names
169
+ model.half() if fp16 else model.float()
170
+ ch = model.yaml.get("channels", 3)
171
+ self.model = model # explicitly assign for to(), cpu(), cuda(), half()
172
+ pt = True
173
+
174
+ # PyTorch
175
+ elif pt:
176
+ from ultralytics.nn.tasks import attempt_load_weights
177
+
178
+ model = attempt_load_weights(
179
+ weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse
180
+ )
181
+ if hasattr(model, "kpt_shape"):
182
+ kpt_shape = model.kpt_shape # pose-only
183
+ stride = max(int(model.stride.max()), 32) # model stride
184
+ names = model.module.names if hasattr(model, "module") else model.names # get class names
185
+ model.half() if fp16 else model.float()
186
+ ch = model.yaml.get("channels", 3)
187
+ self.model = model # explicitly assign for to(), cpu(), cuda(), half()
188
+
189
+ # TorchScript
190
+ elif jit:
191
+ import torchvision # noqa - https://github.com/ultralytics/ultralytics/pull/19747
192
+
193
+ LOGGER.info(f"Loading {w} for TorchScript inference...")
194
+ extra_files = {"config.txt": ""} # model metadata
195
+ model = torch.jit.load(w, _extra_files=extra_files, map_location=device)
196
+ model.half() if fp16 else model.float()
197
+ if extra_files["config.txt"]: # load metadata dict
198
+ metadata = json.loads(extra_files["config.txt"], object_hook=lambda x: dict(x.items()))
199
+
200
+ # ONNX OpenCV DNN
201
+ elif dnn:
202
+ LOGGER.info(f"Loading {w} for ONNX OpenCV DNN inference...")
203
+ check_requirements("opencv-python>=4.5.4")
204
+ net = cv2.dnn.readNetFromONNX(w)
205
+
206
+ # ONNX Runtime and IMX
207
+ elif onnx or imx:
208
+ LOGGER.info(f"Loading {w} for ONNX Runtime inference...")
209
+ check_requirements(("onnx", "onnxruntime-gpu" if cuda else "onnxruntime"))
210
+ import onnxruntime
211
+
212
+ providers = ["CPUExecutionProvider"]
213
+ if cuda:
214
+ if "CUDAExecutionProvider" in onnxruntime.get_available_providers():
215
+ providers.insert(0, "CUDAExecutionProvider")
216
+ else: # Only log warning if CUDA was requested but unavailable
217
+ LOGGER.warning("Failed to start ONNX Runtime with CUDA. Using CPU...")
218
+ device = torch.device("cpu")
219
+ cuda = False
220
+ LOGGER.info(f"Using ONNX Runtime {providers[0]}")
221
+ if onnx:
222
+ session = onnxruntime.InferenceSession(w, providers=providers)
223
+ else:
224
+ check_requirements(
225
+ ["model-compression-toolkit>=2.3.0", "sony-custom-layers[torch]>=0.3.0", "onnxruntime-extensions"]
226
+ )
227
+ w = next(Path(w).glob("*.onnx"))
228
+ LOGGER.info(f"Loading {w} for ONNX IMX inference...")
229
+ import mct_quantizers as mctq
230
+ from sony_custom_layers.pytorch.nms import nms_ort # noqa
231
+
232
+ session_options = mctq.get_ort_session_options()
233
+ session_options.enable_mem_reuse = False # fix the shape mismatch from onnxruntime
234
+ session = onnxruntime.InferenceSession(w, session_options, providers=["CPUExecutionProvider"])
235
+ task = "detect"
236
+
237
+ output_names = [x.name for x in session.get_outputs()]
238
+ metadata = session.get_modelmeta().custom_metadata_map
239
+ dynamic = isinstance(session.get_outputs()[0].shape[0], str)
240
+ fp16 = "float16" in session.get_inputs()[0].type
241
+ if not dynamic:
242
+ io = session.io_binding()
243
+ bindings = []
244
+ for output in session.get_outputs():
245
+ out_fp16 = "float16" in output.type
246
+ y_tensor = torch.empty(output.shape, dtype=torch.float16 if out_fp16 else torch.float32).to(device)
247
+ io.bind_output(
248
+ name=output.name,
249
+ device_type=device.type,
250
+ device_id=device.index if cuda else 0,
251
+ element_type=np.float16 if out_fp16 else np.float32,
252
+ shape=tuple(y_tensor.shape),
253
+ buffer_ptr=y_tensor.data_ptr(),
254
+ )
255
+ bindings.append(y_tensor)
256
+
257
+ # OpenVINO
258
+ elif xml:
259
+ LOGGER.info(f"Loading {w} for OpenVINO inference...")
260
+ check_requirements("openvino>=2024.0.0")
261
+ import openvino as ov
262
+
263
+ core = ov.Core()
264
+ device_name = "AUTO"
265
+ if isinstance(device, str) and device.startswith("intel"):
266
+ device_name = device.split(":")[1].upper() # Intel OpenVINO device
267
+ device = torch.device("cpu")
268
+ if device_name not in core.available_devices:
269
+ LOGGER.warning(f"OpenVINO device '{device_name}' not available. Using 'AUTO' instead.")
270
+ device_name = "AUTO"
271
+ w = Path(w)
272
+ if not w.is_file(): # if not *.xml
273
+ w = next(w.glob("*.xml")) # get *.xml file from *_openvino_model dir
274
+ ov_model = core.read_model(model=str(w), weights=w.with_suffix(".bin"))
275
+ if ov_model.get_parameters()[0].get_layout().empty:
276
+ ov_model.get_parameters()[0].set_layout(ov.Layout("NCHW"))
277
+
278
+ # OpenVINO inference modes are 'LATENCY', 'THROUGHPUT' (not recommended), or 'CUMULATIVE_THROUGHPUT'
279
+ inference_mode = "CUMULATIVE_THROUGHPUT" if batch > 1 else "LATENCY"
280
+ LOGGER.info(f"Using OpenVINO {inference_mode} mode for batch={batch} inference...")
281
+ ov_compiled_model = core.compile_model(
282
+ ov_model,
283
+ device_name=device_name,
284
+ config={"PERFORMANCE_HINT": inference_mode},
285
+ )
286
+ input_name = ov_compiled_model.input().get_any_name()
287
+ metadata = w.parent / "metadata.yaml"
288
+
289
+ # TensorRT
290
+ elif engine:
291
+ LOGGER.info(f"Loading {w} for TensorRT inference...")
292
+
293
+ if IS_JETSON and check_version(PYTHON_VERSION, "<=3.8.10"):
294
+ # fix error: `np.bool` was a deprecated alias for the builtin `bool` for JetPack 4 and JetPack 5 with Python <= 3.8.10
295
+ check_requirements("numpy==1.23.5")
296
+
297
+ try: # https://developer.nvidia.com/nvidia-tensorrt-download
298
+ import tensorrt as trt # noqa
299
+ except ImportError:
300
+ if LINUX:
301
+ check_requirements("tensorrt>7.0.0,!=10.1.0")
302
+ import tensorrt as trt # noqa
303
+ check_version(trt.__version__, ">=7.0.0", hard=True)
304
+ check_version(trt.__version__, "!=10.1.0", msg="https://github.com/ultralytics/ultralytics/pull/14239")
305
+ if device.type == "cpu":
306
+ device = torch.device("cuda:0")
307
+ Binding = namedtuple("Binding", ("name", "dtype", "shape", "data", "ptr"))
308
+ logger = trt.Logger(trt.Logger.INFO)
309
+ # Read file
310
+ with open(w, "rb") as f, trt.Runtime(logger) as runtime:
311
+ try:
312
+ meta_len = int.from_bytes(f.read(4), byteorder="little") # read metadata length
313
+ metadata = json.loads(f.read(meta_len).decode("utf-8")) # read metadata
314
+ dla = metadata.get("dla", None)
315
+ if dla is not None:
316
+ runtime.DLA_core = int(dla)
317
+ except UnicodeDecodeError:
318
+ f.seek(0) # engine file may lack embedded Ultralytics metadata
319
+ model = runtime.deserialize_cuda_engine(f.read()) # read engine
320
+
321
+ # Model context
322
+ try:
323
+ context = model.create_execution_context()
324
+ except Exception as e: # model is None
325
+ LOGGER.error(f"TensorRT model exported with a different version than {trt.__version__}\n")
326
+ raise e
327
+
328
+ bindings = OrderedDict()
329
+ output_names = []
330
+ fp16 = False # default updated below
331
+ dynamic = False
332
+ is_trt10 = not hasattr(model, "num_bindings")
333
+ num = range(model.num_io_tensors) if is_trt10 else range(model.num_bindings)
334
+ for i in num:
335
+ if is_trt10:
336
+ name = model.get_tensor_name(i)
337
+ dtype = trt.nptype(model.get_tensor_dtype(name))
338
+ is_input = model.get_tensor_mode(name) == trt.TensorIOMode.INPUT
339
+ if is_input:
340
+ if -1 in tuple(model.get_tensor_shape(name)):
341
+ dynamic = True
342
+ context.set_input_shape(name, tuple(model.get_tensor_profile_shape(name, 0)[1]))
343
+ if dtype == np.float16:
344
+ fp16 = True
345
+ else:
346
+ output_names.append(name)
347
+ shape = tuple(context.get_tensor_shape(name))
348
+ else: # TensorRT < 10.0
349
+ name = model.get_binding_name(i)
350
+ dtype = trt.nptype(model.get_binding_dtype(i))
351
+ is_input = model.binding_is_input(i)
352
+ if model.binding_is_input(i):
353
+ if -1 in tuple(model.get_binding_shape(i)): # dynamic
354
+ dynamic = True
355
+ context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[1]))
356
+ if dtype == np.float16:
357
+ fp16 = True
358
+ else:
359
+ output_names.append(name)
360
+ shape = tuple(context.get_binding_shape(i))
361
+ im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
362
+ bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
363
+ binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
364
+ batch_size = bindings["images"].shape[0] # if dynamic, this is instead max batch size
365
+
366
+ # CoreML
367
+ elif coreml:
368
+ LOGGER.info(f"Loading {w} for CoreML inference...")
369
+ import coremltools as ct
370
+
371
+ model = ct.models.MLModel(w)
372
+ metadata = dict(model.user_defined_metadata)
373
+
374
+ # TF SavedModel
375
+ elif saved_model:
376
+ LOGGER.info(f"Loading {w} for TensorFlow SavedModel inference...")
377
+ import tensorflow as tf
378
+
379
+ keras = False # assume TF1 saved_model
380
+ model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
381
+ metadata = Path(w) / "metadata.yaml"
382
+
383
+ # TF GraphDef
384
+ elif pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
385
+ LOGGER.info(f"Loading {w} for TensorFlow GraphDef inference...")
386
+ import tensorflow as tf
387
+
388
+ from ultralytics.engine.exporter import gd_outputs
389
+
390
+ def wrap_frozen_graph(gd, inputs, outputs):
391
+ """Wrap frozen graphs for deployment."""
392
+ x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped
393
+ ge = x.graph.as_graph_element
394
+ return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
395
+
396
+ gd = tf.Graph().as_graph_def() # TF GraphDef
397
+ with open(w, "rb") as f:
398
+ gd.ParseFromString(f.read())
399
+ frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd))
400
+ try: # find metadata in SavedModel alongside GraphDef
401
+ metadata = next(Path(w).resolve().parent.rglob(f"{Path(w).stem}_saved_model*/metadata.yaml"))
402
+ except StopIteration:
403
+ pass
404
+
405
+ # TFLite or TFLite Edge TPU
406
+ elif tflite or edgetpu: # https://ai.google.dev/edge/litert/microcontrollers/python
407
+ try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
408
+ from tflite_runtime.interpreter import Interpreter, load_delegate
409
+ except ImportError:
410
+ import tensorflow as tf
411
+
412
+ Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate
413
+ if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime
414
+ device = device[3:] if str(device).startswith("tpu") else ":0"
415
+ LOGGER.info(f"Loading {w} on device {device[1:]} for TensorFlow Lite Edge TPU inference...")
416
+ delegate = {"Linux": "libedgetpu.so.1", "Darwin": "libedgetpu.1.dylib", "Windows": "edgetpu.dll"}[
417
+ platform.system()
418
+ ]
419
+ interpreter = Interpreter(
420
+ model_path=w,
421
+ experimental_delegates=[load_delegate(delegate, options={"device": device})],
422
+ )
423
+ device = "cpu" # Required, otherwise PyTorch will try to use the wrong device
424
+ else: # TFLite
425
+ LOGGER.info(f"Loading {w} for TensorFlow Lite inference...")
426
+ interpreter = Interpreter(model_path=w) # load TFLite model
427
+ interpreter.allocate_tensors() # allocate
428
+ input_details = interpreter.get_input_details() # inputs
429
+ output_details = interpreter.get_output_details() # outputs
430
+ # Load metadata
431
+ try:
432
+ with zipfile.ZipFile(w, "r") as zf:
433
+ name = zf.namelist()[0]
434
+ contents = zf.read(name).decode("utf-8")
435
+ if name == "metadata.json": # Custom Ultralytics metadata dict for Python>=3.12
436
+ metadata = json.loads(contents)
437
+ else:
438
+ metadata = ast.literal_eval(contents) # Default tflite-support metadata for Python<=3.11
439
+ except (zipfile.BadZipFile, SyntaxError, ValueError, json.JSONDecodeError):
440
+ pass
441
+
442
+ # TF.js
443
+ elif tfjs:
444
+ raise NotImplementedError("YOLOv8 TF.js inference is not currently supported.")
445
+
446
+ # PaddlePaddle
447
+ elif paddle:
448
+ LOGGER.info(f"Loading {w} for PaddlePaddle inference...")
449
+ check_requirements("paddlepaddle-gpu" if cuda else "paddlepaddle>=3.0.0")
450
+ import paddle.inference as pdi # noqa
451
+
452
+ w = Path(w)
453
+ model_file, params_file = None, None
454
+ if w.is_dir():
455
+ model_file = next(w.rglob("*.json"), None)
456
+ params_file = next(w.rglob("*.pdiparams"), None)
457
+ elif w.suffix == ".pdiparams":
458
+ model_file = w.with_name("model.json")
459
+ params_file = w
460
+
461
+ if not (model_file and params_file and model_file.is_file() and params_file.is_file()):
462
+ raise FileNotFoundError(f"Paddle model not found in {w}. Both .json and .pdiparams files are required.")
463
+
464
+ config = pdi.Config(str(model_file), str(params_file))
465
+ if cuda:
466
+ config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0)
467
+ predictor = pdi.create_predictor(config)
468
+ input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
469
+ output_names = predictor.get_output_names()
470
+ metadata = w / "metadata.yaml"
471
+
472
+ # MNN
473
+ elif mnn:
474
+ LOGGER.info(f"Loading {w} for MNN inference...")
475
+ check_requirements("MNN") # requires MNN
476
+ import os
477
+
478
+ import MNN
479
+
480
+ config = {"precision": "low", "backend": "CPU", "numThread": (os.cpu_count() + 1) // 2}
481
+ rt = MNN.nn.create_runtime_manager((config,))
482
+ net = MNN.nn.load_module_from_file(w, [], [], runtime_manager=rt, rearrange=True)
483
+
484
+ def torch_to_mnn(x):
485
+ return MNN.expr.const(x.data_ptr(), x.shape)
486
+
487
+ metadata = json.loads(net.get_info()["bizCode"])
488
+
489
+ # NCNN
490
+ elif ncnn:
491
+ LOGGER.info(f"Loading {w} for NCNN inference...")
492
+ check_requirements("git+https://github.com/Tencent/ncnn.git" if ARM64 else "ncnn") # requires NCNN
493
+ import ncnn as pyncnn
494
+
495
+ net = pyncnn.Net()
496
+ net.opt.use_vulkan_compute = cuda
497
+ w = Path(w)
498
+ if not w.is_file(): # if not *.param
499
+ w = next(w.glob("*.param")) # get *.param file from *_ncnn_model dir
500
+ net.load_param(str(w))
501
+ net.load_model(str(w.with_suffix(".bin")))
502
+ metadata = w.parent / "metadata.yaml"
503
+
504
+ # NVIDIA Triton Inference Server
505
+ elif triton:
506
+ check_requirements("tritonclient[all]")
507
+ from ultralytics.utils.triton import TritonRemoteModel
508
+
509
+ model = TritonRemoteModel(w)
510
+ metadata = model.metadata
511
+
512
+ # RKNN
513
+ elif rknn:
514
+ if not is_rockchip():
515
+ raise OSError("RKNN inference is only supported on Rockchip devices.")
516
+ LOGGER.info(f"Loading {w} for RKNN inference...")
517
+ check_requirements("rknn-toolkit-lite2")
518
+ from rknnlite.api import RKNNLite
519
+
520
+ w = Path(w)
521
+ if not w.is_file(): # if not *.rknn
522
+ w = next(w.rglob("*.rknn")) # get *.rknn file from *_rknn_model dir
523
+ rknn_model = RKNNLite()
524
+ rknn_model.load_rknn(str(w))
525
+ rknn_model.init_runtime()
526
+ metadata = w.parent / "metadata.yaml"
527
+
528
+ # Any other format (unsupported)
529
+ else:
530
+ from ultralytics.engine.exporter import export_formats
531
+
532
+ raise TypeError(
533
+ f"model='{w}' is not a supported model format. Ultralytics supports: {export_formats()['Format']}\n"
534
+ f"See https://docs.ultralytics.com/modes/predict for help."
535
+ )
536
+
537
+ # Load external metadata YAML
538
+ if isinstance(metadata, (str, Path)) and Path(metadata).exists():
539
+ metadata = YAML.load(metadata)
540
+ if metadata and isinstance(metadata, dict):
541
+ for k, v in metadata.items():
542
+ if k in {"stride", "batch", "channels"}:
543
+ metadata[k] = int(v)
544
+ elif k in {"imgsz", "names", "kpt_shape", "args"} and isinstance(v, str):
545
+ metadata[k] = eval(v)
546
+ stride = metadata["stride"]
547
+ task = metadata["task"]
548
+ batch = metadata["batch"]
549
+ imgsz = metadata["imgsz"]
550
+ names = metadata["names"]
551
+ kpt_shape = metadata.get("kpt_shape")
552
+ end2end = metadata.get("args", {}).get("nms", False)
553
+ dynamic = metadata.get("args", {}).get("dynamic", dynamic)
554
+ ch = metadata.get("channels", 3)
555
+ elif not (pt or triton or nn_module):
556
+ LOGGER.warning(f"Metadata not found for 'model={weights}'")
557
+
558
+ # Check names
559
+ if "names" not in locals(): # names missing
560
+ names = default_class_names(data)
561
+ names = check_class_names(names)
562
+
563
+ # Disable gradients
564
+ if pt:
565
+ for p in model.parameters():
566
+ p.requires_grad = False
567
+
568
+ self.__dict__.update(locals()) # assign all variables to self
569
+
570
+ def forward(self, im, augment=False, visualize=False, embed=None, **kwargs):
571
+ """
572
+ Runs inference on the YOLOv8 MultiBackend model.
573
+
574
+ Args:
575
+ im (torch.Tensor): The image tensor to perform inference on.
576
+ augment (bool): Whether to perform data augmentation during inference.
577
+ visualize (bool): Whether to visualize the output predictions.
578
+ embed (list | None): A list of feature vectors/embeddings to return.
579
+ **kwargs (Any): Additional keyword arguments for model configuration.
580
+
581
+ Returns:
582
+ (torch.Tensor | List[torch.Tensor]): The raw output tensor(s) from the model.
583
+ """
584
+ b, ch, h, w = im.shape # batch, channel, height, width
585
+ if self.fp16 and im.dtype != torch.float16:
586
+ im = im.half() # to FP16
587
+ if self.nhwc:
588
+ im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3)
589
+
590
+ # PyTorch
591
+ if self.pt or self.nn_module:
592
+ y = self.model(im, augment=augment, visualize=visualize, embed=embed, **kwargs)
593
+
594
+ # TorchScript
595
+ elif self.jit:
596
+ y = self.model(im)
597
+
598
+ # ONNX OpenCV DNN
599
+ elif self.dnn:
600
+ im = im.cpu().numpy() # torch to numpy
601
+ self.net.setInput(im)
602
+ y = self.net.forward()
603
+
604
+ # ONNX Runtime
605
+ elif self.onnx or self.imx:
606
+ if self.dynamic:
607
+ im = im.cpu().numpy() # torch to numpy
608
+ y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})
609
+ else:
610
+ if not self.cuda:
611
+ im = im.cpu()
612
+ self.io.bind_input(
613
+ name="images",
614
+ device_type=im.device.type,
615
+ device_id=im.device.index if im.device.type == "cuda" else 0,
616
+ element_type=np.float16 if self.fp16 else np.float32,
617
+ shape=tuple(im.shape),
618
+ buffer_ptr=im.data_ptr(),
619
+ )
620
+ self.session.run_with_iobinding(self.io)
621
+ y = self.bindings
622
+ if self.imx:
623
+ # boxes, conf, cls
624
+ y = np.concatenate([y[0], y[1][:, :, None], y[2][:, :, None]], axis=-1)
625
+
626
+ # OpenVINO
627
+ elif self.xml:
628
+ im = im.cpu().numpy() # FP32
629
+
630
+ if self.inference_mode in {"THROUGHPUT", "CUMULATIVE_THROUGHPUT"}: # optimized for larger batch-sizes
631
+ n = im.shape[0] # number of images in batch
632
+ results = [None] * n # preallocate list with None to match the number of images
633
+
634
+ def callback(request, userdata):
635
+ """Places result in preallocated list using userdata index."""
636
+ results[userdata] = request.results
637
+
638
+ # Create AsyncInferQueue, set the callback and start asynchronous inference for each input image
639
+ async_queue = self.ov.AsyncInferQueue(self.ov_compiled_model)
640
+ async_queue.set_callback(callback)
641
+ for i in range(n):
642
+ # Start async inference with userdata=i to specify the position in results list
643
+ async_queue.start_async(inputs={self.input_name: im[i : i + 1]}, userdata=i) # keep image as BCHW
644
+ async_queue.wait_all() # wait for all inference requests to complete
645
+ y = np.concatenate([list(r.values())[0] for r in results])
646
+
647
+ else: # inference_mode = "LATENCY", optimized for fastest first result at batch-size 1
648
+ y = list(self.ov_compiled_model(im).values())
649
+
650
+ # TensorRT
651
+ elif self.engine:
652
+ if self.dynamic and im.shape != self.bindings["images"].shape:
653
+ if self.is_trt10:
654
+ self.context.set_input_shape("images", im.shape)
655
+ self.bindings["images"] = self.bindings["images"]._replace(shape=im.shape)
656
+ for name in self.output_names:
657
+ self.bindings[name].data.resize_(tuple(self.context.get_tensor_shape(name)))
658
+ else:
659
+ i = self.model.get_binding_index("images")
660
+ self.context.set_binding_shape(i, im.shape)
661
+ self.bindings["images"] = self.bindings["images"]._replace(shape=im.shape)
662
+ for name in self.output_names:
663
+ i = self.model.get_binding_index(name)
664
+ self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i)))
665
+
666
+ s = self.bindings["images"].shape
667
+ assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}"
668
+ self.binding_addrs["images"] = int(im.data_ptr())
669
+ self.context.execute_v2(list(self.binding_addrs.values()))
670
+ y = [self.bindings[x].data for x in sorted(self.output_names)]
671
+
672
+ # CoreML
673
+ elif self.coreml:
674
+ im = im[0].cpu().numpy()
675
+ im_pil = Image.fromarray((im * 255).astype("uint8"))
676
+ # im = im.resize((192, 320), Image.BILINEAR)
677
+ y = self.model.predict({"image": im_pil}) # coordinates are xywh normalized
678
+ if "confidence" in y:
679
+ raise TypeError(
680
+ "Ultralytics only supports inference of non-pipelined CoreML models exported with "
681
+ f"'nms=False', but 'model={w}' has an NMS pipeline created by an 'nms=True' export."
682
+ )
683
+ # TODO: CoreML NMS inference handling
684
+ # from ultralytics.utils.ops import xywh2xyxy
685
+ # box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
686
+ # conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float32)
687
+ # y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
688
+ y = list(y.values())
689
+ if len(y) == 2 and len(y[1].shape) != 4: # segmentation model
690
+ y = list(reversed(y)) # reversed for segmentation models (pred, proto)
691
+
692
+ # PaddlePaddle
693
+ elif self.paddle:
694
+ im = im.cpu().numpy().astype(np.float32)
695
+ self.input_handle.copy_from_cpu(im)
696
+ self.predictor.run()
697
+ y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names]
698
+
699
+ # MNN
700
+ elif self.mnn:
701
+ input_var = self.torch_to_mnn(im)
702
+ output_var = self.net.onForward([input_var])
703
+ y = [x.read() for x in output_var]
704
+
705
+ # NCNN
706
+ elif self.ncnn:
707
+ mat_in = self.pyncnn.Mat(im[0].cpu().numpy())
708
+ with self.net.create_extractor() as ex:
709
+ ex.input(self.net.input_names()[0], mat_in)
710
+ # WARNING: 'output_names' sorted as a temporary fix for https://github.com/pnnx/pnnx/issues/130
711
+ y = [np.array(ex.extract(x)[1])[None] for x in sorted(self.net.output_names())]
712
+
713
+ # NVIDIA Triton Inference Server
714
+ elif self.triton:
715
+ im = im.cpu().numpy() # torch to numpy
716
+ y = self.model(im)
717
+
718
+ # RKNN
719
+ elif self.rknn:
720
+ im = (im.cpu().numpy() * 255).astype("uint8")
721
+ im = im if isinstance(im, (list, tuple)) else [im]
722
+ y = self.rknn_model.inference(inputs=im)
723
+
724
+ # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
725
+ else:
726
+ im = im.cpu().numpy()
727
+ if self.saved_model: # SavedModel
728
+ y = self.model(im, training=False) if self.keras else self.model.serving_default(im)
729
+ if not isinstance(y, list):
730
+ y = [y]
731
+ elif self.pb: # GraphDef
732
+ y = self.frozen_func(x=self.tf.constant(im))
733
+ else: # Lite or Edge TPU
734
+ details = self.input_details[0]
735
+ is_int = details["dtype"] in {np.int8, np.int16} # is TFLite quantized int8 or int16 model
736
+ if is_int:
737
+ scale, zero_point = details["quantization"]
738
+ im = (im / scale + zero_point).astype(details["dtype"]) # de-scale
739
+ self.interpreter.set_tensor(details["index"], im)
740
+ self.interpreter.invoke()
741
+ y = []
742
+ for output in self.output_details:
743
+ x = self.interpreter.get_tensor(output["index"])
744
+ if is_int:
745
+ scale, zero_point = output["quantization"]
746
+ x = (x.astype(np.float32) - zero_point) * scale # re-scale
747
+ if x.ndim == 3: # if task is not classification, excluding masks (ndim=4) as well
748
+ # Denormalize xywh by image size. See https://github.com/ultralytics/ultralytics/pull/1695
749
+ # xywh are normalized in TFLite/EdgeTPU to mitigate quantization error of integer models
750
+ if x.shape[-1] == 6 or self.end2end: # end-to-end model
751
+ x[:, :, [0, 2]] *= w
752
+ x[:, :, [1, 3]] *= h
753
+ if self.task == "pose":
754
+ x[:, :, 6::3] *= w
755
+ x[:, :, 7::3] *= h
756
+ else:
757
+ x[:, [0, 2]] *= w
758
+ x[:, [1, 3]] *= h
759
+ if self.task == "pose":
760
+ x[:, 5::3] *= w
761
+ x[:, 6::3] *= h
762
+ y.append(x)
763
+ # TF segment fixes: export is reversed vs ONNX export and protos are transposed
764
+ if len(y) == 2: # segment with (det, proto) output order reversed
765
+ if len(y[1].shape) != 4:
766
+ y = list(reversed(y)) # should be y = (1, 116, 8400), (1, 160, 160, 32)
767
+ if y[1].shape[-1] == 6: # end-to-end model
768
+ y = [y[1]]
769
+ else:
770
+ y[1] = np.transpose(y[1], (0, 3, 1, 2)) # should be y = (1, 116, 8400), (1, 32, 160, 160)
771
+ y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]
772
+
773
+ # for x in y:
774
+ # print(type(x), len(x)) if isinstance(x, (list, tuple)) else print(type(x), x.shape) # debug shapes
775
+ if isinstance(y, (list, tuple)):
776
+ if len(self.names) == 999 and (self.task == "segment" or len(y) == 2): # segments and names not defined
777
+ nc = y[0].shape[1] - y[1].shape[1] - 4 # y = (1, 32, 160, 160), (1, 116, 8400)
778
+ self.names = {i: f"class{i}" for i in range(nc)}
779
+ return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y]
780
+ else:
781
+ return self.from_numpy(y)
782
+
783
+ def from_numpy(self, x):
784
+ """
785
+ Convert a numpy array to a tensor.
786
+
787
+ Args:
788
+ x (np.ndarray): The array to be converted.
789
+
790
+ Returns:
791
+ (torch.Tensor): The converted tensor
792
+ """
793
+ return torch.tensor(x).to(self.device) if isinstance(x, np.ndarray) else x
794
+
795
+ def warmup(self, imgsz=(1, 3, 640, 640)):
796
+ """
797
+ Warm up the model by running one forward pass with a dummy input.
798
+
799
+ Args:
800
+ imgsz (tuple): The shape of the dummy input tensor in the format (batch_size, channels, height, width)
801
+ """
802
+ import torchvision # noqa (import here so torchvision import time not recorded in postprocess time)
803
+
804
+ warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton, self.nn_module
805
+ if any(warmup_types) and (self.device.type != "cpu" or self.triton):
806
+ im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
807
+ for _ in range(2 if self.jit else 1):
808
+ self.forward(im) # warmup
809
+
810
+ @staticmethod
811
+ def _model_type(p="path/to/model.pt"):
812
+ """
813
+ Takes a path to a model file and returns the model type.
814
+
815
+ Args:
816
+ p (str): Path to the model file.
817
+
818
+ Returns:
819
+ (List[bool]): List of booleans indicating the model type.
820
+
821
+ Examples:
822
+ >>> model = AutoBackend(weights="path/to/model.onnx")
823
+ >>> model_type = model._model_type() # returns "onnx"
824
+ """
825
+ from ultralytics.engine.exporter import export_formats
826
+
827
+ sf = export_formats()["Suffix"] # export suffixes
828
+ if not is_url(p) and not isinstance(p, str):
829
+ check_suffix(p, sf) # checks
830
+ name = Path(p).name
831
+ types = [s in name for s in sf]
832
+ types[5] |= name.endswith(".mlmodel") # retain support for older Apple CoreML *.mlmodel formats
833
+ types[8] &= not types[9] # tflite &= not edgetpu
834
+ if any(types):
835
+ triton = False
836
+ else:
837
+ from urllib.parse import urlsplit
838
+
839
+ url = urlsplit(p)
840
+ triton = bool(url.netloc) and bool(url.path) and url.scheme in {"http", "grpc"}
841
+
842
+ return types + [triton]