inference-models 0.18.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inference_models/__init__.py +36 -0
- inference_models/configuration.py +72 -0
- inference_models/constants.py +2 -0
- inference_models/entities.py +5 -0
- inference_models/errors.py +137 -0
- inference_models/logger.py +52 -0
- inference_models/model_pipelines/__init__.py +0 -0
- inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
- inference_models/model_pipelines/auto_loaders/core.py +120 -0
- inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
- inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
- inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
- inference_models/models/__init__.py +0 -0
- inference_models/models/auto_loaders/__init__.py +0 -0
- inference_models/models/auto_loaders/access_manager.py +168 -0
- inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
- inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
- inference_models/models/auto_loaders/constants.py +7 -0
- inference_models/models/auto_loaders/core.py +1341 -0
- inference_models/models/auto_loaders/dependency_models.py +52 -0
- inference_models/models/auto_loaders/entities.py +57 -0
- inference_models/models/auto_loaders/models_registry.py +497 -0
- inference_models/models/auto_loaders/presentation_utils.py +333 -0
- inference_models/models/auto_loaders/ranking.py +413 -0
- inference_models/models/auto_loaders/utils.py +31 -0
- inference_models/models/base/__init__.py +0 -0
- inference_models/models/base/classification.py +123 -0
- inference_models/models/base/depth_estimation.py +62 -0
- inference_models/models/base/documents_parsing.py +111 -0
- inference_models/models/base/embeddings.py +66 -0
- inference_models/models/base/instance_segmentation.py +87 -0
- inference_models/models/base/keypoints_detection.py +93 -0
- inference_models/models/base/object_detection.py +143 -0
- inference_models/models/base/semantic_segmentation.py +74 -0
- inference_models/models/base/types.py +5 -0
- inference_models/models/clip/__init__.py +0 -0
- inference_models/models/clip/clip_onnx.py +148 -0
- inference_models/models/clip/clip_pytorch.py +104 -0
- inference_models/models/clip/preprocessing.py +162 -0
- inference_models/models/common/__init__.py +0 -0
- inference_models/models/common/cuda.py +30 -0
- inference_models/models/common/model_packages.py +25 -0
- inference_models/models/common/onnx.py +379 -0
- inference_models/models/common/roboflow/__init__.py +0 -0
- inference_models/models/common/roboflow/model_packages.py +361 -0
- inference_models/models/common/roboflow/post_processing.py +436 -0
- inference_models/models/common/roboflow/pre_processing.py +1332 -0
- inference_models/models/common/torch.py +20 -0
- inference_models/models/common/trt.py +266 -0
- inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
- inference_models/models/depth_anything_v2/__init__.py +0 -0
- inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
- inference_models/models/dinov3/__init__.py +0 -0
- inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
- inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
- inference_models/models/doctr/__init__.py +0 -0
- inference_models/models/doctr/doctr_torch.py +304 -0
- inference_models/models/easy_ocr/__init__.py +0 -0
- inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
- inference_models/models/florence2/__init__.py +0 -0
- inference_models/models/florence2/florence2_hf.py +897 -0
- inference_models/models/grounding_dino/__init__.py +0 -0
- inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
- inference_models/models/l2cs/__init__.py +0 -0
- inference_models/models/l2cs/l2cs_onnx.py +216 -0
- inference_models/models/mediapipe_face_detection/__init__.py +0 -0
- inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
- inference_models/models/moondream2/__init__.py +0 -0
- inference_models/models/moondream2/moondream2_hf.py +281 -0
- inference_models/models/owlv2/__init__.py +0 -0
- inference_models/models/owlv2/cache.py +182 -0
- inference_models/models/owlv2/entities.py +112 -0
- inference_models/models/owlv2/owlv2_hf.py +695 -0
- inference_models/models/owlv2/reference_dataset.py +291 -0
- inference_models/models/paligemma/__init__.py +0 -0
- inference_models/models/paligemma/paligemma_hf.py +209 -0
- inference_models/models/perception_encoder/__init__.py +0 -0
- inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
- inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
- inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
- inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
- inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
- inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
- inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
- inference_models/models/qwen25vl/__init__.py +1 -0
- inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
- inference_models/models/resnet/__init__.py +0 -0
- inference_models/models/resnet/resnet_classification_onnx.py +330 -0
- inference_models/models/resnet/resnet_classification_torch.py +305 -0
- inference_models/models/resnet/resnet_classification_trt.py +369 -0
- inference_models/models/rfdetr/__init__.py +0 -0
- inference_models/models/rfdetr/backbone_builder.py +101 -0
- inference_models/models/rfdetr/class_remapping.py +41 -0
- inference_models/models/rfdetr/common.py +115 -0
- inference_models/models/rfdetr/default_labels.py +108 -0
- inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
- inference_models/models/rfdetr/misc.py +26 -0
- inference_models/models/rfdetr/ms_deform_attn.py +180 -0
- inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
- inference_models/models/rfdetr/position_encoding.py +166 -0
- inference_models/models/rfdetr/post_processor.py +83 -0
- inference_models/models/rfdetr/projector.py +373 -0
- inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
- inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
- inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
- inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
- inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
- inference_models/models/rfdetr/segmentation_head.py +273 -0
- inference_models/models/rfdetr/transformer.py +767 -0
- inference_models/models/roboflow_instant/__init__.py +0 -0
- inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
- inference_models/models/sam/__init__.py +0 -0
- inference_models/models/sam/cache.py +147 -0
- inference_models/models/sam/entities.py +25 -0
- inference_models/models/sam/sam_torch.py +675 -0
- inference_models/models/sam2/__init__.py +0 -0
- inference_models/models/sam2/cache.py +162 -0
- inference_models/models/sam2/entities.py +43 -0
- inference_models/models/sam2/sam2_torch.py +905 -0
- inference_models/models/sam2_rt/__init__.py +0 -0
- inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
- inference_models/models/smolvlm/__init__.py +0 -0
- inference_models/models/smolvlm/smolvlm_hf.py +245 -0
- inference_models/models/trocr/__init__.py +0 -0
- inference_models/models/trocr/trocr_hf.py +53 -0
- inference_models/models/vit/__init__.py +0 -0
- inference_models/models/vit/vit_classification_huggingface.py +319 -0
- inference_models/models/vit/vit_classification_onnx.py +326 -0
- inference_models/models/vit/vit_classification_trt.py +365 -0
- inference_models/models/yolact/__init__.py +1 -0
- inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
- inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
- inference_models/models/yolo_world/__init__.py +1 -0
- inference_models/models/yolonas/__init__.py +0 -0
- inference_models/models/yolonas/nms.py +44 -0
- inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
- inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
- inference_models/models/yolov10/__init__.py +0 -0
- inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
- inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
- inference_models/models/yolov11/__init__.py +0 -0
- inference_models/models/yolov11/yolov11_onnx.py +28 -0
- inference_models/models/yolov11/yolov11_torch_script.py +25 -0
- inference_models/models/yolov11/yolov11_trt.py +21 -0
- inference_models/models/yolov12/__init__.py +0 -0
- inference_models/models/yolov12/yolov12_onnx.py +7 -0
- inference_models/models/yolov12/yolov12_torch_script.py +7 -0
- inference_models/models/yolov12/yolov12_trt.py +7 -0
- inference_models/models/yolov5/__init__.py +0 -0
- inference_models/models/yolov5/nms.py +99 -0
- inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
- inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
- inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
- inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
- inference_models/models/yolov7/__init__.py +0 -0
- inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
- inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
- inference_models/models/yolov8/__init__.py +0 -0
- inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
- inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
- inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
- inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
- inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
- inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
- inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
- inference_models/models/yolov9/__init__.py +0 -0
- inference_models/models/yolov9/yolov9_onnx.py +7 -0
- inference_models/models/yolov9/yolov9_torch_script.py +7 -0
- inference_models/models/yolov9/yolov9_trt.py +7 -0
- inference_models/runtime_introspection/__init__.py +0 -0
- inference_models/runtime_introspection/core.py +410 -0
- inference_models/utils/__init__.py +0 -0
- inference_models/utils/download.py +608 -0
- inference_models/utils/environment.py +28 -0
- inference_models/utils/file_system.py +51 -0
- inference_models/utils/hashing.py +7 -0
- inference_models/utils/imports.py +48 -0
- inference_models/utils/onnx_introspection.py +17 -0
- inference_models/weights_providers/__init__.py +0 -0
- inference_models/weights_providers/core.py +20 -0
- inference_models/weights_providers/entities.py +159 -0
- inference_models/weights_providers/roboflow.py +601 -0
- inference_models-0.18.3.dist-info/METADATA +466 -0
- inference_models-0.18.3.dist-info/RECORD +195 -0
- inference_models-0.18.3.dist-info/WHEEL +5 -0
- inference_models-0.18.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,379 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Optional, Union
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from inference_models.errors import MissingDependencyError, ModelRuntimeError
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
import onnxruntime
|
|
10
|
+
except ImportError as import_error:
|
|
11
|
+
raise MissingDependencyError(
|
|
12
|
+
message=f"Could not import onnx tools required to run models with ONNX backend - this error means that some additional "
|
|
13
|
+
f"dependencies are not installed in the environment. If you run the `inference-models` library directly in your "
|
|
14
|
+
f"Python program, make sure the following extras of the package are installed: \n"
|
|
15
|
+
f"\t* `onnx-cpu` - when you wish to use library with CPU support only\n"
|
|
16
|
+
f"\t* `onnx-cu12` - for running on GPU with Cuda 12 installed\n"
|
|
17
|
+
f"\t* `onnx-cu118` - for running on GPU with Cuda 11.8 installed\n"
|
|
18
|
+
f"\t* `onnx-jp6-cu126` - for running on Jetson with Jetpack 6\n"
|
|
19
|
+
f"If you see this error using Roboflow infrastructure, make sure the service you use does support the model. "
|
|
20
|
+
f"You can also contact Roboflow to get support.",
|
|
21
|
+
help_url="https://todo",
|
|
22
|
+
) from import_error
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
TORCH_TYPES_MAPPING = {
|
|
26
|
+
torch.float32: np.float32,
|
|
27
|
+
torch.float16: np.float16,
|
|
28
|
+
torch.int64: np.int64,
|
|
29
|
+
torch.int32: np.int32,
|
|
30
|
+
torch.uint8: np.uint8,
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
ORT_TYPES_TO_TORCH_TYPES_MAPPING = {
|
|
34
|
+
"tensor(float)": torch.float32,
|
|
35
|
+
"tensor(float16)": torch.float16,
|
|
36
|
+
"tensor(double)": torch.float64,
|
|
37
|
+
"tensor(int32)": torch.int32,
|
|
38
|
+
"tensor(int64)": torch.int64,
|
|
39
|
+
"tensor(int16)": torch.int16,
|
|
40
|
+
"tensor(int8)": torch.int8,
|
|
41
|
+
"tensor(uint8)": torch.uint8,
|
|
42
|
+
"tensor(uint16)": torch.uint16,
|
|
43
|
+
"tensor(uint32)": torch.uint32,
|
|
44
|
+
"tensor(uint64)": torch.uint64,
|
|
45
|
+
"tensor(bool)": torch.bool,
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
MODEL_INPUT_CASTING = {
|
|
49
|
+
torch.float16: {torch.float16, torch.float32, torch.float64},
|
|
50
|
+
torch.float32: {torch.float16, torch.float32, torch.float64},
|
|
51
|
+
torch.int8: {
|
|
52
|
+
torch.int8,
|
|
53
|
+
torch.int16,
|
|
54
|
+
torch.int32,
|
|
55
|
+
torch.int64,
|
|
56
|
+
torch.float64,
|
|
57
|
+
torch.float32,
|
|
58
|
+
torch.float16,
|
|
59
|
+
},
|
|
60
|
+
torch.int16: {
|
|
61
|
+
torch.int16,
|
|
62
|
+
torch.int32,
|
|
63
|
+
torch.int64,
|
|
64
|
+
torch.float16,
|
|
65
|
+
torch.float32,
|
|
66
|
+
torch.float64,
|
|
67
|
+
},
|
|
68
|
+
torch.int32: {
|
|
69
|
+
torch.int32,
|
|
70
|
+
torch.int64,
|
|
71
|
+
torch.float16,
|
|
72
|
+
torch.float32,
|
|
73
|
+
torch.float64,
|
|
74
|
+
},
|
|
75
|
+
torch.uint8: {
|
|
76
|
+
torch.uint8,
|
|
77
|
+
torch.int16,
|
|
78
|
+
torch.int32,
|
|
79
|
+
torch.int64,
|
|
80
|
+
torch.float16,
|
|
81
|
+
torch.float32,
|
|
82
|
+
torch.float64,
|
|
83
|
+
},
|
|
84
|
+
torch.bool: {torch.uint8, torch.int8, torch.float16, torch.float32, torch.float64},
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def set_execution_provider_defaults(
|
|
89
|
+
providers: List[Union[str, tuple]],
|
|
90
|
+
model_package_path: str,
|
|
91
|
+
device: torch.device,
|
|
92
|
+
enable_fp16: bool = True,
|
|
93
|
+
default_onnx_trt_options: bool = True,
|
|
94
|
+
) -> List[Union[str, tuple[str, dict[str, Any]]]]:
|
|
95
|
+
result = []
|
|
96
|
+
device_id_options = {}
|
|
97
|
+
if device.index is not None:
|
|
98
|
+
device_id_options["device_id"] = device.index
|
|
99
|
+
for provider in providers:
|
|
100
|
+
if provider == "TensorrtExecutionProvider" and default_onnx_trt_options:
|
|
101
|
+
provider = (
|
|
102
|
+
"TensorrtExecutionProvider",
|
|
103
|
+
{
|
|
104
|
+
"trt_engine_cache_enable": True,
|
|
105
|
+
"trt_engine_cache_path": model_package_path,
|
|
106
|
+
"trt_fp16_enable": enable_fp16,
|
|
107
|
+
**device_id_options,
|
|
108
|
+
},
|
|
109
|
+
)
|
|
110
|
+
if provider == "CUDAExecutionProvider":
|
|
111
|
+
provider = ("CUDAExecutionProvider", device_id_options)
|
|
112
|
+
result.append(provider)
|
|
113
|
+
return result
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def run_session_with_batch_size_limit(
|
|
117
|
+
session: onnxruntime.InferenceSession,
|
|
118
|
+
inputs: Dict[str, torch.Tensor],
|
|
119
|
+
output_shape_mapping: Optional[Dict[str, tuple]] = None,
|
|
120
|
+
max_batch_size: Optional[int] = None,
|
|
121
|
+
min_batch_size: Optional[int] = None,
|
|
122
|
+
) -> List[torch.Tensor]:
|
|
123
|
+
if max_batch_size is None:
|
|
124
|
+
return run_session_via_iobinding(
|
|
125
|
+
session=session,
|
|
126
|
+
inputs=inputs,
|
|
127
|
+
output_shape_mapping=output_shape_mapping,
|
|
128
|
+
)
|
|
129
|
+
input_batch_sizes = set()
|
|
130
|
+
for input_tensor in inputs.values():
|
|
131
|
+
input_batch_sizes.add(input_tensor.shape[0])
|
|
132
|
+
if len(input_batch_sizes) != 1:
|
|
133
|
+
raise ModelRuntimeError(
|
|
134
|
+
message="When running forward pass through ONNX model detected inputs with different batch sizes. "
|
|
135
|
+
"This is the error with the model you run. If the model was trained or exported "
|
|
136
|
+
"on Roboflow platform - contact us to get help. Otherwise, verify your model package or "
|
|
137
|
+
"implementation of the model class.",
|
|
138
|
+
help_url="https://todo",
|
|
139
|
+
)
|
|
140
|
+
input_batch_size = input_batch_sizes.pop()
|
|
141
|
+
if min_batch_size is None and input_batch_size <= max_batch_size:
|
|
142
|
+
# no point iterating
|
|
143
|
+
return run_session_via_iobinding(
|
|
144
|
+
session=session,
|
|
145
|
+
inputs=inputs,
|
|
146
|
+
output_shape_mapping=output_shape_mapping,
|
|
147
|
+
)
|
|
148
|
+
all_results = []
|
|
149
|
+
for _ in session.get_outputs():
|
|
150
|
+
all_results.append([])
|
|
151
|
+
for i in range(0, input_batch_size, max_batch_size):
|
|
152
|
+
batch_inputs = {}
|
|
153
|
+
reminder = 0
|
|
154
|
+
for name, value in inputs.items():
|
|
155
|
+
batched_value = value[i : i + max_batch_size]
|
|
156
|
+
if min_batch_size is not None:
|
|
157
|
+
reminder = min_batch_size - batched_value.shape[0]
|
|
158
|
+
if reminder > 0:
|
|
159
|
+
batched_value = torch.cat(
|
|
160
|
+
(
|
|
161
|
+
batched_value,
|
|
162
|
+
torch.zeros(
|
|
163
|
+
(reminder,) + batched_value.shape[1:],
|
|
164
|
+
dtype=batched_value.dtype,
|
|
165
|
+
device=batched_value.device,
|
|
166
|
+
),
|
|
167
|
+
),
|
|
168
|
+
dim=0,
|
|
169
|
+
)
|
|
170
|
+
batched_value = batched_value.contiguous()
|
|
171
|
+
batch_inputs[name] = batched_value
|
|
172
|
+
batch_output_shape_mapping = None
|
|
173
|
+
if output_shape_mapping:
|
|
174
|
+
batch_output_shape_mapping = {}
|
|
175
|
+
for name, shape in output_shape_mapping.items():
|
|
176
|
+
batch_output_shape_mapping[name] = (max_batch_size,) + shape[1:]
|
|
177
|
+
batch_results = run_session_via_iobinding(
|
|
178
|
+
session=session,
|
|
179
|
+
inputs=batch_inputs,
|
|
180
|
+
output_shape_mapping=batch_output_shape_mapping,
|
|
181
|
+
)
|
|
182
|
+
if reminder > 0:
|
|
183
|
+
batch_results = [r[:-reminder] for r in batch_results]
|
|
184
|
+
for partial_result, all_result_element in zip(batch_results, all_results):
|
|
185
|
+
all_result_element.append(partial_result)
|
|
186
|
+
return [torch.cat(e, dim=0).contiguous() for e in all_results]
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def run_session_via_iobinding(
|
|
190
|
+
session: onnxruntime.InferenceSession,
|
|
191
|
+
inputs: Dict[str, torch.Tensor],
|
|
192
|
+
output_shape_mapping: Optional[Dict[str, tuple]] = None,
|
|
193
|
+
) -> List[torch.Tensor]:
|
|
194
|
+
inputs = auto_cast_session_inputs(
|
|
195
|
+
session=session,
|
|
196
|
+
inputs=inputs,
|
|
197
|
+
)
|
|
198
|
+
device = get_input_device(inputs=inputs)
|
|
199
|
+
if device.type != "cuda":
|
|
200
|
+
inputs_np = {name: value.cpu().numpy() for name, value in inputs.items()}
|
|
201
|
+
results = session.run(None, inputs_np)
|
|
202
|
+
return [torch.from_numpy(element).to(device=device) for element in results]
|
|
203
|
+
try:
|
|
204
|
+
import pycuda.driver as cuda
|
|
205
|
+
|
|
206
|
+
from inference_models.models.common.cuda import use_primary_cuda_context
|
|
207
|
+
except ImportError as import_error:
|
|
208
|
+
raise MissingDependencyError(
|
|
209
|
+
message="TODO", help_url="https://todo"
|
|
210
|
+
) from import_error
|
|
211
|
+
cuda.init()
|
|
212
|
+
cuda_device = cuda.Device(device.index or 0)
|
|
213
|
+
with use_primary_cuda_context(cuda_device=cuda_device):
|
|
214
|
+
if output_shape_mapping is None:
|
|
215
|
+
output_shape_mapping = {}
|
|
216
|
+
binding = session.io_binding()
|
|
217
|
+
pre_allocated_outputs: List[Optional[torch.Tensor]] = []
|
|
218
|
+
some_outputs_dynamically_allocated = False
|
|
219
|
+
for output in session.get_outputs():
|
|
220
|
+
if is_tensor_shape_dynamic(output.shape):
|
|
221
|
+
if output.name in output_shape_mapping:
|
|
222
|
+
torch_output_type = ort_tensor_type_to_torch_tensor_type(
|
|
223
|
+
output.type
|
|
224
|
+
)
|
|
225
|
+
pre_allocated_output = torch.empty(
|
|
226
|
+
output_shape_mapping[output.name],
|
|
227
|
+
dtype=torch_output_type,
|
|
228
|
+
device=device,
|
|
229
|
+
)
|
|
230
|
+
binding.bind_output(
|
|
231
|
+
name=output.name,
|
|
232
|
+
device_type="cuda",
|
|
233
|
+
device_id=device.index or 0,
|
|
234
|
+
element_type=torch_tensor_type_to_onnx_type(torch_output_type),
|
|
235
|
+
shape=tuple(pre_allocated_output.shape),
|
|
236
|
+
buffer_ptr=pre_allocated_output.data_ptr(),
|
|
237
|
+
)
|
|
238
|
+
pre_allocated_outputs.append(pre_allocated_output)
|
|
239
|
+
else:
|
|
240
|
+
binding.bind_output(
|
|
241
|
+
name=output.name,
|
|
242
|
+
device_type="cuda",
|
|
243
|
+
device_id=device.index or 0,
|
|
244
|
+
)
|
|
245
|
+
some_outputs_dynamically_allocated = True
|
|
246
|
+
pre_allocated_outputs.append(None)
|
|
247
|
+
else:
|
|
248
|
+
torch_output_type = ort_tensor_type_to_torch_tensor_type(output.type)
|
|
249
|
+
pre_allocated_output = torch.empty(
|
|
250
|
+
output.shape,
|
|
251
|
+
dtype=torch_output_type,
|
|
252
|
+
device=device,
|
|
253
|
+
)
|
|
254
|
+
binding.bind_output(
|
|
255
|
+
name=output.name,
|
|
256
|
+
device_type="cuda",
|
|
257
|
+
device_id=device.index or 0,
|
|
258
|
+
element_type=torch_tensor_type_to_onnx_type(torch_output_type),
|
|
259
|
+
shape=tuple(pre_allocated_output.shape),
|
|
260
|
+
buffer_ptr=pre_allocated_output.data_ptr(),
|
|
261
|
+
)
|
|
262
|
+
pre_allocated_outputs.append(pre_allocated_output)
|
|
263
|
+
for ort_input in session.get_inputs():
|
|
264
|
+
input_tensor = inputs[ort_input.name].contiguous()
|
|
265
|
+
input_type = torch_tensor_type_to_onnx_type(tensor_dtype=input_tensor.dtype)
|
|
266
|
+
binding.bind_input(
|
|
267
|
+
name=ort_input.name,
|
|
268
|
+
device_type=input_tensor.device.type,
|
|
269
|
+
device_id=input_tensor.device.index or 0,
|
|
270
|
+
element_type=input_type,
|
|
271
|
+
shape=input_tensor.shape,
|
|
272
|
+
buffer_ptr=input_tensor.data_ptr(),
|
|
273
|
+
)
|
|
274
|
+
binding.synchronize_inputs()
|
|
275
|
+
session.run_with_iobinding(binding)
|
|
276
|
+
if not some_outputs_dynamically_allocated:
|
|
277
|
+
return pre_allocated_outputs
|
|
278
|
+
bound_outputs = binding.get_outputs()
|
|
279
|
+
result = []
|
|
280
|
+
for pre_allocated_output, bound_output in zip(
|
|
281
|
+
pre_allocated_outputs, bound_outputs
|
|
282
|
+
):
|
|
283
|
+
if pre_allocated_output is not None:
|
|
284
|
+
result.append(pre_allocated_output)
|
|
285
|
+
continue
|
|
286
|
+
# This is added for the sake of true compatibility with older builds of onnxruntime
|
|
287
|
+
# which do not support zero-copy OrtValue -> torch.Tensor thanks top dlpack
|
|
288
|
+
if not hasattr(bound_output._ortvalue, "to_dlpack"):
|
|
289
|
+
# slower but needed :(
|
|
290
|
+
out_tensor = torch.from_numpy(bound_output._ortvalue.numpy()).to(device)
|
|
291
|
+
else:
|
|
292
|
+
dlpack_tensor = bound_output._ortvalue.to_dlpack()
|
|
293
|
+
out_tensor = torch.utils.dlpack.from_dlpack(dlpack_tensor)
|
|
294
|
+
result.append(out_tensor)
|
|
295
|
+
return result
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def auto_cast_session_inputs(
|
|
299
|
+
session: onnxruntime.InferenceSession, inputs: Dict[str, torch.Tensor]
|
|
300
|
+
) -> Dict[str, torch.Tensor]:
|
|
301
|
+
for ort_input in session.get_inputs():
|
|
302
|
+
expected_type = ort_tensor_type_to_torch_tensor_type(ort_input.type)
|
|
303
|
+
if ort_input.name not in inputs:
|
|
304
|
+
raise ModelRuntimeError(
|
|
305
|
+
message="While performing forward pass through the model, library bug was discovered - "
|
|
306
|
+
f"required model input named '{ort_input.name}' is missing. Submit "
|
|
307
|
+
f"issue to help us solving this problem: https://github.com/roboflow/inference/issues",
|
|
308
|
+
help_url="https://todo",
|
|
309
|
+
)
|
|
310
|
+
actual_type = inputs[ort_input.name].dtype
|
|
311
|
+
if actual_type == expected_type:
|
|
312
|
+
continue
|
|
313
|
+
if not can_model_input_be_casted(source=actual_type, target=expected_type):
|
|
314
|
+
raise ModelRuntimeError(
|
|
315
|
+
message="While performing forward pass through the model, library bug was discovered - "
|
|
316
|
+
f"model requires the input type to be {expected_type}, but the actual input type is {actual_type} - "
|
|
317
|
+
f"this is a bug in model implementation. Submit issue to help us solving this problem: "
|
|
318
|
+
f"https://github.com/roboflow/inference/issues",
|
|
319
|
+
help_url="https://todo",
|
|
320
|
+
)
|
|
321
|
+
inputs[ort_input.name] = inputs[ort_input.name].to(dtype=expected_type)
|
|
322
|
+
return inputs
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
def torch_tensor_type_to_onnx_type(tensor_dtype: torch.dtype) -> Union[np.dtype, int]:
|
|
326
|
+
if tensor_dtype not in TORCH_TYPES_MAPPING:
|
|
327
|
+
raise ModelRuntimeError(
|
|
328
|
+
message=f"While performing forward pass through the model, library discovered tensor of type {tensor_dtype} "
|
|
329
|
+
f"which needs to be passed to onnxruntime session. Conversion of this type is currently not "
|
|
330
|
+
f"supported in inference. At the moment you shall assume your model incompatible with the library. "
|
|
331
|
+
f"To change that state - please submit new issue: https://github.com/roboflow/inference/issues",
|
|
332
|
+
help_url="https://todo",
|
|
333
|
+
)
|
|
334
|
+
return TORCH_TYPES_MAPPING[tensor_dtype]
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
def ort_tensor_type_to_torch_tensor_type(ort_dtype: str) -> torch.dtype:
|
|
338
|
+
if ort_dtype not in ORT_TYPES_TO_TORCH_TYPES_MAPPING:
|
|
339
|
+
raise ModelRuntimeError(
|
|
340
|
+
message=f"While performing forward pass through the model, library discovered ORT tensor of type {ort_dtype} "
|
|
341
|
+
f"which needs to be casted into torch.Tensor. Conversion of this type is currently not "
|
|
342
|
+
f"supported in inference. At the moment you shall assume your model incompatible with the library. "
|
|
343
|
+
f"To change that state - please submit new issue: https://github.com/roboflow/inference/issues",
|
|
344
|
+
help_url="https://todo",
|
|
345
|
+
)
|
|
346
|
+
return ORT_TYPES_TO_TORCH_TYPES_MAPPING[ort_dtype]
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
def is_tensor_shape_dynamic(shape: tuple) -> bool:
|
|
350
|
+
return any(isinstance(dim, str) for dim in shape)
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
def can_model_input_be_casted(source: torch.dtype, target: torch.dtype) -> bool:
|
|
354
|
+
if source not in MODEL_INPUT_CASTING:
|
|
355
|
+
return False
|
|
356
|
+
return target in MODEL_INPUT_CASTING[source]
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
def get_input_device(inputs: Dict[str, torch.Tensor]) -> torch.device:
|
|
360
|
+
device = None
|
|
361
|
+
for input_name, input_tensor in inputs.items():
|
|
362
|
+
if device is None:
|
|
363
|
+
device = input_tensor.device
|
|
364
|
+
elif input_tensor.device != device:
|
|
365
|
+
raise ModelRuntimeError(
|
|
366
|
+
message="While performing forward pass through the model, library discovered the input tensor which is "
|
|
367
|
+
f"wrongly allocated on a different device that rest of the inputs - input named '{input_name}' "
|
|
368
|
+
f"is allocated on {input_tensor.device}, whereas rest of the inputs are allocated on {device}. "
|
|
369
|
+
f"This is a bug in model implementation. To help us fixing that, please submit new issue: "
|
|
370
|
+
f"https://github.com/roboflow/inference/issues",
|
|
371
|
+
help_url="https://todo",
|
|
372
|
+
)
|
|
373
|
+
if device is None:
|
|
374
|
+
raise ModelRuntimeError(
|
|
375
|
+
message="No inputs detected for the model. Raise new issue to help us fixing the problem: "
|
|
376
|
+
"https://github.com/roboflow/inference/issues",
|
|
377
|
+
help_url="https://todo",
|
|
378
|
+
)
|
|
379
|
+
return device
|
|
File without changes
|