inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,227 @@
1
+ import threading
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from inference_models import InstanceDetections, InstanceSegmentationModel
8
+ from inference_models.configuration import DEFAULT_DEVICE
9
+ from inference_models.entities import ColorFormat
10
+ from inference_models.errors import (
11
+ CorruptedModelPackageError,
12
+ MissingDependencyError,
13
+ ModelRuntimeError,
14
+ )
15
+ from inference_models.models.common.cuda import (
16
+ use_cuda_context,
17
+ use_primary_cuda_context,
18
+ )
19
+ from inference_models.models.common.model_packages import get_model_package_contents
20
+ from inference_models.models.common.roboflow.model_packages import (
21
+ InferenceConfig,
22
+ PreProcessingMetadata,
23
+ ResizeMode,
24
+ TRTConfig,
25
+ parse_class_names_file,
26
+ parse_inference_config,
27
+ parse_trt_config,
28
+ )
29
+ from inference_models.models.common.roboflow.pre_processing import (
30
+ pre_process_network_input,
31
+ )
32
+ from inference_models.models.common.trt import (
33
+ get_engine_inputs_and_outputs,
34
+ infer_from_trt_engine,
35
+ load_model,
36
+ )
37
+ from inference_models.models.rfdetr.class_remapping import (
38
+ ClassesReMapping,
39
+ prepare_class_remapping,
40
+ )
41
+ from inference_models.models.rfdetr.common import (
42
+ post_process_instance_segmentation_results,
43
+ )
44
+
45
+ try:
46
+ import tensorrt as trt
47
+ except ImportError as import_error:
48
+ raise MissingDependencyError(
49
+ message=f"Could not import RFDetr model with TRT backend - this error means that some additional dependencies "
50
+ f"are not installed in the environment. If you run the `inference-models` library directly in your Python "
51
+ f"program, make sure the following extras of the package are installed: `trt10` - installation can only "
52
+ f"succeed for Linux and Windows machines with Cuda 12 installed. Jetson devices, should have TRT 10.x "
53
+ f"installed for all builds with Jetpack 6. "
54
+ f"If you see this error using Roboflow infrastructure, make sure the service you use does support the model. "
55
+ f"You can also contact Roboflow to get support.",
56
+ help_url="https://todo",
57
+ ) from import_error
58
+
59
+ try:
60
+ import pycuda.driver as cuda
61
+ except ImportError as import_error:
62
+ raise MissingDependencyError(
63
+ message="TODO", help_url="https://todo"
64
+ ) from import_error
65
+
66
+
67
+ class RFDetrForInstanceSegmentationTRT(
68
+ InstanceSegmentationModel[
69
+ torch.Tensor,
70
+ PreProcessingMetadata,
71
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
72
+ ]
73
+ ):
74
+
75
+ @classmethod
76
+ def from_pretrained(
77
+ cls,
78
+ model_name_or_path: str,
79
+ device: torch.device = DEFAULT_DEVICE,
80
+ engine_host_code_allowed: bool = False,
81
+ **kwargs,
82
+ ) -> "RFDetrForInstanceSegmentationTRT":
83
+ if device.type != "cuda":
84
+ raise ModelRuntimeError(
85
+ message="TRT engine only runs on CUDA device - {device} device detected.",
86
+ help_url="https://todo",
87
+ )
88
+ model_package_content = get_model_package_contents(
89
+ model_package_dir=model_name_or_path,
90
+ elements=[
91
+ "class_names.txt",
92
+ "inference_config.json",
93
+ "trt_config.json",
94
+ "engine.plan",
95
+ ],
96
+ )
97
+ class_names = parse_class_names_file(
98
+ class_names_path=model_package_content["class_names.txt"]
99
+ )
100
+ inference_config = parse_inference_config(
101
+ config_path=model_package_content["inference_config.json"],
102
+ allowed_resize_modes={
103
+ ResizeMode.STRETCH_TO,
104
+ ResizeMode.LETTERBOX,
105
+ ResizeMode.CENTER_CROP,
106
+ ResizeMode.LETTERBOX_REFLECT_EDGES,
107
+ },
108
+ )
109
+ classes_re_mapping = None
110
+ if inference_config.class_names_operations:
111
+ class_names, classes_re_mapping = prepare_class_remapping(
112
+ class_names=class_names,
113
+ class_names_operations=inference_config.class_names_operations,
114
+ device=device,
115
+ )
116
+ trt_config = parse_trt_config(
117
+ config_path=model_package_content["trt_config.json"]
118
+ )
119
+ cuda.init()
120
+ cuda_device = cuda.Device(device.index or 0)
121
+ with use_primary_cuda_context(cuda_device=cuda_device) as cuda_context:
122
+ engine = load_model(
123
+ model_path=model_package_content["engine.plan"],
124
+ engine_host_code_allowed=engine_host_code_allowed,
125
+ )
126
+ execution_context = engine.create_execution_context()
127
+ inputs, outputs = get_engine_inputs_and_outputs(engine=engine)
128
+ if len(inputs) != 1:
129
+ raise CorruptedModelPackageError(
130
+ message=f"Implementation assume single model input, found: {len(inputs)}.",
131
+ help_url="https://todo",
132
+ )
133
+ if len(outputs) != 3:
134
+ raise CorruptedModelPackageError(
135
+ message=f"Implementation assume 3 model outputs, found: {len(outputs)}.",
136
+ help_url="https://todo",
137
+ )
138
+ return cls(
139
+ engine=engine,
140
+ input_name=inputs[0],
141
+ output_names=outputs,
142
+ class_names=class_names,
143
+ classes_re_mapping=classes_re_mapping,
144
+ inference_config=inference_config,
145
+ trt_config=trt_config,
146
+ device=device,
147
+ cuda_context=cuda_context,
148
+ execution_context=execution_context,
149
+ )
150
+
151
+ def __init__(
152
+ self,
153
+ engine: trt.ICudaEngine,
154
+ input_name: str,
155
+ output_names: List[str],
156
+ class_names: List[str],
157
+ classes_re_mapping: Optional[ClassesReMapping],
158
+ inference_config: InferenceConfig,
159
+ trt_config: TRTConfig,
160
+ device: torch.device,
161
+ cuda_context: cuda.Context,
162
+ execution_context: trt.IExecutionContext,
163
+ ):
164
+ self._engine = engine
165
+ self._input_name = input_name
166
+ self._output_names = output_names
167
+ self._inference_config = inference_config
168
+ self._class_names = class_names
169
+ self._classes_re_mapping = classes_re_mapping
170
+ self._device = device
171
+ self._cuda_context = cuda_context
172
+ self._execution_context = execution_context
173
+ self._trt_config = trt_config
174
+ self._lock = threading.Lock()
175
+
176
+ @property
177
+ def class_names(self) -> List[str]:
178
+ return self._class_names
179
+
180
+ def pre_process(
181
+ self,
182
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
183
+ input_color_format: Optional[ColorFormat] = None,
184
+ image_size: Optional[Tuple[int, int]] = None,
185
+ **kwargs,
186
+ ) -> Tuple[torch.Tensor, List[PreProcessingMetadata]]:
187
+ return pre_process_network_input(
188
+ images=images,
189
+ image_pre_processing=self._inference_config.image_pre_processing,
190
+ network_input=self._inference_config.network_input,
191
+ target_device=self._device,
192
+ input_color_format=input_color_format,
193
+ image_size_wh=image_size,
194
+ )
195
+
196
+ def forward(
197
+ self, pre_processed_images: torch.Tensor, **kwargs
198
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
199
+ with self._lock:
200
+ with use_cuda_context(context=self._cuda_context):
201
+ detections, labels, masks = infer_from_trt_engine(
202
+ pre_processed_images=pre_processed_images,
203
+ trt_config=self._trt_config,
204
+ engine=self._engine,
205
+ context=self._execution_context,
206
+ device=self._device,
207
+ input_name=self._input_name,
208
+ outputs=self._output_names,
209
+ )
210
+ return detections, labels, masks
211
+
212
+ def post_process(
213
+ self,
214
+ model_results: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
215
+ pre_processing_meta: List[PreProcessingMetadata],
216
+ threshold: float = 0.5,
217
+ **kwargs,
218
+ ) -> List[InstanceDetections]:
219
+ bboxes, logits, masks = model_results
220
+ return post_process_instance_segmentation_results(
221
+ bboxes=bboxes,
222
+ logits=logits,
223
+ masks=masks,
224
+ pre_processing_meta=pre_processing_meta,
225
+ threshold=threshold,
226
+ classes_re_mapping=self._classes_re_mapping,
227
+ )
@@ -0,0 +1,244 @@
1
+ import threading
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from inference_models import Detections, ObjectDetectionModel
8
+ from inference_models.configuration import DEFAULT_DEVICE
9
+ from inference_models.entities import ColorFormat
10
+ from inference_models.errors import (
11
+ EnvironmentConfigurationError,
12
+ MissingDependencyError,
13
+ )
14
+ from inference_models.models.common.model_packages import get_model_package_contents
15
+ from inference_models.models.common.onnx import (
16
+ run_session_with_batch_size_limit,
17
+ set_execution_provider_defaults,
18
+ )
19
+ from inference_models.models.common.roboflow.model_packages import (
20
+ InferenceConfig,
21
+ PreProcessingMetadata,
22
+ ResizeMode,
23
+ parse_class_names_file,
24
+ parse_inference_config,
25
+ )
26
+ from inference_models.models.common.roboflow.post_processing import (
27
+ rescale_image_detections,
28
+ )
29
+ from inference_models.models.common.roboflow.pre_processing import (
30
+ pre_process_network_input,
31
+ )
32
+ from inference_models.models.rfdetr.class_remapping import (
33
+ ClassesReMapping,
34
+ prepare_class_remapping,
35
+ )
36
+ from inference_models.utils.onnx_introspection import (
37
+ get_selected_onnx_execution_providers,
38
+ )
39
+
40
+ try:
41
+ import onnxruntime
42
+ except ImportError as import_error:
43
+ raise MissingDependencyError(
44
+ message=f"Could not import YOLOv8 model with ONNX backend - this error means that some additional dependencies "
45
+ f"are not installed in the environment. If you run the `inference-models` library directly in your Python "
46
+ f"program, make sure the following extras of the package are installed: \n"
47
+ f"\t* `onnx-cpu` - when you wish to use library with CPU support only\n"
48
+ f"\t* `onnx-cu12` - for running on GPU with Cuda 12 installed\n"
49
+ f"\t* `onnx-cu118` - for running on GPU with Cuda 11.8 installed\n"
50
+ f"\t* `onnx-jp6-cu126` - for running on Jetson with Jetpack 6\n"
51
+ f"If you see this error using Roboflow infrastructure, make sure the service you use does support the model. "
52
+ f"You can also contact Roboflow to get support.",
53
+ help_url="https://todo",
54
+ ) from import_error
55
+
56
+
57
+ class RFDetrForObjectDetectionONNX(
58
+ (
59
+ ObjectDetectionModel[
60
+ torch.Tensor, PreProcessingMetadata, Tuple[torch.Tensor, torch.Tensor]
61
+ ]
62
+ )
63
+ ):
64
+
65
+ @classmethod
66
+ def from_pretrained(
67
+ cls,
68
+ model_name_or_path: str,
69
+ onnx_execution_providers: Optional[List[Union[str, tuple]]] = None,
70
+ default_onnx_trt_options: bool = True,
71
+ device: torch.device = DEFAULT_DEVICE,
72
+ **kwargs,
73
+ ) -> "RFDetrForObjectDetectionONNX":
74
+ if onnx_execution_providers is None:
75
+ onnx_execution_providers = get_selected_onnx_execution_providers()
76
+ if not onnx_execution_providers:
77
+ raise EnvironmentConfigurationError(
78
+ message=f"Could not initialize model - selected backend is ONNX which requires execution provider to "
79
+ f"be specified - explicitly in `from_pretrained(...)` method or via env variable "
80
+ f"`ONNXRUNTIME_EXECUTION_PROVIDERS`. If you run model locally - adjust your setup, otherwise "
81
+ f"contact the platform support.",
82
+ help_url="https://todo",
83
+ )
84
+ onnx_execution_providers = set_execution_provider_defaults(
85
+ providers=onnx_execution_providers,
86
+ model_package_path=model_name_or_path,
87
+ device=device,
88
+ default_onnx_trt_options=default_onnx_trt_options,
89
+ )
90
+ model_package_content = get_model_package_contents(
91
+ model_package_dir=model_name_or_path,
92
+ elements=[
93
+ "class_names.txt",
94
+ "inference_config.json",
95
+ "weights.onnx",
96
+ ],
97
+ )
98
+ class_names = parse_class_names_file(
99
+ class_names_path=model_package_content["class_names.txt"]
100
+ )
101
+ inference_config = parse_inference_config(
102
+ config_path=model_package_content["inference_config.json"],
103
+ allowed_resize_modes={
104
+ ResizeMode.STRETCH_TO,
105
+ ResizeMode.LETTERBOX,
106
+ ResizeMode.CENTER_CROP,
107
+ ResizeMode.LETTERBOX_REFLECT_EDGES,
108
+ },
109
+ )
110
+ classes_re_mapping = None
111
+ if inference_config.class_names_operations:
112
+ class_names, classes_re_mapping = prepare_class_remapping(
113
+ class_names=class_names,
114
+ class_names_operations=inference_config.class_names_operations,
115
+ device=device,
116
+ )
117
+ session = onnxruntime.InferenceSession(
118
+ path_or_bytes=model_package_content["weights.onnx"],
119
+ providers=onnx_execution_providers,
120
+ )
121
+ input_batch_size = session.get_inputs()[0].shape[0]
122
+ if isinstance(input_batch_size, str):
123
+ input_batch_size = None
124
+ input_name = session.get_inputs()[0].name
125
+ return cls(
126
+ session=session,
127
+ input_name=input_name,
128
+ class_names=class_names,
129
+ classes_re_mapping=classes_re_mapping,
130
+ inference_config=inference_config,
131
+ device=device,
132
+ input_batch_size=input_batch_size,
133
+ )
134
+
135
+ def __init__(
136
+ self,
137
+ session: onnxruntime.InferenceSession,
138
+ input_name: str,
139
+ class_names: List[str],
140
+ classes_re_mapping: Optional[ClassesReMapping],
141
+ inference_config: InferenceConfig,
142
+ device: torch.device,
143
+ input_batch_size: Optional[int],
144
+ ):
145
+ self._session = session
146
+ self._input_name = input_name
147
+ self._inference_config = inference_config
148
+ self._class_names = class_names
149
+ self._classes_re_mapping = classes_re_mapping
150
+ self._device = device
151
+ self._min_batch_size = input_batch_size
152
+ self._max_batch_size = (
153
+ input_batch_size
154
+ if input_batch_size is not None
155
+ else inference_config.forward_pass.max_dynamic_batch_size
156
+ )
157
+ self._session_thread_lock = threading.Lock()
158
+
159
+ @property
160
+ def class_names(self) -> List[str]:
161
+ return self._class_names
162
+
163
+ def pre_process(
164
+ self,
165
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
166
+ input_color_format: Optional[ColorFormat] = None,
167
+ **kwargs,
168
+ ) -> Tuple[torch.Tensor, List[PreProcessingMetadata]]:
169
+ return pre_process_network_input(
170
+ images=images,
171
+ image_pre_processing=self._inference_config.image_pre_processing,
172
+ network_input=self._inference_config.network_input,
173
+ target_device=self._device,
174
+ input_color_format=input_color_format,
175
+ )
176
+
177
+ def forward(
178
+ self, pre_processed_images: torch.Tensor, **kwargs
179
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
180
+ with self._session_thread_lock:
181
+ bboxes, logits = run_session_with_batch_size_limit(
182
+ session=self._session,
183
+ inputs={self._input_name: pre_processed_images},
184
+ min_batch_size=self._min_batch_size,
185
+ max_batch_size=self._max_batch_size,
186
+ )
187
+ return bboxes, logits
188
+
189
+ def post_process(
190
+ self,
191
+ model_results: Tuple[torch.Tensor, torch.Tensor],
192
+ pre_processing_meta: List[PreProcessingMetadata],
193
+ threshold: float = 0.5,
194
+ **kwargs,
195
+ ) -> List[Detections]:
196
+ bboxes, logits = model_results
197
+ logits_sigmoid = torch.nn.functional.sigmoid(logits)
198
+ results = []
199
+ for image_bboxes, image_logits, image_meta in zip(
200
+ bboxes, logits_sigmoid, pre_processing_meta
201
+ ):
202
+ confidence, top_classes = image_logits.max(dim=1)
203
+ confidence_mask = confidence > threshold
204
+ confidence = confidence[confidence_mask]
205
+ top_classes = top_classes[confidence_mask]
206
+ selected_boxes = image_bboxes[confidence_mask]
207
+ confidence, sorted_indices = torch.sort(confidence, descending=True)
208
+ top_classes = top_classes[sorted_indices]
209
+ selected_boxes = selected_boxes[sorted_indices]
210
+ if self._classes_re_mapping is not None:
211
+ remapping_mask = torch.isin(
212
+ top_classes, self._classes_re_mapping.remaining_class_ids
213
+ )
214
+ top_classes = self._classes_re_mapping.class_mapping[
215
+ top_classes[remapping_mask]
216
+ ]
217
+ selected_boxes = selected_boxes[remapping_mask]
218
+ confidence = confidence[remapping_mask]
219
+ cxcy = selected_boxes[:, :2]
220
+ wh = selected_boxes[:, 2:]
221
+ xy_min = cxcy - 0.5 * wh
222
+ xy_max = cxcy + 0.5 * wh
223
+ selected_boxes_xyxy_pct = torch.cat([xy_min, xy_max], dim=-1)
224
+ inference_size_hwhw = torch.tensor(
225
+ [
226
+ image_meta.inference_size.height,
227
+ image_meta.inference_size.width,
228
+ image_meta.inference_size.height,
229
+ image_meta.inference_size.width,
230
+ ],
231
+ device=self._device,
232
+ )
233
+ selected_boxes_xyxy = selected_boxes_xyxy_pct * inference_size_hwhw
234
+ selected_boxes_xyxy = rescale_image_detections(
235
+ image_detections=selected_boxes_xyxy,
236
+ image_metadata=image_meta,
237
+ )
238
+ detections = Detections(
239
+ xyxy=selected_boxes_xyxy.round().int(),
240
+ confidence=confidence,
241
+ class_id=top_classes.int(),
242
+ )
243
+ results.append(detections)
244
+ return results