inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,365 @@
1
+ from threading import Lock
2
+ from typing import List, Optional, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from inference_models import (
8
+ ClassificationModel,
9
+ ClassificationPrediction,
10
+ MultiLabelClassificationModel,
11
+ MultiLabelClassificationPrediction,
12
+ )
13
+ from inference_models.configuration import DEFAULT_DEVICE
14
+ from inference_models.entities import ColorFormat
15
+ from inference_models.errors import (
16
+ CorruptedModelPackageError,
17
+ MissingDependencyError,
18
+ ModelRuntimeError,
19
+ )
20
+ from inference_models.models.base.types import PreprocessedInputs
21
+ from inference_models.models.common.cuda import (
22
+ use_cuda_context,
23
+ use_primary_cuda_context,
24
+ )
25
+ from inference_models.models.common.model_packages import get_model_package_contents
26
+ from inference_models.models.common.roboflow.model_packages import (
27
+ InferenceConfig,
28
+ ResizeMode,
29
+ TRTConfig,
30
+ parse_class_names_file,
31
+ parse_inference_config,
32
+ parse_trt_config,
33
+ )
34
+ from inference_models.models.common.roboflow.pre_processing import (
35
+ pre_process_network_input,
36
+ )
37
+ from inference_models.models.common.trt import (
38
+ get_engine_inputs_and_outputs,
39
+ infer_from_trt_engine,
40
+ load_model,
41
+ )
42
+
43
+ try:
44
+ import tensorrt as trt
45
+ except ImportError as import_error:
46
+ raise MissingDependencyError(
47
+ message=f"Could not import YOLOv8 model with TRT backend - this error means that some additional dependencies "
48
+ f"are not installed in the environment. If you run the `inference-models` library directly in your Python "
49
+ f"program, make sure the following extras of the package are installed: `trt10` - installation can only "
50
+ f"succeed for Linux and Windows machines with Cuda 12 installed. Jetson devices, should have TRT 10.x "
51
+ f"installed for all builds with Jetpack 6. "
52
+ f"If you see this error using Roboflow infrastructure, make sure the service you use does support the model. "
53
+ f"You can also contact Roboflow to get support.",
54
+ help_url="https://todo",
55
+ ) from import_error
56
+
57
+ try:
58
+ import pycuda.driver as cuda
59
+ except ImportError as import_error:
60
+ raise MissingDependencyError(
61
+ message="TODO", help_url="https://todo"
62
+ ) from import_error
63
+
64
+
65
+ class VITForClassificationTRT(ClassificationModel[torch.Tensor, torch.Tensor]):
66
+
67
+ @classmethod
68
+ def from_pretrained(
69
+ cls,
70
+ model_name_or_path: str,
71
+ device: torch.device = DEFAULT_DEVICE,
72
+ engine_host_code_allowed: bool = False,
73
+ **kwargs,
74
+ ) -> "VITForClassificationTRT":
75
+ if device.type != "cuda":
76
+ raise ModelRuntimeError(
77
+ message=f"TRT engine only runs on CUDA device - {device} device detected.",
78
+ help_url="https://todo",
79
+ )
80
+ model_package_content = get_model_package_contents(
81
+ model_package_dir=model_name_or_path,
82
+ elements=[
83
+ "class_names.txt",
84
+ "inference_config.json",
85
+ "trt_config.json",
86
+ "engine.plan",
87
+ ],
88
+ )
89
+ class_names = parse_class_names_file(
90
+ class_names_path=model_package_content["class_names.txt"]
91
+ )
92
+ inference_config = parse_inference_config(
93
+ config_path=model_package_content["inference_config.json"],
94
+ allowed_resize_modes={
95
+ ResizeMode.STRETCH_TO,
96
+ ResizeMode.LETTERBOX,
97
+ ResizeMode.CENTER_CROP,
98
+ ResizeMode.LETTERBOX_REFLECT_EDGES,
99
+ },
100
+ )
101
+ if inference_config.post_processing.type != "softmax":
102
+ raise CorruptedModelPackageError(
103
+ message="Expected Softmax to be the post-processing",
104
+ help_url="https://todo",
105
+ )
106
+ trt_config = parse_trt_config(
107
+ config_path=model_package_content["trt_config.json"]
108
+ )
109
+ cuda.init()
110
+ cuda_device = cuda.Device(device.index or 0)
111
+ with use_primary_cuda_context(cuda_device=cuda_device) as cuda_context:
112
+ engine = load_model(
113
+ model_path=model_package_content["engine.plan"],
114
+ engine_host_code_allowed=engine_host_code_allowed,
115
+ )
116
+ execution_context = engine.create_execution_context()
117
+ inputs, outputs = get_engine_inputs_and_outputs(engine=engine)
118
+ if len(inputs) != 1:
119
+ raise CorruptedModelPackageError(
120
+ message=f"Implementation assume single model input, found: {len(inputs)}.",
121
+ help_url="https://todo",
122
+ )
123
+ if len(outputs) != 1:
124
+ raise CorruptedModelPackageError(
125
+ message=f"Implementation assume single model output, found: {len(outputs)}.",
126
+ help_url="https://todo",
127
+ )
128
+ return cls(
129
+ engine=engine,
130
+ input_name=inputs[0],
131
+ output_name=outputs[0],
132
+ class_names=class_names,
133
+ inference_config=inference_config,
134
+ trt_config=trt_config,
135
+ device=device,
136
+ cuda_context=cuda_context,
137
+ execution_context=execution_context,
138
+ )
139
+
140
+ def __init__(
141
+ self,
142
+ engine: trt.ICudaEngine,
143
+ input_name: str,
144
+ output_name: str,
145
+ class_names: List[str],
146
+ inference_config: InferenceConfig,
147
+ trt_config: TRTConfig,
148
+ device: torch.device,
149
+ cuda_context: cuda.Context,
150
+ execution_context: trt.IExecutionContext,
151
+ ):
152
+ self._engine = engine
153
+ self._input_name = input_name
154
+ self._output_names = [output_name]
155
+ self._class_names = class_names
156
+ self._inference_config = inference_config
157
+ self._trt_config = trt_config
158
+ self._device = device
159
+ self._cuda_context = cuda_context
160
+ self._execution_context = execution_context
161
+ self._lock = Lock()
162
+
163
+ @property
164
+ def class_names(self) -> List[str]:
165
+ return self._class_names
166
+
167
+ def pre_process(
168
+ self,
169
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
170
+ input_color_format: Optional[ColorFormat] = None,
171
+ **kwargs,
172
+ ) -> torch.Tensor:
173
+ return pre_process_network_input(
174
+ images=images,
175
+ image_pre_processing=self._inference_config.image_pre_processing,
176
+ network_input=self._inference_config.network_input,
177
+ target_device=self._device,
178
+ input_color_format=input_color_format,
179
+ )[0]
180
+
181
+ def forward(
182
+ self, pre_processed_images: PreprocessedInputs, **kwargs
183
+ ) -> torch.Tensor:
184
+ with self._lock:
185
+ with use_cuda_context(context=self._cuda_context):
186
+ return infer_from_trt_engine(
187
+ pre_processed_images=pre_processed_images,
188
+ trt_config=self._trt_config,
189
+ engine=self._engine,
190
+ context=self._execution_context,
191
+ device=self._device,
192
+ input_name=self._input_name,
193
+ outputs=self._output_names,
194
+ )[0]
195
+
196
+ def post_process(
197
+ self,
198
+ model_results: torch.Tensor,
199
+ **kwargs,
200
+ ) -> ClassificationPrediction:
201
+ if self._inference_config.post_processing.fused:
202
+ confidence = model_results
203
+ else:
204
+ confidence = torch.nn.functional.softmax(model_results, dim=-1)
205
+ return ClassificationPrediction(
206
+ class_id=confidence.argmax(dim=-1),
207
+ confidence=confidence,
208
+ )
209
+
210
+
211
+ class VITForMultiLabelClassificationTRT(
212
+ MultiLabelClassificationModel[torch.Tensor, torch.Tensor]
213
+ ):
214
+
215
+ @classmethod
216
+ def from_pretrained(
217
+ cls,
218
+ model_name_or_path: str,
219
+ device: torch.device = DEFAULT_DEVICE,
220
+ engine_host_code_allowed: bool = False,
221
+ **kwargs,
222
+ ) -> "VITForMultiLabelClassificationTRT":
223
+ if device.type != "cuda":
224
+ raise ModelRuntimeError(
225
+ message=f"TRT engine only runs on CUDA device - {device} device detected.",
226
+ help_url="https://todo",
227
+ )
228
+ model_package_content = get_model_package_contents(
229
+ model_package_dir=model_name_or_path,
230
+ elements=[
231
+ "class_names.txt",
232
+ "inference_config.json",
233
+ "trt_config.json",
234
+ "engine.plan",
235
+ ],
236
+ )
237
+ class_names = parse_class_names_file(
238
+ class_names_path=model_package_content["class_names.txt"]
239
+ )
240
+ inference_config = parse_inference_config(
241
+ config_path=model_package_content["inference_config.json"],
242
+ allowed_resize_modes={
243
+ ResizeMode.STRETCH_TO,
244
+ ResizeMode.LETTERBOX,
245
+ ResizeMode.CENTER_CROP,
246
+ ResizeMode.LETTERBOX_REFLECT_EDGES,
247
+ },
248
+ )
249
+ if inference_config.post_processing.type != "sigmoid":
250
+ raise CorruptedModelPackageError(
251
+ message="Expected sigmoid to be the post-processing",
252
+ help_url="https://todo",
253
+ )
254
+ trt_config = parse_trt_config(
255
+ config_path=model_package_content["trt_config.json"]
256
+ )
257
+ cuda.init()
258
+ cuda_device = cuda.Device(device.index or 0)
259
+ with use_primary_cuda_context(cuda_device=cuda_device) as cuda_context:
260
+ engine = load_model(
261
+ model_path=model_package_content["engine.plan"],
262
+ engine_host_code_allowed=engine_host_code_allowed,
263
+ )
264
+ execution_context = engine.create_execution_context()
265
+ inputs, outputs = get_engine_inputs_and_outputs(engine=engine)
266
+ if len(inputs) != 1:
267
+ raise CorruptedModelPackageError(
268
+ message=f"Implementation assume single model input, found: {len(inputs)}.",
269
+ help_url="https://todo",
270
+ )
271
+ if len(outputs) != 1:
272
+ raise CorruptedModelPackageError(
273
+ message=f"Implementation assume single model output, found: {len(outputs)}.",
274
+ help_url="https://todo",
275
+ )
276
+ return cls(
277
+ engine=engine,
278
+ input_name=inputs[0],
279
+ output_name=outputs[0],
280
+ class_names=class_names,
281
+ inference_config=inference_config,
282
+ trt_config=trt_config,
283
+ device=device,
284
+ cuda_context=cuda_context,
285
+ execution_context=execution_context,
286
+ )
287
+
288
+ def __init__(
289
+ self,
290
+ engine: trt.ICudaEngine,
291
+ input_name: str,
292
+ output_name: str,
293
+ class_names: List[str],
294
+ inference_config: InferenceConfig,
295
+ trt_config: TRTConfig,
296
+ device: torch.device,
297
+ cuda_context: cuda.Context,
298
+ execution_context: trt.IExecutionContext,
299
+ ):
300
+ self._engine = engine
301
+ self._input_name = input_name
302
+ self._output_names = [output_name]
303
+ self._class_names = class_names
304
+ self._inference_config = inference_config
305
+ self._trt_config = trt_config
306
+ self._device = device
307
+ self._cuda_context = cuda_context
308
+ self._execution_context = execution_context
309
+ self._lock = Lock()
310
+
311
+ @property
312
+ def class_names(self) -> List[str]:
313
+ return self._class_names
314
+
315
+ def pre_process(
316
+ self,
317
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
318
+ input_color_format: Optional[ColorFormat] = None,
319
+ **kwargs,
320
+ ) -> torch.Tensor:
321
+ return pre_process_network_input(
322
+ images=images,
323
+ image_pre_processing=self._inference_config.image_pre_processing,
324
+ network_input=self._inference_config.network_input,
325
+ target_device=self._device,
326
+ input_color_format=input_color_format,
327
+ )[0]
328
+
329
+ def forward(
330
+ self, pre_processed_images: PreprocessedInputs, **kwargs
331
+ ) -> torch.Tensor:
332
+ with self._lock:
333
+ with use_cuda_context(context=self._cuda_context):
334
+ return infer_from_trt_engine(
335
+ pre_processed_images=pre_processed_images,
336
+ trt_config=self._trt_config,
337
+ engine=self._engine,
338
+ context=self._execution_context,
339
+ device=self._device,
340
+ input_name=self._input_name,
341
+ outputs=self._output_names,
342
+ )[0]
343
+
344
+ def post_process(
345
+ self,
346
+ model_results: torch.Tensor,
347
+ confidence: float = 0.5,
348
+ **kwargs,
349
+ ) -> List[MultiLabelClassificationPrediction]:
350
+ if self._inference_config.post_processing.fused:
351
+ model_results = model_results
352
+ else:
353
+ model_results = torch.nn.functional.sigmoid(model_results)
354
+ results = []
355
+ for batch_element_confidence in model_results:
356
+ predicted_classes = torch.argwhere(
357
+ batch_element_confidence >= confidence
358
+ ).squeeze(dim=-1)
359
+ results.append(
360
+ MultiLabelClassificationPrediction(
361
+ class_ids=predicted_classes,
362
+ confidence=batch_element_confidence,
363
+ )
364
+ )
365
+ return results
@@ -0,0 +1 @@
1
+ # TODO: decide if port is needed
@@ -0,0 +1,336 @@
1
+ from threading import Lock
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torchvision
7
+
8
+ from inference_models import InstanceDetections, InstanceSegmentationModel
9
+ from inference_models.configuration import DEFAULT_DEVICE
10
+ from inference_models.entities import ColorFormat
11
+ from inference_models.errors import (
12
+ EnvironmentConfigurationError,
13
+ MissingDependencyError,
14
+ ModelRuntimeError,
15
+ )
16
+ from inference_models.models.common.model_packages import get_model_package_contents
17
+ from inference_models.models.common.onnx import (
18
+ run_session_with_batch_size_limit,
19
+ set_execution_provider_defaults,
20
+ )
21
+ from inference_models.models.common.roboflow.model_packages import (
22
+ InferenceConfig,
23
+ PreProcessingMetadata,
24
+ ResizeMode,
25
+ parse_class_names_file,
26
+ parse_inference_config,
27
+ )
28
+ from inference_models.models.common.roboflow.post_processing import (
29
+ align_instance_segmentation_results,
30
+ crop_masks_to_boxes,
31
+ )
32
+ from inference_models.models.common.roboflow.pre_processing import (
33
+ pre_process_network_input,
34
+ )
35
+ from inference_models.utils.onnx_introspection import (
36
+ get_selected_onnx_execution_providers,
37
+ )
38
+
39
+ try:
40
+ import onnxruntime
41
+ except ImportError as import_error:
42
+ raise MissingDependencyError(
43
+ message=f"Could not import YOLOv5 model with ONNX backend - this error means that some additional dependencies "
44
+ f"are not installed in the environment. If you run the `inference-models` library directly in your Python "
45
+ f"program, make sure the following extras of the package are installed: \n"
46
+ f"\t* `onnx-cpu` - when you wish to use library with CPU support only\n"
47
+ f"\t* `onnx-cu12` - for running on GPU with Cuda 12 installed\n"
48
+ f"\t* `onnx-cu118` - for running on GPU with Cuda 11.8 installed\n"
49
+ f"\t* `onnx-jp6-cu126` - for running on Jetson with Jetpack 6\n"
50
+ f"If you see this error using Roboflow infrastructure, make sure the service you use does support the model. "
51
+ f"You can also contact Roboflow to get support.",
52
+ help_url="https://todo",
53
+ ) from import_error
54
+
55
+
56
+ class YOLOACTForInstanceSegmentationOnnx(
57
+ InstanceSegmentationModel[
58
+ torch.Tensor,
59
+ PreProcessingMetadata,
60
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
61
+ ]
62
+ ):
63
+
64
+ @classmethod
65
+ def from_pretrained(
66
+ cls,
67
+ model_name_or_path: str,
68
+ onnx_execution_providers: Optional[List[Union[str, tuple]]] = None,
69
+ default_onnx_trt_options: bool = True,
70
+ device: torch.device = DEFAULT_DEVICE,
71
+ **kwargs,
72
+ ) -> "YOLOACTForInstanceSegmentationOnnx":
73
+ if onnx_execution_providers is None:
74
+ onnx_execution_providers = get_selected_onnx_execution_providers()
75
+ if not onnx_execution_providers:
76
+ raise EnvironmentConfigurationError(
77
+ message=f"Could not initialize model - selected backend is ONNX which requires execution provider to "
78
+ f"be specified - explicitly in `from_pretrained(...)` method or via env variable "
79
+ f"`ONNXRUNTIME_EXECUTION_PROVIDERS`. If you run model locally - adjust your setup, otherwise "
80
+ f"contact the platform support.",
81
+ help_url="https://todo",
82
+ )
83
+ onnx_execution_providers = set_execution_provider_defaults(
84
+ providers=onnx_execution_providers,
85
+ model_package_path=model_name_or_path,
86
+ device=device,
87
+ default_onnx_trt_options=default_onnx_trt_options,
88
+ )
89
+ model_package_content = get_model_package_contents(
90
+ model_package_dir=model_name_or_path,
91
+ elements=[
92
+ "class_names.txt",
93
+ "inference_config.json",
94
+ "weights.onnx",
95
+ ],
96
+ )
97
+ class_names = parse_class_names_file(
98
+ class_names_path=model_package_content["class_names.txt"]
99
+ )
100
+ inference_config = parse_inference_config(
101
+ config_path=model_package_content["inference_config.json"],
102
+ allowed_resize_modes={
103
+ ResizeMode.STRETCH_TO,
104
+ ResizeMode.LETTERBOX,
105
+ ResizeMode.CENTER_CROP,
106
+ ResizeMode.LETTERBOX_REFLECT_EDGES,
107
+ },
108
+ )
109
+ session = onnxruntime.InferenceSession(
110
+ path_or_bytes=model_package_content["weights.onnx"],
111
+ providers=onnx_execution_providers,
112
+ )
113
+ input_batch_size = session.get_inputs()[0].shape[0]
114
+ if input_batch_size != 1:
115
+ raise ModelRuntimeError(
116
+ message="Implementation of YOLOACTForInstanceSegmentationOnnx is adjusted to work correctly with "
117
+ "onnx models accepting inputs with `batch_size=1`. It can be extended if needed, but we've "
118
+ "not heard such request so far. If you find that a valueble feature - let us know via "
119
+ "https://github.com/roboflow/inference/issues"
120
+ )
121
+ input_name = session.get_inputs()[0].name
122
+ return cls(
123
+ session=session,
124
+ input_name=input_name,
125
+ class_names=class_names,
126
+ inference_config=inference_config,
127
+ device=device,
128
+ )
129
+
130
+ def __init__(
131
+ self,
132
+ session: onnxruntime.InferenceSession,
133
+ input_name: str,
134
+ inference_config: InferenceConfig,
135
+ class_names: List[str],
136
+ device: torch.device,
137
+ ):
138
+ self._session = session
139
+ self._input_name = input_name
140
+ self._inference_config = inference_config
141
+ self._class_names = class_names
142
+ self._device = device
143
+ self._session_thread_lock = Lock()
144
+
145
+ @property
146
+ def class_names(self) -> List[str]:
147
+ return self._class_names
148
+
149
+ def pre_process(
150
+ self,
151
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
152
+ input_color_format: Optional[ColorFormat] = None,
153
+ **kwargs,
154
+ ) -> Tuple[torch.Tensor, List[PreProcessingMetadata]]:
155
+ return pre_process_network_input(
156
+ images=images,
157
+ image_pre_processing=self._inference_config.image_pre_processing,
158
+ network_input=self._inference_config.network_input,
159
+ target_device=self._device,
160
+ input_color_format=input_color_format,
161
+ )
162
+
163
+ def forward(
164
+ self, pre_processed_images: torch.Tensor, **kwargs
165
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
166
+ with self._session_thread_lock:
167
+ (
168
+ all_loc_data,
169
+ all_conf_data,
170
+ all_mask_data,
171
+ all_prior_data,
172
+ all_proto_data,
173
+ ) = ([], [], [], [], [])
174
+ for image in pre_processed_images:
175
+ loc_data, conf_data, mask_data, prior_data, proto_data = (
176
+ run_session_with_batch_size_limit(
177
+ session=self._session,
178
+ inputs={self._input_name: image.unsqueeze(0).contiguous()},
179
+ )
180
+ )
181
+ all_loc_data.append(loc_data)
182
+ all_conf_data.append(conf_data)
183
+ all_mask_data.append(mask_data)
184
+ all_prior_data.append(prior_data)
185
+ all_proto_data.append(proto_data)
186
+ return (
187
+ torch.cat(all_loc_data, dim=0),
188
+ torch.cat(all_conf_data, dim=0),
189
+ torch.cat(all_mask_data, dim=0),
190
+ torch.stack(all_prior_data, dim=0),
191
+ torch.cat(all_proto_data, dim=0),
192
+ )
193
+
194
+ def post_process(
195
+ self,
196
+ model_results: Tuple[
197
+ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
198
+ ],
199
+ pre_processing_meta: List[PreProcessingMetadata],
200
+ conf_thresh: float = 0.25,
201
+ iou_thresh: float = 0.45,
202
+ max_detections: int = 100,
203
+ class_agnostic: bool = False,
204
+ **kwargs,
205
+ ) -> List[InstanceDetections]:
206
+ all_loc_data, all_conf_data, all_mask_data, all_prior_data, all_proto_data = (
207
+ model_results
208
+ )
209
+ batch_size = all_loc_data.shape[0]
210
+ num_priors = all_loc_data.shape[1]
211
+ boxes = torch.zeros((batch_size, num_priors, 4), device=self._device)
212
+ for batch_element_id, (
213
+ batch_element_loc_data,
214
+ batch_element_priors,
215
+ image_prep_meta,
216
+ ) in enumerate(zip(all_loc_data, all_prior_data, pre_processing_meta)):
217
+ image_boxes = decode_predicted_bboxes(
218
+ loc_data=batch_element_loc_data,
219
+ priors=batch_element_priors,
220
+ )
221
+ inference_height, inference_width = (
222
+ image_prep_meta.inference_size.height,
223
+ image_prep_meta.inference_size.width,
224
+ )
225
+ scale = torch.tensor(
226
+ [inference_width, inference_height, inference_width, inference_height],
227
+ device=self._device,
228
+ )
229
+ image_boxes = image_boxes.mul_(scale)
230
+ boxes[batch_element_id, :, :] = image_boxes
231
+ all_conf_data = all_conf_data[:, :, 1:] # remove background class
232
+ instances = torch.cat([boxes, all_conf_data, all_mask_data], dim=2)
233
+ nms_results = run_nms_for_instance_segmentation(
234
+ output=instances,
235
+ conf_thresh=conf_thresh,
236
+ iou_thresh=iou_thresh,
237
+ max_detections=max_detections,
238
+ class_agnostic=class_agnostic,
239
+ )
240
+ final_results = []
241
+ for image_bboxes, image_protos, image_meta in zip(
242
+ nms_results, all_proto_data, pre_processing_meta
243
+ ):
244
+ pre_processed_masks = image_protos @ image_bboxes[:, 6:].T
245
+ pre_processed_masks = 1 / (1 + torch.exp(-pre_processed_masks))
246
+ pre_processed_masks = torch.permute(pre_processed_masks, (2, 0, 1))
247
+ cropped_masks = crop_masks_to_boxes(
248
+ image_bboxes[:, :4], pre_processed_masks
249
+ )
250
+ padding = (
251
+ image_meta.pad_left,
252
+ image_meta.pad_top,
253
+ image_meta.pad_right,
254
+ image_meta.pad_bottom,
255
+ )
256
+ aligned_boxes, aligned_masks = align_instance_segmentation_results(
257
+ image_bboxes=image_bboxes,
258
+ masks=cropped_masks,
259
+ padding=padding,
260
+ scale_height=image_meta.scale_height,
261
+ scale_width=image_meta.scale_width,
262
+ original_size=image_meta.original_size,
263
+ size_after_pre_processing=image_meta.size_after_pre_processing,
264
+ inference_size=image_meta.inference_size,
265
+ static_crop_offset=image_meta.static_crop_offset,
266
+ binarization_threshold=0.5,
267
+ )
268
+ final_results.append(
269
+ InstanceDetections(
270
+ xyxy=aligned_boxes[:, :4].round().int(),
271
+ class_id=aligned_boxes[:, 5].int(),
272
+ confidence=aligned_boxes[:, 4],
273
+ mask=aligned_masks,
274
+ )
275
+ )
276
+ return final_results
277
+
278
+
279
+ def decode_predicted_bboxes(
280
+ loc_data: torch.Tensor, priors: torch.Tensor
281
+ ) -> torch.Tensor:
282
+ variances = torch.tensor([0.1, 0.2], device=loc_data.device)
283
+ boxes = torch.cat(
284
+ [
285
+ priors[:, :2] + loc_data[:, :2] * variances[0] * priors[:, 2:],
286
+ priors[:, 2:] * torch.exp(loc_data[:, 2:] * variances[1]),
287
+ ],
288
+ dim=1,
289
+ )
290
+ boxes[:, :2] -= boxes[:, 2:] / 2
291
+ boxes[:, 2:] += boxes[:, :2]
292
+ return boxes
293
+
294
+
295
+ def run_nms_for_instance_segmentation(
296
+ output: torch.Tensor,
297
+ conf_thresh: float = 0.25,
298
+ iou_thresh: float = 0.45,
299
+ max_detections: int = 100,
300
+ class_agnostic: bool = False,
301
+ ) -> List[torch.Tensor]:
302
+ bs = output.shape[0]
303
+ boxes = output[:, :, :4] # (N, 19248, 4)
304
+ scores = output[:, :, 4:-32] # (N, 19248, num_classes)
305
+ masks = output[:, :, -32:]
306
+ results = []
307
+ for b in range(bs):
308
+ bboxes = boxes[b] # (19248, 4)
309
+ class_scores = scores[b] # (19248, 80)
310
+ box_masks = masks[b]
311
+ class_conf, class_ids = class_scores.max(1) # (8400,), (8400,)
312
+ mask = class_conf > conf_thresh
313
+ if mask.sum() == 0:
314
+ results.append(torch.zeros((0, 38), device=output.device))
315
+ continue
316
+ bboxes = bboxes[mask]
317
+ class_conf = class_conf[mask]
318
+ class_ids = class_ids[mask]
319
+ box_masks = box_masks[mask]
320
+ # Class-agnostic NMS -> use dummy class ids
321
+ nms_class_ids = torch.zeros_like(class_ids) if class_agnostic else class_ids
322
+ keep = torchvision.ops.batched_nms(
323
+ bboxes, class_conf, nms_class_ids, iou_thresh
324
+ )
325
+ keep = keep[:max_detections]
326
+ detections = torch.cat(
327
+ [
328
+ bboxes[keep],
329
+ class_conf[keep].unsqueeze(1),
330
+ class_ids[keep].unsqueeze(1).float(),
331
+ box_masks[keep],
332
+ ],
333
+ dim=1,
334
+ ) # [x1, y1, x2, y2, conf, cls]
335
+ results.append(detections)
336
+ return results