inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,379 @@
1
+ from typing import Any, Dict, List, Optional, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from inference_models.errors import MissingDependencyError, ModelRuntimeError
7
+
8
+ try:
9
+ import onnxruntime
10
+ except ImportError as import_error:
11
+ raise MissingDependencyError(
12
+ message=f"Could not import onnx tools required to run models with ONNX backend - this error means that some additional "
13
+ f"dependencies are not installed in the environment. If you run the `inference-models` library directly in your "
14
+ f"Python program, make sure the following extras of the package are installed: \n"
15
+ f"\t* `onnx-cpu` - when you wish to use library with CPU support only\n"
16
+ f"\t* `onnx-cu12` - for running on GPU with Cuda 12 installed\n"
17
+ f"\t* `onnx-cu118` - for running on GPU with Cuda 11.8 installed\n"
18
+ f"\t* `onnx-jp6-cu126` - for running on Jetson with Jetpack 6\n"
19
+ f"If you see this error using Roboflow infrastructure, make sure the service you use does support the model. "
20
+ f"You can also contact Roboflow to get support.",
21
+ help_url="https://todo",
22
+ ) from import_error
23
+
24
+
25
+ TORCH_TYPES_MAPPING = {
26
+ torch.float32: np.float32,
27
+ torch.float16: np.float16,
28
+ torch.int64: np.int64,
29
+ torch.int32: np.int32,
30
+ torch.uint8: np.uint8,
31
+ }
32
+
33
+ ORT_TYPES_TO_TORCH_TYPES_MAPPING = {
34
+ "tensor(float)": torch.float32,
35
+ "tensor(float16)": torch.float16,
36
+ "tensor(double)": torch.float64,
37
+ "tensor(int32)": torch.int32,
38
+ "tensor(int64)": torch.int64,
39
+ "tensor(int16)": torch.int16,
40
+ "tensor(int8)": torch.int8,
41
+ "tensor(uint8)": torch.uint8,
42
+ "tensor(uint16)": torch.uint16,
43
+ "tensor(uint32)": torch.uint32,
44
+ "tensor(uint64)": torch.uint64,
45
+ "tensor(bool)": torch.bool,
46
+ }
47
+
48
+ MODEL_INPUT_CASTING = {
49
+ torch.float16: {torch.float16, torch.float32, torch.float64},
50
+ torch.float32: {torch.float16, torch.float32, torch.float64},
51
+ torch.int8: {
52
+ torch.int8,
53
+ torch.int16,
54
+ torch.int32,
55
+ torch.int64,
56
+ torch.float64,
57
+ torch.float32,
58
+ torch.float16,
59
+ },
60
+ torch.int16: {
61
+ torch.int16,
62
+ torch.int32,
63
+ torch.int64,
64
+ torch.float16,
65
+ torch.float32,
66
+ torch.float64,
67
+ },
68
+ torch.int32: {
69
+ torch.int32,
70
+ torch.int64,
71
+ torch.float16,
72
+ torch.float32,
73
+ torch.float64,
74
+ },
75
+ torch.uint8: {
76
+ torch.uint8,
77
+ torch.int16,
78
+ torch.int32,
79
+ torch.int64,
80
+ torch.float16,
81
+ torch.float32,
82
+ torch.float64,
83
+ },
84
+ torch.bool: {torch.uint8, torch.int8, torch.float16, torch.float32, torch.float64},
85
+ }
86
+
87
+
88
+ def set_execution_provider_defaults(
89
+ providers: List[Union[str, tuple]],
90
+ model_package_path: str,
91
+ device: torch.device,
92
+ enable_fp16: bool = True,
93
+ default_onnx_trt_options: bool = True,
94
+ ) -> List[Union[str, tuple[str, dict[str, Any]]]]:
95
+ result = []
96
+ device_id_options = {}
97
+ if device.index is not None:
98
+ device_id_options["device_id"] = device.index
99
+ for provider in providers:
100
+ if provider == "TensorrtExecutionProvider" and default_onnx_trt_options:
101
+ provider = (
102
+ "TensorrtExecutionProvider",
103
+ {
104
+ "trt_engine_cache_enable": True,
105
+ "trt_engine_cache_path": model_package_path,
106
+ "trt_fp16_enable": enable_fp16,
107
+ **device_id_options,
108
+ },
109
+ )
110
+ if provider == "CUDAExecutionProvider":
111
+ provider = ("CUDAExecutionProvider", device_id_options)
112
+ result.append(provider)
113
+ return result
114
+
115
+
116
+ def run_session_with_batch_size_limit(
117
+ session: onnxruntime.InferenceSession,
118
+ inputs: Dict[str, torch.Tensor],
119
+ output_shape_mapping: Optional[Dict[str, tuple]] = None,
120
+ max_batch_size: Optional[int] = None,
121
+ min_batch_size: Optional[int] = None,
122
+ ) -> List[torch.Tensor]:
123
+ if max_batch_size is None:
124
+ return run_session_via_iobinding(
125
+ session=session,
126
+ inputs=inputs,
127
+ output_shape_mapping=output_shape_mapping,
128
+ )
129
+ input_batch_sizes = set()
130
+ for input_tensor in inputs.values():
131
+ input_batch_sizes.add(input_tensor.shape[0])
132
+ if len(input_batch_sizes) != 1:
133
+ raise ModelRuntimeError(
134
+ message="When running forward pass through ONNX model detected inputs with different batch sizes. "
135
+ "This is the error with the model you run. If the model was trained or exported "
136
+ "on Roboflow platform - contact us to get help. Otherwise, verify your model package or "
137
+ "implementation of the model class.",
138
+ help_url="https://todo",
139
+ )
140
+ input_batch_size = input_batch_sizes.pop()
141
+ if min_batch_size is None and input_batch_size <= max_batch_size:
142
+ # no point iterating
143
+ return run_session_via_iobinding(
144
+ session=session,
145
+ inputs=inputs,
146
+ output_shape_mapping=output_shape_mapping,
147
+ )
148
+ all_results = []
149
+ for _ in session.get_outputs():
150
+ all_results.append([])
151
+ for i in range(0, input_batch_size, max_batch_size):
152
+ batch_inputs = {}
153
+ reminder = 0
154
+ for name, value in inputs.items():
155
+ batched_value = value[i : i + max_batch_size]
156
+ if min_batch_size is not None:
157
+ reminder = min_batch_size - batched_value.shape[0]
158
+ if reminder > 0:
159
+ batched_value = torch.cat(
160
+ (
161
+ batched_value,
162
+ torch.zeros(
163
+ (reminder,) + batched_value.shape[1:],
164
+ dtype=batched_value.dtype,
165
+ device=batched_value.device,
166
+ ),
167
+ ),
168
+ dim=0,
169
+ )
170
+ batched_value = batched_value.contiguous()
171
+ batch_inputs[name] = batched_value
172
+ batch_output_shape_mapping = None
173
+ if output_shape_mapping:
174
+ batch_output_shape_mapping = {}
175
+ for name, shape in output_shape_mapping.items():
176
+ batch_output_shape_mapping[name] = (max_batch_size,) + shape[1:]
177
+ batch_results = run_session_via_iobinding(
178
+ session=session,
179
+ inputs=batch_inputs,
180
+ output_shape_mapping=batch_output_shape_mapping,
181
+ )
182
+ if reminder > 0:
183
+ batch_results = [r[:-reminder] for r in batch_results]
184
+ for partial_result, all_result_element in zip(batch_results, all_results):
185
+ all_result_element.append(partial_result)
186
+ return [torch.cat(e, dim=0).contiguous() for e in all_results]
187
+
188
+
189
+ def run_session_via_iobinding(
190
+ session: onnxruntime.InferenceSession,
191
+ inputs: Dict[str, torch.Tensor],
192
+ output_shape_mapping: Optional[Dict[str, tuple]] = None,
193
+ ) -> List[torch.Tensor]:
194
+ inputs = auto_cast_session_inputs(
195
+ session=session,
196
+ inputs=inputs,
197
+ )
198
+ device = get_input_device(inputs=inputs)
199
+ if device.type != "cuda":
200
+ inputs_np = {name: value.cpu().numpy() for name, value in inputs.items()}
201
+ results = session.run(None, inputs_np)
202
+ return [torch.from_numpy(element).to(device=device) for element in results]
203
+ try:
204
+ import pycuda.driver as cuda
205
+
206
+ from inference_models.models.common.cuda import use_primary_cuda_context
207
+ except ImportError as import_error:
208
+ raise MissingDependencyError(
209
+ message="TODO", help_url="https://todo"
210
+ ) from import_error
211
+ cuda.init()
212
+ cuda_device = cuda.Device(device.index or 0)
213
+ with use_primary_cuda_context(cuda_device=cuda_device):
214
+ if output_shape_mapping is None:
215
+ output_shape_mapping = {}
216
+ binding = session.io_binding()
217
+ pre_allocated_outputs: List[Optional[torch.Tensor]] = []
218
+ some_outputs_dynamically_allocated = False
219
+ for output in session.get_outputs():
220
+ if is_tensor_shape_dynamic(output.shape):
221
+ if output.name in output_shape_mapping:
222
+ torch_output_type = ort_tensor_type_to_torch_tensor_type(
223
+ output.type
224
+ )
225
+ pre_allocated_output = torch.empty(
226
+ output_shape_mapping[output.name],
227
+ dtype=torch_output_type,
228
+ device=device,
229
+ )
230
+ binding.bind_output(
231
+ name=output.name,
232
+ device_type="cuda",
233
+ device_id=device.index or 0,
234
+ element_type=torch_tensor_type_to_onnx_type(torch_output_type),
235
+ shape=tuple(pre_allocated_output.shape),
236
+ buffer_ptr=pre_allocated_output.data_ptr(),
237
+ )
238
+ pre_allocated_outputs.append(pre_allocated_output)
239
+ else:
240
+ binding.bind_output(
241
+ name=output.name,
242
+ device_type="cuda",
243
+ device_id=device.index or 0,
244
+ )
245
+ some_outputs_dynamically_allocated = True
246
+ pre_allocated_outputs.append(None)
247
+ else:
248
+ torch_output_type = ort_tensor_type_to_torch_tensor_type(output.type)
249
+ pre_allocated_output = torch.empty(
250
+ output.shape,
251
+ dtype=torch_output_type,
252
+ device=device,
253
+ )
254
+ binding.bind_output(
255
+ name=output.name,
256
+ device_type="cuda",
257
+ device_id=device.index or 0,
258
+ element_type=torch_tensor_type_to_onnx_type(torch_output_type),
259
+ shape=tuple(pre_allocated_output.shape),
260
+ buffer_ptr=pre_allocated_output.data_ptr(),
261
+ )
262
+ pre_allocated_outputs.append(pre_allocated_output)
263
+ for ort_input in session.get_inputs():
264
+ input_tensor = inputs[ort_input.name].contiguous()
265
+ input_type = torch_tensor_type_to_onnx_type(tensor_dtype=input_tensor.dtype)
266
+ binding.bind_input(
267
+ name=ort_input.name,
268
+ device_type=input_tensor.device.type,
269
+ device_id=input_tensor.device.index or 0,
270
+ element_type=input_type,
271
+ shape=input_tensor.shape,
272
+ buffer_ptr=input_tensor.data_ptr(),
273
+ )
274
+ binding.synchronize_inputs()
275
+ session.run_with_iobinding(binding)
276
+ if not some_outputs_dynamically_allocated:
277
+ return pre_allocated_outputs
278
+ bound_outputs = binding.get_outputs()
279
+ result = []
280
+ for pre_allocated_output, bound_output in zip(
281
+ pre_allocated_outputs, bound_outputs
282
+ ):
283
+ if pre_allocated_output is not None:
284
+ result.append(pre_allocated_output)
285
+ continue
286
+ # This is added for the sake of true compatibility with older builds of onnxruntime
287
+ # which do not support zero-copy OrtValue -> torch.Tensor thanks top dlpack
288
+ if not hasattr(bound_output._ortvalue, "to_dlpack"):
289
+ # slower but needed :(
290
+ out_tensor = torch.from_numpy(bound_output._ortvalue.numpy()).to(device)
291
+ else:
292
+ dlpack_tensor = bound_output._ortvalue.to_dlpack()
293
+ out_tensor = torch.utils.dlpack.from_dlpack(dlpack_tensor)
294
+ result.append(out_tensor)
295
+ return result
296
+
297
+
298
+ def auto_cast_session_inputs(
299
+ session: onnxruntime.InferenceSession, inputs: Dict[str, torch.Tensor]
300
+ ) -> Dict[str, torch.Tensor]:
301
+ for ort_input in session.get_inputs():
302
+ expected_type = ort_tensor_type_to_torch_tensor_type(ort_input.type)
303
+ if ort_input.name not in inputs:
304
+ raise ModelRuntimeError(
305
+ message="While performing forward pass through the model, library bug was discovered - "
306
+ f"required model input named '{ort_input.name}' is missing. Submit "
307
+ f"issue to help us solving this problem: https://github.com/roboflow/inference/issues",
308
+ help_url="https://todo",
309
+ )
310
+ actual_type = inputs[ort_input.name].dtype
311
+ if actual_type == expected_type:
312
+ continue
313
+ if not can_model_input_be_casted(source=actual_type, target=expected_type):
314
+ raise ModelRuntimeError(
315
+ message="While performing forward pass through the model, library bug was discovered - "
316
+ f"model requires the input type to be {expected_type}, but the actual input type is {actual_type} - "
317
+ f"this is a bug in model implementation. Submit issue to help us solving this problem: "
318
+ f"https://github.com/roboflow/inference/issues",
319
+ help_url="https://todo",
320
+ )
321
+ inputs[ort_input.name] = inputs[ort_input.name].to(dtype=expected_type)
322
+ return inputs
323
+
324
+
325
+ def torch_tensor_type_to_onnx_type(tensor_dtype: torch.dtype) -> Union[np.dtype, int]:
326
+ if tensor_dtype not in TORCH_TYPES_MAPPING:
327
+ raise ModelRuntimeError(
328
+ message=f"While performing forward pass through the model, library discovered tensor of type {tensor_dtype} "
329
+ f"which needs to be passed to onnxruntime session. Conversion of this type is currently not "
330
+ f"supported in inference. At the moment you shall assume your model incompatible with the library. "
331
+ f"To change that state - please submit new issue: https://github.com/roboflow/inference/issues",
332
+ help_url="https://todo",
333
+ )
334
+ return TORCH_TYPES_MAPPING[tensor_dtype]
335
+
336
+
337
+ def ort_tensor_type_to_torch_tensor_type(ort_dtype: str) -> torch.dtype:
338
+ if ort_dtype not in ORT_TYPES_TO_TORCH_TYPES_MAPPING:
339
+ raise ModelRuntimeError(
340
+ message=f"While performing forward pass through the model, library discovered ORT tensor of type {ort_dtype} "
341
+ f"which needs to be casted into torch.Tensor. Conversion of this type is currently not "
342
+ f"supported in inference. At the moment you shall assume your model incompatible with the library. "
343
+ f"To change that state - please submit new issue: https://github.com/roboflow/inference/issues",
344
+ help_url="https://todo",
345
+ )
346
+ return ORT_TYPES_TO_TORCH_TYPES_MAPPING[ort_dtype]
347
+
348
+
349
+ def is_tensor_shape_dynamic(shape: tuple) -> bool:
350
+ return any(isinstance(dim, str) for dim in shape)
351
+
352
+
353
+ def can_model_input_be_casted(source: torch.dtype, target: torch.dtype) -> bool:
354
+ if source not in MODEL_INPUT_CASTING:
355
+ return False
356
+ return target in MODEL_INPUT_CASTING[source]
357
+
358
+
359
+ def get_input_device(inputs: Dict[str, torch.Tensor]) -> torch.device:
360
+ device = None
361
+ for input_name, input_tensor in inputs.items():
362
+ if device is None:
363
+ device = input_tensor.device
364
+ elif input_tensor.device != device:
365
+ raise ModelRuntimeError(
366
+ message="While performing forward pass through the model, library discovered the input tensor which is "
367
+ f"wrongly allocated on a different device that rest of the inputs - input named '{input_name}' "
368
+ f"is allocated on {input_tensor.device}, whereas rest of the inputs are allocated on {device}. "
369
+ f"This is a bug in model implementation. To help us fixing that, please submit new issue: "
370
+ f"https://github.com/roboflow/inference/issues",
371
+ help_url="https://todo",
372
+ )
373
+ if device is None:
374
+ raise ModelRuntimeError(
375
+ message="No inputs detected for the model. Raise new issue to help us fixing the problem: "
376
+ "https://github.com/roboflow/inference/issues",
377
+ help_url="https://todo",
378
+ )
379
+ return device
File without changes