dgenerate-ultralytics-headless 8.3.253__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dgenerate_ultralytics_headless-8.3.253.dist-info/METADATA +405 -0
- dgenerate_ultralytics_headless-8.3.253.dist-info/RECORD +299 -0
- dgenerate_ultralytics_headless-8.3.253.dist-info/WHEEL +5 -0
- dgenerate_ultralytics_headless-8.3.253.dist-info/entry_points.txt +3 -0
- dgenerate_ultralytics_headless-8.3.253.dist-info/licenses/LICENSE +661 -0
- dgenerate_ultralytics_headless-8.3.253.dist-info/top_level.txt +1 -0
- tests/__init__.py +23 -0
- tests/conftest.py +59 -0
- tests/test_cli.py +131 -0
- tests/test_cuda.py +216 -0
- tests/test_engine.py +157 -0
- tests/test_exports.py +309 -0
- tests/test_integrations.py +151 -0
- tests/test_python.py +777 -0
- tests/test_solutions.py +371 -0
- ultralytics/__init__.py +48 -0
- ultralytics/assets/bus.jpg +0 -0
- ultralytics/assets/zidane.jpg +0 -0
- ultralytics/cfg/__init__.py +1028 -0
- ultralytics/cfg/datasets/Argoverse.yaml +78 -0
- ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
- ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +32 -0
- ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
- ultralytics/cfg/datasets/Objects365.yaml +447 -0
- ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
- ultralytics/cfg/datasets/TT100K.yaml +346 -0
- ultralytics/cfg/datasets/VOC.yaml +102 -0
- ultralytics/cfg/datasets/VisDrone.yaml +87 -0
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +22 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
- ultralytics/cfg/datasets/coco-pose.yaml +64 -0
- ultralytics/cfg/datasets/coco.yaml +118 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco128.yaml +101 -0
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
- ultralytics/cfg/datasets/coco8-pose.yaml +47 -0
- ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco8.yaml +101 -0
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +22 -0
- ultralytics/cfg/datasets/dog-pose.yaml +52 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
- ultralytics/cfg/datasets/dota8.yaml +35 -0
- ultralytics/cfg/datasets/hand-keypoints.yaml +50 -0
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +1240 -0
- ultralytics/cfg/datasets/medical-pills.yaml +21 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +663 -0
- ultralytics/cfg/datasets/package-seg.yaml +22 -0
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +41 -0
- ultralytics/cfg/datasets/xView.yaml +155 -0
- ultralytics/cfg/default.yaml +130 -0
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
- ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
- ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
- ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
- ultralytics/cfg/models/12/yolo12.yaml +48 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
- ultralytics/cfg/models/v3/yolov3.yaml +49 -0
- ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
- ultralytics/cfg/models/v5/yolov5.yaml +51 -0
- ultralytics/cfg/models/v6/yolov6.yaml +56 -0
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +48 -0
- ultralytics/cfg/models/v8/yoloe-v8.yaml +48 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
- ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8.yaml +49 -0
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/trackers/botsort.yaml +21 -0
- ultralytics/cfg/trackers/bytetrack.yaml +12 -0
- ultralytics/data/__init__.py +26 -0
- ultralytics/data/annotator.py +66 -0
- ultralytics/data/augment.py +2801 -0
- ultralytics/data/base.py +435 -0
- ultralytics/data/build.py +437 -0
- ultralytics/data/converter.py +855 -0
- ultralytics/data/dataset.py +834 -0
- ultralytics/data/loaders.py +704 -0
- ultralytics/data/scripts/download_weights.sh +18 -0
- ultralytics/data/scripts/get_coco.sh +61 -0
- ultralytics/data/scripts/get_coco128.sh +18 -0
- ultralytics/data/scripts/get_imagenet.sh +52 -0
- ultralytics/data/split.py +138 -0
- ultralytics/data/split_dota.py +344 -0
- ultralytics/data/utils.py +798 -0
- ultralytics/engine/__init__.py +1 -0
- ultralytics/engine/exporter.py +1580 -0
- ultralytics/engine/model.py +1125 -0
- ultralytics/engine/predictor.py +508 -0
- ultralytics/engine/results.py +1522 -0
- ultralytics/engine/trainer.py +977 -0
- ultralytics/engine/tuner.py +449 -0
- ultralytics/engine/validator.py +387 -0
- ultralytics/hub/__init__.py +166 -0
- ultralytics/hub/auth.py +151 -0
- ultralytics/hub/google/__init__.py +174 -0
- ultralytics/hub/session.py +422 -0
- ultralytics/hub/utils.py +162 -0
- ultralytics/models/__init__.py +9 -0
- ultralytics/models/fastsam/__init__.py +7 -0
- ultralytics/models/fastsam/model.py +79 -0
- ultralytics/models/fastsam/predict.py +169 -0
- ultralytics/models/fastsam/utils.py +23 -0
- ultralytics/models/fastsam/val.py +38 -0
- ultralytics/models/nas/__init__.py +7 -0
- ultralytics/models/nas/model.py +98 -0
- ultralytics/models/nas/predict.py +56 -0
- ultralytics/models/nas/val.py +38 -0
- ultralytics/models/rtdetr/__init__.py +7 -0
- ultralytics/models/rtdetr/model.py +63 -0
- ultralytics/models/rtdetr/predict.py +88 -0
- ultralytics/models/rtdetr/train.py +89 -0
- ultralytics/models/rtdetr/val.py +216 -0
- ultralytics/models/sam/__init__.py +25 -0
- ultralytics/models/sam/amg.py +275 -0
- ultralytics/models/sam/build.py +365 -0
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +169 -0
- ultralytics/models/sam/modules/__init__.py +1 -0
- ultralytics/models/sam/modules/blocks.py +1067 -0
- ultralytics/models/sam/modules/decoders.py +495 -0
- ultralytics/models/sam/modules/encoders.py +794 -0
- ultralytics/models/sam/modules/memory_attention.py +298 -0
- ultralytics/models/sam/modules/sam.py +1160 -0
- ultralytics/models/sam/modules/tiny_encoder.py +979 -0
- ultralytics/models/sam/modules/transformer.py +344 -0
- ultralytics/models/sam/modules/utils.py +512 -0
- ultralytics/models/sam/predict.py +3940 -0
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/__init__.py +1 -0
- ultralytics/models/utils/loss.py +466 -0
- ultralytics/models/utils/ops.py +315 -0
- ultralytics/models/yolo/__init__.py +7 -0
- ultralytics/models/yolo/classify/__init__.py +7 -0
- ultralytics/models/yolo/classify/predict.py +90 -0
- ultralytics/models/yolo/classify/train.py +202 -0
- ultralytics/models/yolo/classify/val.py +216 -0
- ultralytics/models/yolo/detect/__init__.py +7 -0
- ultralytics/models/yolo/detect/predict.py +122 -0
- ultralytics/models/yolo/detect/train.py +227 -0
- ultralytics/models/yolo/detect/val.py +507 -0
- ultralytics/models/yolo/model.py +430 -0
- ultralytics/models/yolo/obb/__init__.py +7 -0
- ultralytics/models/yolo/obb/predict.py +56 -0
- ultralytics/models/yolo/obb/train.py +79 -0
- ultralytics/models/yolo/obb/val.py +302 -0
- ultralytics/models/yolo/pose/__init__.py +7 -0
- ultralytics/models/yolo/pose/predict.py +65 -0
- ultralytics/models/yolo/pose/train.py +110 -0
- ultralytics/models/yolo/pose/val.py +248 -0
- ultralytics/models/yolo/segment/__init__.py +7 -0
- ultralytics/models/yolo/segment/predict.py +109 -0
- ultralytics/models/yolo/segment/train.py +69 -0
- ultralytics/models/yolo/segment/val.py +307 -0
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +173 -0
- ultralytics/models/yolo/world/train_world.py +178 -0
- ultralytics/models/yolo/yoloe/__init__.py +22 -0
- ultralytics/models/yolo/yoloe/predict.py +162 -0
- ultralytics/models/yolo/yoloe/train.py +287 -0
- ultralytics/models/yolo/yoloe/train_seg.py +122 -0
- ultralytics/models/yolo/yoloe/val.py +206 -0
- ultralytics/nn/__init__.py +27 -0
- ultralytics/nn/autobackend.py +964 -0
- ultralytics/nn/modules/__init__.py +182 -0
- ultralytics/nn/modules/activation.py +54 -0
- ultralytics/nn/modules/block.py +1947 -0
- ultralytics/nn/modules/conv.py +669 -0
- ultralytics/nn/modules/head.py +1183 -0
- ultralytics/nn/modules/transformer.py +793 -0
- ultralytics/nn/modules/utils.py +159 -0
- ultralytics/nn/tasks.py +1768 -0
- ultralytics/nn/text_model.py +356 -0
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +41 -0
- ultralytics/solutions/ai_gym.py +108 -0
- ultralytics/solutions/analytics.py +264 -0
- ultralytics/solutions/config.py +107 -0
- ultralytics/solutions/distance_calculation.py +123 -0
- ultralytics/solutions/heatmap.py +125 -0
- ultralytics/solutions/instance_segmentation.py +86 -0
- ultralytics/solutions/object_blurrer.py +89 -0
- ultralytics/solutions/object_counter.py +190 -0
- ultralytics/solutions/object_cropper.py +87 -0
- ultralytics/solutions/parking_management.py +280 -0
- ultralytics/solutions/queue_management.py +93 -0
- ultralytics/solutions/region_counter.py +133 -0
- ultralytics/solutions/security_alarm.py +151 -0
- ultralytics/solutions/similarity_search.py +219 -0
- ultralytics/solutions/solutions.py +828 -0
- ultralytics/solutions/speed_estimation.py +114 -0
- ultralytics/solutions/streamlit_inference.py +260 -0
- ultralytics/solutions/templates/similarity-search.html +156 -0
- ultralytics/solutions/trackzone.py +88 -0
- ultralytics/solutions/vision_eye.py +67 -0
- ultralytics/trackers/__init__.py +7 -0
- ultralytics/trackers/basetrack.py +115 -0
- ultralytics/trackers/bot_sort.py +257 -0
- ultralytics/trackers/byte_tracker.py +469 -0
- ultralytics/trackers/track.py +116 -0
- ultralytics/trackers/utils/__init__.py +1 -0
- ultralytics/trackers/utils/gmc.py +339 -0
- ultralytics/trackers/utils/kalman_filter.py +482 -0
- ultralytics/trackers/utils/matching.py +154 -0
- ultralytics/utils/__init__.py +1450 -0
- ultralytics/utils/autobatch.py +118 -0
- ultralytics/utils/autodevice.py +205 -0
- ultralytics/utils/benchmarks.py +728 -0
- ultralytics/utils/callbacks/__init__.py +5 -0
- ultralytics/utils/callbacks/base.py +233 -0
- ultralytics/utils/callbacks/clearml.py +146 -0
- ultralytics/utils/callbacks/comet.py +625 -0
- ultralytics/utils/callbacks/dvc.py +197 -0
- ultralytics/utils/callbacks/hub.py +110 -0
- ultralytics/utils/callbacks/mlflow.py +134 -0
- ultralytics/utils/callbacks/neptune.py +126 -0
- ultralytics/utils/callbacks/platform.py +453 -0
- ultralytics/utils/callbacks/raytune.py +42 -0
- ultralytics/utils/callbacks/tensorboard.py +123 -0
- ultralytics/utils/callbacks/wb.py +188 -0
- ultralytics/utils/checks.py +1020 -0
- ultralytics/utils/cpu.py +85 -0
- ultralytics/utils/dist.py +123 -0
- ultralytics/utils/downloads.py +529 -0
- ultralytics/utils/errors.py +35 -0
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/export/engine.py +237 -0
- ultralytics/utils/export/imx.py +325 -0
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +219 -0
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +484 -0
- ultralytics/utils/logger.py +506 -0
- ultralytics/utils/loss.py +849 -0
- ultralytics/utils/metrics.py +1563 -0
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +664 -0
- ultralytics/utils/patches.py +201 -0
- ultralytics/utils/plotting.py +1047 -0
- ultralytics/utils/tal.py +404 -0
- ultralytics/utils/torch_utils.py +984 -0
- ultralytics/utils/tqdm.py +443 -0
- ultralytics/utils/triton.py +112 -0
- ultralytics/utils/tuner.py +168 -0
|
@@ -0,0 +1,964 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import ast
|
|
6
|
+
import json
|
|
7
|
+
import platform
|
|
8
|
+
import zipfile
|
|
9
|
+
from collections import OrderedDict, namedtuple
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
import cv2
|
|
14
|
+
import numpy as np
|
|
15
|
+
import torch
|
|
16
|
+
import torch.nn as nn
|
|
17
|
+
from PIL import Image
|
|
18
|
+
|
|
19
|
+
from ultralytics.utils import ARM64, IS_JETSON, LINUX, LOGGER, PYTHON_VERSION, ROOT, YAML, is_jetson
|
|
20
|
+
from ultralytics.utils.checks import check_requirements, check_suffix, check_version, check_yaml, is_rockchip
|
|
21
|
+
from ultralytics.utils.downloads import attempt_download_asset, is_url
|
|
22
|
+
from ultralytics.utils.nms import non_max_suppression
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def check_class_names(names: list | dict) -> dict[int, str]:
|
|
26
|
+
"""Check class names and convert to dict format if needed.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
names (list | dict): Class names as list or dict format.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
(dict): Class names in dict format with integer keys and string values.
|
|
33
|
+
|
|
34
|
+
Raises:
|
|
35
|
+
KeyError: If class indices are invalid for the dataset size.
|
|
36
|
+
"""
|
|
37
|
+
if isinstance(names, list): # names is a list
|
|
38
|
+
names = dict(enumerate(names)) # convert to dict
|
|
39
|
+
if isinstance(names, dict):
|
|
40
|
+
# Convert 1) string keys to int, i.e. '0' to 0, and non-string values to strings, i.e. True to 'True'
|
|
41
|
+
names = {int(k): str(v) for k, v in names.items()}
|
|
42
|
+
n = len(names)
|
|
43
|
+
if max(names.keys()) >= n:
|
|
44
|
+
raise KeyError(
|
|
45
|
+
f"{n}-class dataset requires class indices 0-{n - 1}, but you have invalid class indices "
|
|
46
|
+
f"{min(names.keys())}-{max(names.keys())} defined in your dataset YAML."
|
|
47
|
+
)
|
|
48
|
+
if isinstance(names[0], str) and names[0].startswith("n0"): # imagenet class codes, i.e. 'n01440764'
|
|
49
|
+
names_map = YAML.load(ROOT / "cfg/datasets/ImageNet.yaml")["map"] # human-readable names
|
|
50
|
+
names = {k: names_map[v] for k, v in names.items()}
|
|
51
|
+
return names
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def default_class_names(data: str | Path | None = None) -> dict[int, str]:
|
|
55
|
+
"""Apply default class names to an input YAML file or return numerical class names.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
data (str | Path, optional): Path to YAML file containing class names.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
(dict): Dictionary mapping class indices to class names.
|
|
62
|
+
"""
|
|
63
|
+
if data:
|
|
64
|
+
try:
|
|
65
|
+
return YAML.load(check_yaml(data))["names"]
|
|
66
|
+
except Exception:
|
|
67
|
+
pass
|
|
68
|
+
return {i: f"class{i}" for i in range(999)} # return default if above errors
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class AutoBackend(nn.Module):
|
|
72
|
+
"""Handle dynamic backend selection for running inference using Ultralytics YOLO models.
|
|
73
|
+
|
|
74
|
+
The AutoBackend class is designed to provide an abstraction layer for various inference engines. It supports a wide
|
|
75
|
+
range of formats, each with specific naming conventions as outlined below:
|
|
76
|
+
|
|
77
|
+
Supported Formats and Naming Conventions:
|
|
78
|
+
| Format | File Suffix |
|
|
79
|
+
| --------------------- | ----------------- |
|
|
80
|
+
| PyTorch | *.pt |
|
|
81
|
+
| TorchScript | *.torchscript |
|
|
82
|
+
| ONNX Runtime | *.onnx |
|
|
83
|
+
| ONNX OpenCV DNN | *.onnx (dnn=True) |
|
|
84
|
+
| OpenVINO | *openvino_model/ |
|
|
85
|
+
| CoreML | *.mlpackage |
|
|
86
|
+
| TensorRT | *.engine |
|
|
87
|
+
| TensorFlow SavedModel | *_saved_model/ |
|
|
88
|
+
| TensorFlow GraphDef | *.pb |
|
|
89
|
+
| TensorFlow Lite | *.tflite |
|
|
90
|
+
| TensorFlow Edge TPU | *_edgetpu.tflite |
|
|
91
|
+
| PaddlePaddle | *_paddle_model/ |
|
|
92
|
+
| MNN | *.mnn |
|
|
93
|
+
| NCNN | *_ncnn_model/ |
|
|
94
|
+
| IMX | *_imx_model/ |
|
|
95
|
+
| RKNN | *_rknn_model/ |
|
|
96
|
+
| Triton Inference | triton://model |
|
|
97
|
+
| ExecuTorch | *.pte |
|
|
98
|
+
| Axelera | *_axelera_model/ |
|
|
99
|
+
|
|
100
|
+
Attributes:
|
|
101
|
+
model (torch.nn.Module): The loaded YOLO model.
|
|
102
|
+
device (torch.device): The device (CPU or GPU) on which the model is loaded.
|
|
103
|
+
task (str): The type of task the model performs (detect, segment, classify, pose).
|
|
104
|
+
names (dict): A dictionary of class names that the model can detect.
|
|
105
|
+
stride (int): The model stride, typically 32 for YOLO models.
|
|
106
|
+
fp16 (bool): Whether the model uses half-precision (FP16) inference.
|
|
107
|
+
nhwc (bool): Whether the model expects NHWC input format instead of NCHW.
|
|
108
|
+
pt (bool): Whether the model is a PyTorch model.
|
|
109
|
+
jit (bool): Whether the model is a TorchScript model.
|
|
110
|
+
onnx (bool): Whether the model is an ONNX model.
|
|
111
|
+
xml (bool): Whether the model is an OpenVINO model.
|
|
112
|
+
engine (bool): Whether the model is a TensorRT engine.
|
|
113
|
+
coreml (bool): Whether the model is a CoreML model.
|
|
114
|
+
saved_model (bool): Whether the model is a TensorFlow SavedModel.
|
|
115
|
+
pb (bool): Whether the model is a TensorFlow GraphDef.
|
|
116
|
+
tflite (bool): Whether the model is a TensorFlow Lite model.
|
|
117
|
+
edgetpu (bool): Whether the model is a TensorFlow Edge TPU model.
|
|
118
|
+
tfjs (bool): Whether the model is a TensorFlow.js model.
|
|
119
|
+
paddle (bool): Whether the model is a PaddlePaddle model.
|
|
120
|
+
mnn (bool): Whether the model is an MNN model.
|
|
121
|
+
ncnn (bool): Whether the model is an NCNN model.
|
|
122
|
+
imx (bool): Whether the model is an IMX model.
|
|
123
|
+
rknn (bool): Whether the model is an RKNN model.
|
|
124
|
+
triton (bool): Whether the model is a Triton Inference Server model.
|
|
125
|
+
pte (bool): Whether the model is a PyTorch ExecuTorch model.
|
|
126
|
+
axelera (bool): Whether the model is an Axelera model.
|
|
127
|
+
|
|
128
|
+
Methods:
|
|
129
|
+
forward: Run inference on an input image.
|
|
130
|
+
from_numpy: Convert NumPy arrays to tensors on the model device.
|
|
131
|
+
warmup: Warm up the model with a dummy input.
|
|
132
|
+
_model_type: Determine the model type from file path.
|
|
133
|
+
|
|
134
|
+
Examples:
|
|
135
|
+
>>> model = AutoBackend(model="yolo11n.pt", device="cuda")
|
|
136
|
+
>>> results = model(img)
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
@torch.no_grad()
|
|
140
|
+
def __init__(
|
|
141
|
+
self,
|
|
142
|
+
model: str | torch.nn.Module = "yolo11n.pt",
|
|
143
|
+
device: torch.device = torch.device("cpu"),
|
|
144
|
+
dnn: bool = False,
|
|
145
|
+
data: str | Path | None = None,
|
|
146
|
+
fp16: bool = False,
|
|
147
|
+
fuse: bool = True,
|
|
148
|
+
verbose: bool = True,
|
|
149
|
+
):
|
|
150
|
+
"""Initialize the AutoBackend for inference.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
model (str | torch.nn.Module): Path to the model weights file or a module instance.
|
|
154
|
+
device (torch.device): Device to run the model on.
|
|
155
|
+
dnn (bool): Use OpenCV DNN module for ONNX inference.
|
|
156
|
+
data (str | Path, optional): Path to the additional data.yaml file containing class names.
|
|
157
|
+
fp16 (bool): Enable half-precision inference. Supported only on specific backends.
|
|
158
|
+
fuse (bool): Fuse Conv2D + BatchNorm layers for optimization.
|
|
159
|
+
verbose (bool): Enable verbose logging.
|
|
160
|
+
"""
|
|
161
|
+
super().__init__()
|
|
162
|
+
nn_module = isinstance(model, torch.nn.Module)
|
|
163
|
+
(
|
|
164
|
+
pt,
|
|
165
|
+
jit,
|
|
166
|
+
onnx,
|
|
167
|
+
xml,
|
|
168
|
+
engine,
|
|
169
|
+
coreml,
|
|
170
|
+
saved_model,
|
|
171
|
+
pb,
|
|
172
|
+
tflite,
|
|
173
|
+
edgetpu,
|
|
174
|
+
tfjs,
|
|
175
|
+
paddle,
|
|
176
|
+
mnn,
|
|
177
|
+
ncnn,
|
|
178
|
+
imx,
|
|
179
|
+
rknn,
|
|
180
|
+
pte,
|
|
181
|
+
axelera,
|
|
182
|
+
triton,
|
|
183
|
+
) = self._model_type("" if nn_module else model)
|
|
184
|
+
fp16 &= pt or jit or onnx or xml or engine or nn_module or triton # FP16
|
|
185
|
+
nhwc = coreml or saved_model or pb or tflite or edgetpu or rknn # BHWC formats (vs torch BCHW)
|
|
186
|
+
stride, ch = 32, 3 # default stride and channels
|
|
187
|
+
end2end, dynamic = False, False
|
|
188
|
+
metadata, task = None, None
|
|
189
|
+
|
|
190
|
+
# Set device
|
|
191
|
+
cuda = isinstance(device, torch.device) and torch.cuda.is_available() and device.type != "cpu" # use CUDA
|
|
192
|
+
if cuda and not any([nn_module, pt, jit, engine, onnx, paddle]): # GPU dataloader formats
|
|
193
|
+
device = torch.device("cpu")
|
|
194
|
+
cuda = False
|
|
195
|
+
|
|
196
|
+
# Download if not local
|
|
197
|
+
w = attempt_download_asset(model) if pt else model # weights path
|
|
198
|
+
|
|
199
|
+
# PyTorch (in-memory or file)
|
|
200
|
+
if nn_module or pt:
|
|
201
|
+
if nn_module:
|
|
202
|
+
pt = True
|
|
203
|
+
if fuse:
|
|
204
|
+
if IS_JETSON and is_jetson(jetpack=5):
|
|
205
|
+
# Jetson Jetpack5 requires device before fuse https://github.com/ultralytics/ultralytics/pull/21028
|
|
206
|
+
model = model.to(device)
|
|
207
|
+
model = model.fuse(verbose=verbose)
|
|
208
|
+
model = model.to(device)
|
|
209
|
+
else: # pt file
|
|
210
|
+
from ultralytics.nn.tasks import load_checkpoint
|
|
211
|
+
|
|
212
|
+
model, _ = load_checkpoint(model, device=device, fuse=fuse) # load model, ckpt
|
|
213
|
+
|
|
214
|
+
# Common PyTorch model processing
|
|
215
|
+
if hasattr(model, "kpt_shape"):
|
|
216
|
+
kpt_shape = model.kpt_shape # pose-only
|
|
217
|
+
stride = max(int(model.stride.max()), 32) # model stride
|
|
218
|
+
names = model.module.names if hasattr(model, "module") else model.names # get class names
|
|
219
|
+
model.half() if fp16 else model.float()
|
|
220
|
+
ch = model.yaml.get("channels", 3)
|
|
221
|
+
for p in model.parameters():
|
|
222
|
+
p.requires_grad = False
|
|
223
|
+
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
|
|
224
|
+
|
|
225
|
+
# TorchScript
|
|
226
|
+
elif jit:
|
|
227
|
+
import torchvision # noqa - https://github.com/ultralytics/ultralytics/pull/19747
|
|
228
|
+
|
|
229
|
+
LOGGER.info(f"Loading {w} for TorchScript inference...")
|
|
230
|
+
extra_files = {"config.txt": ""} # model metadata
|
|
231
|
+
model = torch.jit.load(w, _extra_files=extra_files, map_location=device)
|
|
232
|
+
model.half() if fp16 else model.float()
|
|
233
|
+
if extra_files["config.txt"]: # load metadata dict
|
|
234
|
+
metadata = json.loads(extra_files["config.txt"], object_hook=lambda x: dict(x.items()))
|
|
235
|
+
|
|
236
|
+
# ONNX OpenCV DNN
|
|
237
|
+
elif dnn:
|
|
238
|
+
LOGGER.info(f"Loading {w} for ONNX OpenCV DNN inference...")
|
|
239
|
+
check_requirements("opencv-python>=4.5.4")
|
|
240
|
+
net = cv2.dnn.readNetFromONNX(w)
|
|
241
|
+
|
|
242
|
+
# ONNX Runtime and IMX
|
|
243
|
+
elif onnx or imx:
|
|
244
|
+
LOGGER.info(f"Loading {w} for ONNX Runtime inference...")
|
|
245
|
+
check_requirements(("onnx", "onnxruntime-gpu" if cuda else "onnxruntime"))
|
|
246
|
+
import onnxruntime
|
|
247
|
+
|
|
248
|
+
# Select execution provider: CUDA > CoreML (mps) > CPU
|
|
249
|
+
available = onnxruntime.get_available_providers()
|
|
250
|
+
if cuda and "CUDAExecutionProvider" in available:
|
|
251
|
+
providers = [("CUDAExecutionProvider", {"device_id": device.index}), "CPUExecutionProvider"]
|
|
252
|
+
elif device.type == "mps" and "CoreMLExecutionProvider" in available:
|
|
253
|
+
providers = ["CoreMLExecutionProvider", "CPUExecutionProvider"]
|
|
254
|
+
else:
|
|
255
|
+
providers = ["CPUExecutionProvider"]
|
|
256
|
+
if cuda:
|
|
257
|
+
LOGGER.warning("CUDA requested but CUDAExecutionProvider not available. Using CPU...")
|
|
258
|
+
device, cuda = torch.device("cpu"), False
|
|
259
|
+
LOGGER.info(
|
|
260
|
+
f"Using ONNX Runtime {onnxruntime.__version__} with {providers[0] if isinstance(providers[0], str) else providers[0][0]}"
|
|
261
|
+
)
|
|
262
|
+
if onnx:
|
|
263
|
+
session = onnxruntime.InferenceSession(w, providers=providers)
|
|
264
|
+
else:
|
|
265
|
+
check_requirements(("model-compression-toolkit>=2.4.1", "edge-mdt-cl<1.1.0", "onnxruntime-extensions"))
|
|
266
|
+
w = next(Path(w).glob("*.onnx"))
|
|
267
|
+
LOGGER.info(f"Loading {w} for ONNX IMX inference...")
|
|
268
|
+
import mct_quantizers as mctq
|
|
269
|
+
from edgemdt_cl.pytorch.nms import nms_ort # noqa - register custom NMS ops
|
|
270
|
+
|
|
271
|
+
session_options = mctq.get_ort_session_options()
|
|
272
|
+
session_options.enable_mem_reuse = False # fix the shape mismatch from onnxruntime
|
|
273
|
+
session = onnxruntime.InferenceSession(w, session_options, providers=["CPUExecutionProvider"])
|
|
274
|
+
|
|
275
|
+
output_names = [x.name for x in session.get_outputs()]
|
|
276
|
+
metadata = session.get_modelmeta().custom_metadata_map
|
|
277
|
+
dynamic = isinstance(session.get_outputs()[0].shape[0], str)
|
|
278
|
+
fp16 = "float16" in session.get_inputs()[0].type
|
|
279
|
+
|
|
280
|
+
# Setup IO binding for optimized inference (CUDA only, not supported for CoreML)
|
|
281
|
+
use_io_binding = not dynamic and cuda
|
|
282
|
+
if use_io_binding:
|
|
283
|
+
io = session.io_binding()
|
|
284
|
+
bindings = []
|
|
285
|
+
for output in session.get_outputs():
|
|
286
|
+
out_fp16 = "float16" in output.type
|
|
287
|
+
y_tensor = torch.empty(output.shape, dtype=torch.float16 if out_fp16 else torch.float32).to(device)
|
|
288
|
+
io.bind_output(
|
|
289
|
+
name=output.name,
|
|
290
|
+
device_type=device.type,
|
|
291
|
+
device_id=device.index if cuda else 0,
|
|
292
|
+
element_type=np.float16 if out_fp16 else np.float32,
|
|
293
|
+
shape=tuple(y_tensor.shape),
|
|
294
|
+
buffer_ptr=y_tensor.data_ptr(),
|
|
295
|
+
)
|
|
296
|
+
bindings.append(y_tensor)
|
|
297
|
+
|
|
298
|
+
# OpenVINO
|
|
299
|
+
elif xml:
|
|
300
|
+
LOGGER.info(f"Loading {w} for OpenVINO inference...")
|
|
301
|
+
check_requirements("openvino>=2024.0.0")
|
|
302
|
+
import openvino as ov
|
|
303
|
+
|
|
304
|
+
core = ov.Core()
|
|
305
|
+
device_name = "AUTO"
|
|
306
|
+
if isinstance(device, str) and device.startswith("intel"):
|
|
307
|
+
device_name = device.split(":")[1].upper() # Intel OpenVINO device
|
|
308
|
+
device = torch.device("cpu")
|
|
309
|
+
if device_name not in core.available_devices:
|
|
310
|
+
LOGGER.warning(f"OpenVINO device '{device_name}' not available. Using 'AUTO' instead.")
|
|
311
|
+
device_name = "AUTO"
|
|
312
|
+
w = Path(w)
|
|
313
|
+
if not w.is_file(): # if not *.xml
|
|
314
|
+
w = next(w.glob("*.xml")) # get *.xml file from *_openvino_model dir
|
|
315
|
+
ov_model = core.read_model(model=str(w), weights=w.with_suffix(".bin"))
|
|
316
|
+
if ov_model.get_parameters()[0].get_layout().empty:
|
|
317
|
+
ov_model.get_parameters()[0].set_layout(ov.Layout("NCHW"))
|
|
318
|
+
|
|
319
|
+
metadata = w.parent / "metadata.yaml"
|
|
320
|
+
if metadata.exists():
|
|
321
|
+
metadata = YAML.load(metadata)
|
|
322
|
+
batch = metadata["batch"]
|
|
323
|
+
dynamic = metadata.get("args", {}).get("dynamic", dynamic)
|
|
324
|
+
# OpenVINO inference modes are 'LATENCY', 'THROUGHPUT' (not recommended), or 'CUMULATIVE_THROUGHPUT'
|
|
325
|
+
inference_mode = "CUMULATIVE_THROUGHPUT" if batch > 1 and dynamic else "LATENCY"
|
|
326
|
+
ov_compiled_model = core.compile_model(
|
|
327
|
+
ov_model,
|
|
328
|
+
device_name=device_name,
|
|
329
|
+
config={"PERFORMANCE_HINT": inference_mode},
|
|
330
|
+
)
|
|
331
|
+
LOGGER.info(
|
|
332
|
+
f"Using OpenVINO {inference_mode} mode for batch={batch} inference on {', '.join(ov_compiled_model.get_property('EXECUTION_DEVICES'))}..."
|
|
333
|
+
)
|
|
334
|
+
input_name = ov_compiled_model.input().get_any_name()
|
|
335
|
+
|
|
336
|
+
# TensorRT
|
|
337
|
+
elif engine:
|
|
338
|
+
LOGGER.info(f"Loading {w} for TensorRT inference...")
|
|
339
|
+
|
|
340
|
+
if IS_JETSON and check_version(PYTHON_VERSION, "<=3.8.10"):
|
|
341
|
+
# fix error: `np.bool` was a deprecated alias for the builtin `bool` for JetPack 4 and JetPack 5 with Python <= 3.8.10
|
|
342
|
+
check_requirements("numpy==1.23.5")
|
|
343
|
+
|
|
344
|
+
try: # https://developer.nvidia.com/nvidia-tensorrt-download
|
|
345
|
+
import tensorrt as trt
|
|
346
|
+
except ImportError:
|
|
347
|
+
if LINUX:
|
|
348
|
+
check_requirements("tensorrt>7.0.0,!=10.1.0")
|
|
349
|
+
import tensorrt as trt
|
|
350
|
+
check_version(trt.__version__, ">=7.0.0", hard=True)
|
|
351
|
+
check_version(trt.__version__, "!=10.1.0", msg="https://github.com/ultralytics/ultralytics/pull/14239")
|
|
352
|
+
if device.type == "cpu":
|
|
353
|
+
device = torch.device("cuda:0")
|
|
354
|
+
Binding = namedtuple("Binding", ("name", "dtype", "shape", "data", "ptr"))
|
|
355
|
+
logger = trt.Logger(trt.Logger.INFO)
|
|
356
|
+
# Read file
|
|
357
|
+
with open(w, "rb") as f, trt.Runtime(logger) as runtime:
|
|
358
|
+
try:
|
|
359
|
+
meta_len = int.from_bytes(f.read(4), byteorder="little") # read metadata length
|
|
360
|
+
metadata = json.loads(f.read(meta_len).decode("utf-8")) # read metadata
|
|
361
|
+
dla = metadata.get("dla", None)
|
|
362
|
+
if dla is not None:
|
|
363
|
+
runtime.DLA_core = int(dla)
|
|
364
|
+
except UnicodeDecodeError:
|
|
365
|
+
f.seek(0) # engine file may lack embedded Ultralytics metadata
|
|
366
|
+
model = runtime.deserialize_cuda_engine(f.read()) # read engine
|
|
367
|
+
|
|
368
|
+
# Model context
|
|
369
|
+
try:
|
|
370
|
+
context = model.create_execution_context()
|
|
371
|
+
except Exception as e: # model is None
|
|
372
|
+
LOGGER.error(f"TensorRT model exported with a different version than {trt.__version__}\n")
|
|
373
|
+
raise e
|
|
374
|
+
|
|
375
|
+
bindings = OrderedDict()
|
|
376
|
+
output_names = []
|
|
377
|
+
fp16 = False # default updated below
|
|
378
|
+
dynamic = False
|
|
379
|
+
is_trt10 = not hasattr(model, "num_bindings")
|
|
380
|
+
num = range(model.num_io_tensors) if is_trt10 else range(model.num_bindings)
|
|
381
|
+
for i in num:
|
|
382
|
+
# Get tensor info using TRT10+ or legacy API
|
|
383
|
+
if is_trt10:
|
|
384
|
+
name = model.get_tensor_name(i)
|
|
385
|
+
dtype = trt.nptype(model.get_tensor_dtype(name))
|
|
386
|
+
is_input = model.get_tensor_mode(name) == trt.TensorIOMode.INPUT
|
|
387
|
+
shape = tuple(model.get_tensor_shape(name))
|
|
388
|
+
profile_shape = tuple(model.get_tensor_profile_shape(name, 0)[2]) if is_input else None
|
|
389
|
+
else:
|
|
390
|
+
name = model.get_binding_name(i)
|
|
391
|
+
dtype = trt.nptype(model.get_binding_dtype(i))
|
|
392
|
+
is_input = model.binding_is_input(i)
|
|
393
|
+
shape = tuple(model.get_binding_shape(i))
|
|
394
|
+
profile_shape = tuple(model.get_profile_shape(0, i)[1]) if is_input else None
|
|
395
|
+
|
|
396
|
+
# Process input/output tensors
|
|
397
|
+
if is_input:
|
|
398
|
+
if -1 in shape:
|
|
399
|
+
dynamic = True
|
|
400
|
+
if is_trt10:
|
|
401
|
+
context.set_input_shape(name, profile_shape)
|
|
402
|
+
else:
|
|
403
|
+
context.set_binding_shape(i, profile_shape)
|
|
404
|
+
if dtype == np.float16:
|
|
405
|
+
fp16 = True
|
|
406
|
+
else:
|
|
407
|
+
output_names.append(name)
|
|
408
|
+
shape = tuple(context.get_tensor_shape(name)) if is_trt10 else tuple(context.get_binding_shape(i))
|
|
409
|
+
im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
|
|
410
|
+
bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
|
|
411
|
+
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
|
|
412
|
+
|
|
413
|
+
# CoreML
|
|
414
|
+
elif coreml:
|
|
415
|
+
check_requirements(
|
|
416
|
+
["coremltools>=9.0", "numpy>=1.14.5,<=2.3.5"]
|
|
417
|
+
) # latest numpy 2.4.0rc1 breaks coremltools exports
|
|
418
|
+
LOGGER.info(f"Loading {w} for CoreML inference...")
|
|
419
|
+
import coremltools as ct
|
|
420
|
+
|
|
421
|
+
model = ct.models.MLModel(w)
|
|
422
|
+
dynamic = model.get_spec().description.input[0].type.HasField("multiArrayType")
|
|
423
|
+
metadata = dict(model.user_defined_metadata)
|
|
424
|
+
|
|
425
|
+
# TF SavedModel
|
|
426
|
+
elif saved_model:
|
|
427
|
+
LOGGER.info(f"Loading {w} for TensorFlow SavedModel inference...")
|
|
428
|
+
import tensorflow as tf
|
|
429
|
+
|
|
430
|
+
model = tf.saved_model.load(w)
|
|
431
|
+
metadata = Path(w) / "metadata.yaml"
|
|
432
|
+
|
|
433
|
+
# TF GraphDef
|
|
434
|
+
elif pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
|
|
435
|
+
LOGGER.info(f"Loading {w} for TensorFlow GraphDef inference...")
|
|
436
|
+
import tensorflow as tf
|
|
437
|
+
|
|
438
|
+
from ultralytics.utils.export.tensorflow import gd_outputs
|
|
439
|
+
|
|
440
|
+
def wrap_frozen_graph(gd, inputs, outputs):
|
|
441
|
+
"""Wrap frozen graphs for deployment."""
|
|
442
|
+
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped
|
|
443
|
+
ge = x.graph.as_graph_element
|
|
444
|
+
return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
|
|
445
|
+
|
|
446
|
+
gd = tf.Graph().as_graph_def() # TF GraphDef
|
|
447
|
+
with open(w, "rb") as f:
|
|
448
|
+
gd.ParseFromString(f.read())
|
|
449
|
+
frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd))
|
|
450
|
+
try: # find metadata in SavedModel alongside GraphDef
|
|
451
|
+
metadata = next(Path(w).resolve().parent.rglob(f"{Path(w).stem}_saved_model*/metadata.yaml"))
|
|
452
|
+
except StopIteration:
|
|
453
|
+
pass
|
|
454
|
+
|
|
455
|
+
# TFLite or TFLite Edge TPU
|
|
456
|
+
elif tflite or edgetpu: # https://ai.google.dev/edge/litert/microcontrollers/python
|
|
457
|
+
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
|
|
458
|
+
from tflite_runtime.interpreter import Interpreter, load_delegate
|
|
459
|
+
except ImportError:
|
|
460
|
+
import tensorflow as tf
|
|
461
|
+
|
|
462
|
+
Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate
|
|
463
|
+
if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime
|
|
464
|
+
device = device[3:] if str(device).startswith("tpu") else ":0"
|
|
465
|
+
LOGGER.info(f"Loading {w} on device {device[1:]} for TensorFlow Lite Edge TPU inference...")
|
|
466
|
+
delegate = {"Linux": "libedgetpu.so.1", "Darwin": "libedgetpu.1.dylib", "Windows": "edgetpu.dll"}[
|
|
467
|
+
platform.system()
|
|
468
|
+
]
|
|
469
|
+
interpreter = Interpreter(
|
|
470
|
+
model_path=w,
|
|
471
|
+
experimental_delegates=[load_delegate(delegate, options={"device": device})],
|
|
472
|
+
)
|
|
473
|
+
device = "cpu" # Required, otherwise PyTorch will try to use the wrong device
|
|
474
|
+
else: # TFLite
|
|
475
|
+
LOGGER.info(f"Loading {w} for TensorFlow Lite inference...")
|
|
476
|
+
interpreter = Interpreter(model_path=w) # load TFLite model
|
|
477
|
+
interpreter.allocate_tensors() # allocate
|
|
478
|
+
input_details = interpreter.get_input_details() # inputs
|
|
479
|
+
output_details = interpreter.get_output_details() # outputs
|
|
480
|
+
# Load metadata
|
|
481
|
+
try:
|
|
482
|
+
with zipfile.ZipFile(w, "r") as zf:
|
|
483
|
+
name = zf.namelist()[0]
|
|
484
|
+
contents = zf.read(name).decode("utf-8")
|
|
485
|
+
if name == "metadata.json": # Custom Ultralytics metadata dict for Python>=3.12
|
|
486
|
+
metadata = json.loads(contents)
|
|
487
|
+
else:
|
|
488
|
+
metadata = ast.literal_eval(contents) # Default tflite-support metadata for Python<=3.11
|
|
489
|
+
except (zipfile.BadZipFile, SyntaxError, ValueError, json.JSONDecodeError):
|
|
490
|
+
pass
|
|
491
|
+
|
|
492
|
+
# TF.js
|
|
493
|
+
elif tfjs:
|
|
494
|
+
raise NotImplementedError("Ultralytics TF.js inference is not currently supported.")
|
|
495
|
+
|
|
496
|
+
# PaddlePaddle
|
|
497
|
+
elif paddle:
|
|
498
|
+
LOGGER.info(f"Loading {w} for PaddlePaddle inference...")
|
|
499
|
+
check_requirements(
|
|
500
|
+
"paddlepaddle-gpu>=3.0.0,!=3.3.0" # exclude 3.3.0 https://github.com/PaddlePaddle/Paddle/issues/77340
|
|
501
|
+
if torch.cuda.is_available()
|
|
502
|
+
else "paddlepaddle==3.0.0" # pin 3.0.0 for ARM64
|
|
503
|
+
if ARM64
|
|
504
|
+
else "paddlepaddle>=3.0.0,!=3.3.0" # exclude 3.3.0 https://github.com/PaddlePaddle/Paddle/issues/77340
|
|
505
|
+
)
|
|
506
|
+
import paddle.inference as pdi
|
|
507
|
+
|
|
508
|
+
w = Path(w)
|
|
509
|
+
model_file, params_file = None, None
|
|
510
|
+
if w.is_dir():
|
|
511
|
+
model_file = next(w.rglob("*.json"), None)
|
|
512
|
+
params_file = next(w.rglob("*.pdiparams"), None)
|
|
513
|
+
elif w.suffix == ".pdiparams":
|
|
514
|
+
model_file = w.with_name("model.json")
|
|
515
|
+
params_file = w
|
|
516
|
+
|
|
517
|
+
if not (model_file and params_file and model_file.is_file() and params_file.is_file()):
|
|
518
|
+
raise FileNotFoundError(f"Paddle model not found in {w}. Both .json and .pdiparams files are required.")
|
|
519
|
+
|
|
520
|
+
config = pdi.Config(str(model_file), str(params_file))
|
|
521
|
+
if cuda:
|
|
522
|
+
config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0)
|
|
523
|
+
predictor = pdi.create_predictor(config)
|
|
524
|
+
input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
|
|
525
|
+
output_names = predictor.get_output_names()
|
|
526
|
+
metadata = w / "metadata.yaml"
|
|
527
|
+
|
|
528
|
+
# MNN
|
|
529
|
+
elif mnn:
|
|
530
|
+
LOGGER.info(f"Loading {w} for MNN inference...")
|
|
531
|
+
check_requirements("MNN") # requires MNN
|
|
532
|
+
import os
|
|
533
|
+
|
|
534
|
+
import MNN
|
|
535
|
+
|
|
536
|
+
config = {"precision": "low", "backend": "CPU", "numThread": (os.cpu_count() + 1) // 2}
|
|
537
|
+
rt = MNN.nn.create_runtime_manager((config,))
|
|
538
|
+
net = MNN.nn.load_module_from_file(w, [], [], runtime_manager=rt, rearrange=True)
|
|
539
|
+
|
|
540
|
+
def torch_to_mnn(x):
|
|
541
|
+
return MNN.expr.const(x.data_ptr(), x.shape)
|
|
542
|
+
|
|
543
|
+
metadata = json.loads(net.get_info()["bizCode"])
|
|
544
|
+
|
|
545
|
+
# NCNN
|
|
546
|
+
elif ncnn:
|
|
547
|
+
LOGGER.info(f"Loading {w} for NCNN inference...")
|
|
548
|
+
# use git source for ARM64 due to broken PyPI packages https://github.com/Tencent/ncnn/issues/6509
|
|
549
|
+
check_requirements("git+https://github.com/Tencent/ncnn.git" if ARM64 else "ncnn", cmds="--no-deps")
|
|
550
|
+
import ncnn as pyncnn
|
|
551
|
+
|
|
552
|
+
net = pyncnn.Net()
|
|
553
|
+
if isinstance(cuda, torch.device):
|
|
554
|
+
net.opt.use_vulkan_compute = cuda
|
|
555
|
+
elif isinstance(device, str) and device.startswith("vulkan"):
|
|
556
|
+
net.opt.use_vulkan_compute = True
|
|
557
|
+
net.set_vulkan_device(int(device.split(":")[1]))
|
|
558
|
+
device = torch.device("cpu")
|
|
559
|
+
w = Path(w)
|
|
560
|
+
if not w.is_file(): # if not *.param
|
|
561
|
+
w = next(w.glob("*.param")) # get *.param file from *_ncnn_model dir
|
|
562
|
+
net.load_param(str(w))
|
|
563
|
+
net.load_model(str(w.with_suffix(".bin")))
|
|
564
|
+
metadata = w.parent / "metadata.yaml"
|
|
565
|
+
|
|
566
|
+
# NVIDIA Triton Inference Server
|
|
567
|
+
elif triton:
|
|
568
|
+
check_requirements("tritonclient[all]")
|
|
569
|
+
from ultralytics.utils.triton import TritonRemoteModel
|
|
570
|
+
|
|
571
|
+
model = TritonRemoteModel(w)
|
|
572
|
+
metadata = model.metadata
|
|
573
|
+
|
|
574
|
+
# RKNN
|
|
575
|
+
elif rknn:
|
|
576
|
+
if not is_rockchip():
|
|
577
|
+
raise OSError("RKNN inference is only supported on Rockchip devices.")
|
|
578
|
+
LOGGER.info(f"Loading {w} for RKNN inference...")
|
|
579
|
+
check_requirements("rknn-toolkit-lite2")
|
|
580
|
+
from rknnlite.api import RKNNLite
|
|
581
|
+
|
|
582
|
+
w = Path(w)
|
|
583
|
+
if not w.is_file(): # if not *.rknn
|
|
584
|
+
w = next(w.rglob("*.rknn")) # get *.rknn file from *_rknn_model dir
|
|
585
|
+
rknn_model = RKNNLite()
|
|
586
|
+
rknn_model.load_rknn(str(w))
|
|
587
|
+
rknn_model.init_runtime()
|
|
588
|
+
metadata = w.parent / "metadata.yaml"
|
|
589
|
+
|
|
590
|
+
# Axelera
|
|
591
|
+
elif axelera:
|
|
592
|
+
import os
|
|
593
|
+
|
|
594
|
+
if not os.environ.get("AXELERA_RUNTIME_DIR"):
|
|
595
|
+
LOGGER.warning(
|
|
596
|
+
"Axelera runtime environment is not activated."
|
|
597
|
+
"\nPlease run: source /opt/axelera/sdk/latest/axelera_activate.sh"
|
|
598
|
+
"\n\nIf this fails, verify driver installation: https://docs.ultralytics.com/integrations/axelera/#axelera-driver-installation"
|
|
599
|
+
)
|
|
600
|
+
try:
|
|
601
|
+
from axelera.runtime import op
|
|
602
|
+
except ImportError:
|
|
603
|
+
check_requirements(
|
|
604
|
+
"axelera_runtime2==0.1.2",
|
|
605
|
+
cmds="--extra-index-url https://software.axelera.ai/artifactory/axelera-runtime-pypi",
|
|
606
|
+
)
|
|
607
|
+
from axelera.runtime import op
|
|
608
|
+
|
|
609
|
+
w = Path(w)
|
|
610
|
+
if (found := next(w.rglob("*.axm"), None)) is None:
|
|
611
|
+
raise FileNotFoundError(f"No .axm file found in: {w}")
|
|
612
|
+
|
|
613
|
+
ax_model = op.load(str(found))
|
|
614
|
+
metadata = found.parent / "metadata.yaml"
|
|
615
|
+
|
|
616
|
+
# ExecuTorch
|
|
617
|
+
elif pte:
|
|
618
|
+
LOGGER.info(f"Loading {w} for ExecuTorch inference...")
|
|
619
|
+
# TorchAO release compatibility table bug https://github.com/pytorch/ao/issues/2919
|
|
620
|
+
check_requirements("setuptools<71.0.0") # Setuptools bug: https://github.com/pypa/setuptools/issues/4483
|
|
621
|
+
check_requirements(("executorch==1.0.1", "flatbuffers"))
|
|
622
|
+
from executorch.runtime import Runtime
|
|
623
|
+
|
|
624
|
+
w = Path(w)
|
|
625
|
+
if w.is_dir():
|
|
626
|
+
model_file = next(w.rglob("*.pte"))
|
|
627
|
+
metadata = w / "metadata.yaml"
|
|
628
|
+
else:
|
|
629
|
+
model_file = w
|
|
630
|
+
metadata = w.parent / "metadata.yaml"
|
|
631
|
+
|
|
632
|
+
program = Runtime.get().load_program(str(model_file))
|
|
633
|
+
model = program.load_method("forward")
|
|
634
|
+
|
|
635
|
+
# Any other format (unsupported)
|
|
636
|
+
else:
|
|
637
|
+
from ultralytics.engine.exporter import export_formats
|
|
638
|
+
|
|
639
|
+
raise TypeError(
|
|
640
|
+
f"model='{w}' is not a supported model format. Ultralytics supports: {export_formats()['Format']}\n"
|
|
641
|
+
f"See https://docs.ultralytics.com/modes/predict for help."
|
|
642
|
+
)
|
|
643
|
+
|
|
644
|
+
# Load external metadata YAML
|
|
645
|
+
if isinstance(metadata, (str, Path)) and Path(metadata).exists():
|
|
646
|
+
metadata = YAML.load(metadata)
|
|
647
|
+
if metadata and isinstance(metadata, dict):
|
|
648
|
+
for k, v in metadata.items():
|
|
649
|
+
if k in {"stride", "batch", "channels"}:
|
|
650
|
+
metadata[k] = int(v)
|
|
651
|
+
elif k in {"imgsz", "names", "kpt_shape", "kpt_names", "args"} and isinstance(v, str):
|
|
652
|
+
metadata[k] = ast.literal_eval(v)
|
|
653
|
+
stride = metadata["stride"]
|
|
654
|
+
task = metadata["task"]
|
|
655
|
+
batch = metadata["batch"]
|
|
656
|
+
imgsz = metadata["imgsz"]
|
|
657
|
+
names = metadata["names"]
|
|
658
|
+
kpt_shape = metadata.get("kpt_shape")
|
|
659
|
+
kpt_names = metadata.get("kpt_names")
|
|
660
|
+
end2end = metadata.get("args", {}).get("nms", False)
|
|
661
|
+
dynamic = metadata.get("args", {}).get("dynamic", dynamic)
|
|
662
|
+
ch = metadata.get("channels", 3)
|
|
663
|
+
elif not (pt or triton or nn_module):
|
|
664
|
+
LOGGER.warning(f"Metadata not found for 'model={w}'")
|
|
665
|
+
|
|
666
|
+
# Check names
|
|
667
|
+
if "names" not in locals(): # names missing
|
|
668
|
+
names = default_class_names(data)
|
|
669
|
+
names = check_class_names(names)
|
|
670
|
+
|
|
671
|
+
self.__dict__.update(locals()) # assign all variables to self
|
|
672
|
+
|
|
673
|
+
def forward(
|
|
674
|
+
self,
|
|
675
|
+
im: torch.Tensor,
|
|
676
|
+
augment: bool = False,
|
|
677
|
+
visualize: bool = False,
|
|
678
|
+
embed: list | None = None,
|
|
679
|
+
**kwargs: Any,
|
|
680
|
+
) -> torch.Tensor | list[torch.Tensor]:
|
|
681
|
+
"""Run inference on an AutoBackend model.
|
|
682
|
+
|
|
683
|
+
Args:
|
|
684
|
+
im (torch.Tensor): The image tensor to perform inference on.
|
|
685
|
+
augment (bool): Whether to perform data augmentation during inference.
|
|
686
|
+
visualize (bool): Whether to visualize the output predictions.
|
|
687
|
+
embed (list, optional): A list of feature vectors/embeddings to return.
|
|
688
|
+
**kwargs (Any): Additional keyword arguments for model configuration.
|
|
689
|
+
|
|
690
|
+
Returns:
|
|
691
|
+
(torch.Tensor | list[torch.Tensor]): The raw output tensor(s) from the model.
|
|
692
|
+
"""
|
|
693
|
+
_b, _ch, h, w = im.shape # batch, channel, height, width
|
|
694
|
+
if self.fp16 and im.dtype != torch.float16:
|
|
695
|
+
im = im.half() # to FP16
|
|
696
|
+
if self.nhwc:
|
|
697
|
+
im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3)
|
|
698
|
+
|
|
699
|
+
# PyTorch
|
|
700
|
+
if self.pt or self.nn_module:
|
|
701
|
+
y = self.model(im, augment=augment, visualize=visualize, embed=embed, **kwargs)
|
|
702
|
+
|
|
703
|
+
# TorchScript
|
|
704
|
+
elif self.jit:
|
|
705
|
+
y = self.model(im)
|
|
706
|
+
|
|
707
|
+
# ONNX OpenCV DNN
|
|
708
|
+
elif self.dnn:
|
|
709
|
+
im = im.cpu().numpy() # torch to numpy
|
|
710
|
+
self.net.setInput(im)
|
|
711
|
+
y = self.net.forward()
|
|
712
|
+
|
|
713
|
+
# ONNX Runtime
|
|
714
|
+
elif self.onnx or self.imx:
|
|
715
|
+
if self.use_io_binding:
|
|
716
|
+
if not self.cuda:
|
|
717
|
+
im = im.cpu()
|
|
718
|
+
self.io.bind_input(
|
|
719
|
+
name="images",
|
|
720
|
+
device_type=im.device.type,
|
|
721
|
+
device_id=im.device.index if im.device.type == "cuda" else 0,
|
|
722
|
+
element_type=np.float16 if self.fp16 else np.float32,
|
|
723
|
+
shape=tuple(im.shape),
|
|
724
|
+
buffer_ptr=im.data_ptr(),
|
|
725
|
+
)
|
|
726
|
+
self.session.run_with_iobinding(self.io)
|
|
727
|
+
y = self.bindings
|
|
728
|
+
else:
|
|
729
|
+
im = im.cpu().numpy() # torch to numpy
|
|
730
|
+
y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})
|
|
731
|
+
if self.imx:
|
|
732
|
+
if self.task == "detect":
|
|
733
|
+
# boxes, conf, cls
|
|
734
|
+
y = np.concatenate([y[0], y[1][:, :, None], y[2][:, :, None]], axis=-1)
|
|
735
|
+
elif self.task == "pose":
|
|
736
|
+
# boxes, conf, kpts
|
|
737
|
+
y = np.concatenate([y[0], y[1][:, :, None], y[2][:, :, None], y[3]], axis=-1, dtype=y[0].dtype)
|
|
738
|
+
elif self.task == "segment":
|
|
739
|
+
y = (
|
|
740
|
+
np.concatenate([y[0], y[1][:, :, None], y[2][:, :, None], y[3]], axis=-1, dtype=y[0].dtype),
|
|
741
|
+
y[4],
|
|
742
|
+
)
|
|
743
|
+
|
|
744
|
+
# OpenVINO
|
|
745
|
+
elif self.xml:
|
|
746
|
+
im = im.cpu().numpy() # FP32
|
|
747
|
+
|
|
748
|
+
if self.inference_mode in {"THROUGHPUT", "CUMULATIVE_THROUGHPUT"}: # optimized for larger batch-sizes
|
|
749
|
+
n = im.shape[0] # number of images in batch
|
|
750
|
+
results = [None] * n # preallocate list with None to match the number of images
|
|
751
|
+
|
|
752
|
+
def callback(request, userdata):
|
|
753
|
+
"""Place result in preallocated list using userdata index."""
|
|
754
|
+
results[userdata] = request.results
|
|
755
|
+
|
|
756
|
+
# Create AsyncInferQueue, set the callback and start asynchronous inference for each input image
|
|
757
|
+
async_queue = self.ov.AsyncInferQueue(self.ov_compiled_model)
|
|
758
|
+
async_queue.set_callback(callback)
|
|
759
|
+
for i in range(n):
|
|
760
|
+
# Start async inference with userdata=i to specify the position in results list
|
|
761
|
+
async_queue.start_async(inputs={self.input_name: im[i : i + 1]}, userdata=i) # keep image as BCHW
|
|
762
|
+
async_queue.wait_all() # wait for all inference requests to complete
|
|
763
|
+
y = [list(r.values()) for r in results]
|
|
764
|
+
y = [np.concatenate(x) for x in zip(*y)]
|
|
765
|
+
else: # inference_mode = "LATENCY", optimized for fastest first result at batch-size 1
|
|
766
|
+
y = list(self.ov_compiled_model(im).values())
|
|
767
|
+
|
|
768
|
+
# TensorRT
|
|
769
|
+
elif self.engine:
|
|
770
|
+
if self.dynamic and im.shape != self.bindings["images"].shape:
|
|
771
|
+
if self.is_trt10:
|
|
772
|
+
self.context.set_input_shape("images", im.shape)
|
|
773
|
+
self.bindings["images"] = self.bindings["images"]._replace(shape=im.shape)
|
|
774
|
+
for name in self.output_names:
|
|
775
|
+
self.bindings[name].data.resize_(tuple(self.context.get_tensor_shape(name)))
|
|
776
|
+
else:
|
|
777
|
+
i = self.model.get_binding_index("images")
|
|
778
|
+
self.context.set_binding_shape(i, im.shape)
|
|
779
|
+
self.bindings["images"] = self.bindings["images"]._replace(shape=im.shape)
|
|
780
|
+
for name in self.output_names:
|
|
781
|
+
i = self.model.get_binding_index(name)
|
|
782
|
+
self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i)))
|
|
783
|
+
|
|
784
|
+
s = self.bindings["images"].shape
|
|
785
|
+
assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}"
|
|
786
|
+
self.binding_addrs["images"] = int(im.data_ptr())
|
|
787
|
+
self.context.execute_v2(list(self.binding_addrs.values()))
|
|
788
|
+
y = [self.bindings[x].data for x in sorted(self.output_names)]
|
|
789
|
+
|
|
790
|
+
# CoreML
|
|
791
|
+
elif self.coreml:
|
|
792
|
+
im = im.cpu().numpy()
|
|
793
|
+
if self.dynamic:
|
|
794
|
+
im = im.transpose(0, 3, 1, 2)
|
|
795
|
+
else:
|
|
796
|
+
im = Image.fromarray((im[0] * 255).astype("uint8"))
|
|
797
|
+
# im = im.resize((192, 320), Image.BILINEAR)
|
|
798
|
+
y = self.model.predict({"image": im}) # coordinates are xywh normalized
|
|
799
|
+
if "confidence" in y: # NMS included
|
|
800
|
+
from ultralytics.utils.ops import xywh2xyxy
|
|
801
|
+
|
|
802
|
+
box = xywh2xyxy(y["coordinates"] * [[w, h, w, h]]) # xyxy pixels
|
|
803
|
+
cls = y["confidence"].argmax(1, keepdims=True)
|
|
804
|
+
y = np.concatenate((box, np.take_along_axis(y["confidence"], cls, axis=1), cls), 1)[None]
|
|
805
|
+
else:
|
|
806
|
+
y = list(y.values())
|
|
807
|
+
if len(y) == 2 and len(y[1].shape) != 4: # segmentation model
|
|
808
|
+
y = list(reversed(y)) # reversed for segmentation models (pred, proto)
|
|
809
|
+
|
|
810
|
+
# PaddlePaddle
|
|
811
|
+
elif self.paddle:
|
|
812
|
+
im = im.cpu().numpy().astype(np.float32)
|
|
813
|
+
self.input_handle.copy_from_cpu(im)
|
|
814
|
+
self.predictor.run()
|
|
815
|
+
y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names]
|
|
816
|
+
|
|
817
|
+
# MNN
|
|
818
|
+
elif self.mnn:
|
|
819
|
+
input_var = self.torch_to_mnn(im)
|
|
820
|
+
output_var = self.net.onForward([input_var])
|
|
821
|
+
y = [x.read() for x in output_var]
|
|
822
|
+
|
|
823
|
+
# NCNN
|
|
824
|
+
elif self.ncnn:
|
|
825
|
+
mat_in = self.pyncnn.Mat(im[0].cpu().numpy())
|
|
826
|
+
with self.net.create_extractor() as ex:
|
|
827
|
+
ex.input(self.net.input_names()[0], mat_in)
|
|
828
|
+
# WARNING: 'output_names' sorted as a temporary fix for https://github.com/pnnx/pnnx/issues/130
|
|
829
|
+
y = [np.array(ex.extract(x)[1])[None] for x in sorted(self.net.output_names())]
|
|
830
|
+
|
|
831
|
+
# NVIDIA Triton Inference Server
|
|
832
|
+
elif self.triton:
|
|
833
|
+
im = im.cpu().numpy() # torch to numpy
|
|
834
|
+
y = self.model(im)
|
|
835
|
+
|
|
836
|
+
# RKNN
|
|
837
|
+
elif self.rknn:
|
|
838
|
+
im = (im.cpu().numpy() * 255).astype("uint8")
|
|
839
|
+
im = im if isinstance(im, (list, tuple)) else [im]
|
|
840
|
+
y = self.rknn_model.inference(inputs=im)
|
|
841
|
+
|
|
842
|
+
# Axelera
|
|
843
|
+
elif self.axelera:
|
|
844
|
+
y = self.ax_model(im.cpu())
|
|
845
|
+
|
|
846
|
+
# ExecuTorch
|
|
847
|
+
elif self.pte:
|
|
848
|
+
y = self.model.execute([im])
|
|
849
|
+
|
|
850
|
+
# TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
|
|
851
|
+
else:
|
|
852
|
+
im = im.cpu().numpy()
|
|
853
|
+
if self.saved_model: # SavedModel
|
|
854
|
+
y = self.model.serving_default(im)
|
|
855
|
+
if not isinstance(y, list):
|
|
856
|
+
y = [y]
|
|
857
|
+
elif self.pb: # GraphDef
|
|
858
|
+
y = self.frozen_func(x=self.tf.constant(im))
|
|
859
|
+
else: # Lite or Edge TPU
|
|
860
|
+
details = self.input_details[0]
|
|
861
|
+
is_int = details["dtype"] in {np.int8, np.int16} # is TFLite quantized int8 or int16 model
|
|
862
|
+
if is_int:
|
|
863
|
+
scale, zero_point = details["quantization"]
|
|
864
|
+
im = (im / scale + zero_point).astype(details["dtype"]) # de-scale
|
|
865
|
+
self.interpreter.set_tensor(details["index"], im)
|
|
866
|
+
self.interpreter.invoke()
|
|
867
|
+
y = []
|
|
868
|
+
for output in self.output_details:
|
|
869
|
+
x = self.interpreter.get_tensor(output["index"])
|
|
870
|
+
if is_int:
|
|
871
|
+
scale, zero_point = output["quantization"]
|
|
872
|
+
x = (x.astype(np.float32) - zero_point) * scale # re-scale
|
|
873
|
+
if x.ndim == 3: # if task is not classification, excluding masks (ndim=4) as well
|
|
874
|
+
# Denormalize xywh by image size. See https://github.com/ultralytics/ultralytics/pull/1695
|
|
875
|
+
# xywh are normalized in TFLite/EdgeTPU to mitigate quantization error of integer models
|
|
876
|
+
if x.shape[-1] == 6 or self.end2end: # end-to-end model
|
|
877
|
+
x[:, :, [0, 2]] *= w
|
|
878
|
+
x[:, :, [1, 3]] *= h
|
|
879
|
+
if self.task == "pose":
|
|
880
|
+
x[:, :, 6::3] *= w
|
|
881
|
+
x[:, :, 7::3] *= h
|
|
882
|
+
else:
|
|
883
|
+
x[:, [0, 2]] *= w
|
|
884
|
+
x[:, [1, 3]] *= h
|
|
885
|
+
if self.task == "pose":
|
|
886
|
+
x[:, 5::3] *= w
|
|
887
|
+
x[:, 6::3] *= h
|
|
888
|
+
y.append(x)
|
|
889
|
+
# TF segment fixes: export is reversed vs ONNX export and protos are transposed
|
|
890
|
+
if len(y) == 2: # segment with (det, proto) output order reversed
|
|
891
|
+
if len(y[1].shape) != 4:
|
|
892
|
+
y = list(reversed(y)) # should be y = (1, 116, 8400), (1, 160, 160, 32)
|
|
893
|
+
if y[1].shape[-1] == 6: # end-to-end model
|
|
894
|
+
y = [y[1]]
|
|
895
|
+
else:
|
|
896
|
+
y[1] = np.transpose(y[1], (0, 3, 1, 2)) # should be y = (1, 116, 8400), (1, 32, 160, 160)
|
|
897
|
+
y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]
|
|
898
|
+
|
|
899
|
+
if isinstance(y, (list, tuple)):
|
|
900
|
+
if len(self.names) == 999 and (self.task == "segment" or len(y) == 2): # segments and names not defined
|
|
901
|
+
nc = y[0].shape[1] - y[1].shape[1] - 4 # y = (1, 32, 160, 160), (1, 116, 8400)
|
|
902
|
+
self.names = {i: f"class{i}" for i in range(nc)}
|
|
903
|
+
return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y]
|
|
904
|
+
else:
|
|
905
|
+
return self.from_numpy(y)
|
|
906
|
+
|
|
907
|
+
def from_numpy(self, x: np.ndarray | torch.Tensor) -> torch.Tensor:
|
|
908
|
+
"""Convert a NumPy array to a torch tensor on the model device.
|
|
909
|
+
|
|
910
|
+
Args:
|
|
911
|
+
x (np.ndarray | torch.Tensor): Input array or tensor.
|
|
912
|
+
|
|
913
|
+
Returns:
|
|
914
|
+
(torch.Tensor): Tensor on `self.device`.
|
|
915
|
+
"""
|
|
916
|
+
return torch.tensor(x).to(self.device) if isinstance(x, np.ndarray) else x
|
|
917
|
+
|
|
918
|
+
def warmup(self, imgsz: tuple[int, int, int, int] = (1, 3, 640, 640)) -> None:
|
|
919
|
+
"""Warm up the model by running one forward pass with a dummy input.
|
|
920
|
+
|
|
921
|
+
Args:
|
|
922
|
+
imgsz (tuple[int, int, int, int]): Dummy input shape in (batch, channels, height, width) format.
|
|
923
|
+
"""
|
|
924
|
+
warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton, self.nn_module
|
|
925
|
+
if any(warmup_types) and (self.device.type != "cpu" or self.triton):
|
|
926
|
+
im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
|
|
927
|
+
for _ in range(2 if self.jit else 1):
|
|
928
|
+
self.forward(im) # warmup model
|
|
929
|
+
warmup_boxes = torch.rand(1, 84, 16, device=self.device) # 16 boxes works best empirically
|
|
930
|
+
warmup_boxes[:, :4] *= imgsz[-1]
|
|
931
|
+
non_max_suppression(warmup_boxes) # warmup NMS
|
|
932
|
+
|
|
933
|
+
@staticmethod
|
|
934
|
+
def _model_type(p: str = "path/to/model.pt") -> list[bool]:
|
|
935
|
+
"""Take a path to a model file and return the model type.
|
|
936
|
+
|
|
937
|
+
Args:
|
|
938
|
+
p (str): Path to the model file.
|
|
939
|
+
|
|
940
|
+
Returns:
|
|
941
|
+
(list[bool]): List of booleans indicating the model type.
|
|
942
|
+
|
|
943
|
+
Examples:
|
|
944
|
+
>>> types = AutoBackend._model_type("path/to/model.onnx")
|
|
945
|
+
>>> assert types[2] # onnx
|
|
946
|
+
"""
|
|
947
|
+
from ultralytics.engine.exporter import export_formats
|
|
948
|
+
|
|
949
|
+
sf = export_formats()["Suffix"] # export suffixes
|
|
950
|
+
if not is_url(p) and not isinstance(p, str):
|
|
951
|
+
check_suffix(p, sf) # checks
|
|
952
|
+
name = Path(p).name
|
|
953
|
+
types = [s in name for s in sf]
|
|
954
|
+
types[5] |= name.endswith(".mlmodel") # retain support for older Apple CoreML *.mlmodel formats
|
|
955
|
+
types[8] &= not types[9] # tflite &= not edgetpu
|
|
956
|
+
if any(types):
|
|
957
|
+
triton = False
|
|
958
|
+
else:
|
|
959
|
+
from urllib.parse import urlsplit
|
|
960
|
+
|
|
961
|
+
url = urlsplit(p)
|
|
962
|
+
triton = bool(url.netloc) and bool(url.path) and url.scheme in {"http", "grpc"}
|
|
963
|
+
|
|
964
|
+
return [*types, triton]
|