inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,287 @@
1
+ from threading import Lock
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from inference_models import Detections, KeyPoints, KeyPointsDetectionModel
8
+ from inference_models.configuration import DEFAULT_DEVICE
9
+ from inference_models.entities import ColorFormat
10
+ from inference_models.errors import (
11
+ CorruptedModelPackageError,
12
+ MissingDependencyError,
13
+ ModelRuntimeError,
14
+ )
15
+ from inference_models.models.common.cuda import (
16
+ use_cuda_context,
17
+ use_primary_cuda_context,
18
+ )
19
+ from inference_models.models.common.model_packages import get_model_package_contents
20
+ from inference_models.models.common.roboflow.model_packages import (
21
+ InferenceConfig,
22
+ PreProcessingMetadata,
23
+ ResizeMode,
24
+ TRTConfig,
25
+ parse_class_names_file,
26
+ parse_inference_config,
27
+ parse_key_points_metadata,
28
+ parse_trt_config,
29
+ )
30
+ from inference_models.models.common.roboflow.post_processing import (
31
+ post_process_nms_fused_model_output,
32
+ rescale_key_points_detections,
33
+ run_nms_for_key_points_detection,
34
+ )
35
+ from inference_models.models.common.roboflow.pre_processing import (
36
+ pre_process_network_input,
37
+ )
38
+ from inference_models.models.common.trt import (
39
+ get_engine_inputs_and_outputs,
40
+ infer_from_trt_engine,
41
+ load_model,
42
+ )
43
+
44
+ try:
45
+ import tensorrt as trt
46
+ except ImportError as import_error:
47
+ raise MissingDependencyError(
48
+ message=f"Could not import YOLOv8 model with TRT backend - this error means that some additional dependencies "
49
+ f"are not installed in the environment. If you run the `inference-models` library directly in your Python "
50
+ f"program, make sure the following extras of the package are installed: `trt10` - installation can only "
51
+ f"succeed for Linux and Windows machines with Cuda 12 installed. Jetson devices, should have TRT 10.x "
52
+ f"installed for all builds with Jetpack 6. "
53
+ f"If you see this error using Roboflow infrastructure, make sure the service you use does support the model. "
54
+ f"You can also contact Roboflow to get support.",
55
+ help_url="https://todo",
56
+ ) from import_error
57
+
58
+ try:
59
+ import pycuda.driver as cuda
60
+ except ImportError as import_error:
61
+ raise MissingDependencyError(
62
+ message="TODO",
63
+ help_url="https://todo",
64
+ ) from import_error
65
+
66
+
67
+ class YOLOv8ForKeyPointsDetectionTRT(
68
+ KeyPointsDetectionModel[torch.Tensor, PreProcessingMetadata, torch.Tensor]
69
+ ):
70
+
71
+ @classmethod
72
+ def from_pretrained(
73
+ cls,
74
+ model_name_or_path: str,
75
+ device: torch.device = DEFAULT_DEVICE,
76
+ engine_host_code_allowed: bool = False,
77
+ **kwargs,
78
+ ) -> "YOLOv8ForKeyPointsDetectionTRT":
79
+ if device.type != "cuda":
80
+ raise ModelRuntimeError(
81
+ message=f"TRT engine only runs on CUDA device - {device} device detected.",
82
+ help_url="https://todo",
83
+ )
84
+ model_package_content = get_model_package_contents(
85
+ model_package_dir=model_name_or_path,
86
+ elements=[
87
+ "class_names.txt",
88
+ "inference_config.json",
89
+ "trt_config.json",
90
+ "engine.plan",
91
+ "keypoints_metadata.json",
92
+ ],
93
+ )
94
+ class_names = parse_class_names_file(
95
+ class_names_path=model_package_content["class_names.txt"]
96
+ )
97
+ inference_config = parse_inference_config(
98
+ config_path=model_package_content["inference_config.json"],
99
+ allowed_resize_modes={
100
+ ResizeMode.STRETCH_TO,
101
+ ResizeMode.LETTERBOX,
102
+ ResizeMode.CENTER_CROP,
103
+ ResizeMode.LETTERBOX_REFLECT_EDGES,
104
+ },
105
+ )
106
+ if inference_config.post_processing.type != "nms":
107
+ raise CorruptedModelPackageError(
108
+ message="Expected NMS to be the post-processing",
109
+ help_url="https://todo",
110
+ )
111
+ trt_config = parse_trt_config(
112
+ config_path=model_package_content["trt_config.json"]
113
+ )
114
+ parsed_key_points_metadata, skeletons = parse_key_points_metadata(
115
+ key_points_metadata_path=model_package_content["keypoints_metadata.json"]
116
+ )
117
+ cuda.init()
118
+ cuda_device = cuda.Device(device.index or 0)
119
+ with use_primary_cuda_context(cuda_device=cuda_device) as cuda_context:
120
+ engine = load_model(
121
+ model_path=model_package_content["engine.plan"],
122
+ engine_host_code_allowed=engine_host_code_allowed,
123
+ )
124
+ execution_context = engine.create_execution_context()
125
+ inputs, outputs = get_engine_inputs_and_outputs(engine=engine)
126
+ if len(inputs) != 1:
127
+ raise CorruptedModelPackageError(
128
+ message=f"Implementation assume single model input, found: {len(inputs)}.",
129
+ help_url="https://todo",
130
+ )
131
+ if len(outputs) != 1:
132
+ raise CorruptedModelPackageError(
133
+ message=f"Implementation assume single model output, found: {len(outputs)}.",
134
+ help_url="https://todo",
135
+ )
136
+ return cls(
137
+ engine=engine,
138
+ input_name=inputs[0],
139
+ output_name=outputs[0],
140
+ class_names=class_names,
141
+ skeletons=skeletons,
142
+ inference_config=inference_config,
143
+ parsed_key_points_metadata=parsed_key_points_metadata,
144
+ trt_config=trt_config,
145
+ device=device,
146
+ cuda_context=cuda_context,
147
+ execution_context=execution_context,
148
+ )
149
+
150
+ def __init__(
151
+ self,
152
+ engine: trt.ICudaEngine,
153
+ input_name: str,
154
+ output_name: str,
155
+ class_names: List[str],
156
+ skeletons: List[List[Tuple[int, int]]],
157
+ inference_config: InferenceConfig,
158
+ parsed_key_points_metadata: List[List[str]],
159
+ trt_config: TRTConfig,
160
+ device: torch.device,
161
+ cuda_context: cuda.Context,
162
+ execution_context: trt.IExecutionContext,
163
+ ):
164
+ self._engine = engine
165
+ self._input_name = input_name
166
+ self._output_names = [output_name]
167
+ self._cuda_context = cuda_context
168
+ self._execution_context = execution_context
169
+ self._class_names = class_names
170
+ self._skeletons = skeletons
171
+ self._inference_config = inference_config
172
+ self._parsed_key_points_metadata = parsed_key_points_metadata
173
+ self._trt_config = trt_config
174
+ self._device = device
175
+ self._session_thread_lock = Lock()
176
+ self._parsed_key_points_metadata = parsed_key_points_metadata
177
+ self._key_points_classes_for_instances = torch.tensor(
178
+ [len(e) for e in self._parsed_key_points_metadata], device=device
179
+ )
180
+ self._key_points_slots_in_prediction = max(
181
+ len(e) for e in parsed_key_points_metadata
182
+ )
183
+
184
+ @property
185
+ def class_names(self) -> List[str]:
186
+ return self._class_names
187
+
188
+ @property
189
+ def key_points_classes(self) -> List[List[str]]:
190
+ return self._parsed_key_points_metadata
191
+
192
+ @property
193
+ def skeletons(self) -> List[List[Tuple[int, int]]]:
194
+ return self._skeletons
195
+
196
+ def pre_process(
197
+ self,
198
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
199
+ input_color_format: Optional[ColorFormat] = None,
200
+ **kwargs,
201
+ ) -> Tuple[torch.Tensor, List[PreProcessingMetadata]]:
202
+ return pre_process_network_input(
203
+ images=images,
204
+ image_pre_processing=self._inference_config.image_pre_processing,
205
+ network_input=self._inference_config.network_input,
206
+ target_device=self._device,
207
+ input_color_format=input_color_format,
208
+ )
209
+
210
+ def forward(self, pre_processed_images: torch.Tensor, **kwargs) -> torch.Tensor:
211
+ with self._session_thread_lock:
212
+ with use_cuda_context(context=self._cuda_context):
213
+ return infer_from_trt_engine(
214
+ pre_processed_images=pre_processed_images,
215
+ trt_config=self._trt_config,
216
+ engine=self._engine,
217
+ context=self._execution_context,
218
+ device=self._device,
219
+ input_name=self._input_name,
220
+ outputs=self._output_names,
221
+ )[0]
222
+
223
+ def post_process(
224
+ self,
225
+ model_results: torch.Tensor,
226
+ pre_processing_meta: List[PreProcessingMetadata],
227
+ conf_thresh: float = 0.25,
228
+ iou_thresh: float = 0.45,
229
+ max_detections: int = 100,
230
+ class_agnostic: bool = False,
231
+ key_points_threshold: float = 0.3,
232
+ **kwargs,
233
+ ) -> Tuple[List[KeyPoints], Optional[List[Detections]]]:
234
+ if self._inference_config.post_processing.fused:
235
+ nms_results = post_process_nms_fused_model_output(
236
+ output=model_results, conf_thresh=conf_thresh
237
+ )
238
+ else:
239
+ nms_results = run_nms_for_key_points_detection(
240
+ output=model_results,
241
+ num_classes=len(self._class_names),
242
+ key_points_slots_in_prediction=self._key_points_slots_in_prediction,
243
+ conf_thresh=conf_thresh,
244
+ iou_thresh=iou_thresh,
245
+ max_detections=max_detections,
246
+ class_agnostic=class_agnostic,
247
+ )
248
+ rescaled_results = rescale_key_points_detections(
249
+ detections=nms_results,
250
+ images_metadata=pre_processing_meta,
251
+ num_classes=len(self._class_names),
252
+ key_points_slots_in_prediction=self._key_points_slots_in_prediction,
253
+ )
254
+ detections, all_key_points = [], []
255
+ for result in rescaled_results:
256
+ class_id = result[:, 5].int()
257
+ detections.append(
258
+ Detections(
259
+ xyxy=result[:, :4].round().int(),
260
+ class_id=class_id,
261
+ confidence=result[:, 4],
262
+ )
263
+ )
264
+ key_points_reshaped = result[:, 6:].view(
265
+ result.shape[0], self._key_points_slots_in_prediction, 3
266
+ )
267
+ xy = key_points_reshaped[:, :, :2]
268
+ confidence = key_points_reshaped[:, :, 2]
269
+ key_points_classes_for_instance_class = (
270
+ (self._key_points_classes_for_instances[class_id])
271
+ .unsqueeze(1)
272
+ .to(device=result.device)
273
+ )
274
+ instances_class_mask = (
275
+ torch.arange(self._key_points_slots_in_prediction, device=result.device)
276
+ .unsqueeze(0)
277
+ .repeat(result.shape[0], 1)
278
+ < key_points_classes_for_instance_class
279
+ )
280
+ confidence_mask = confidence < key_points_threshold
281
+ mask = instances_class_mask & confidence_mask
282
+ xy[mask] = 0.0
283
+ confidence[mask] = 0.0
284
+ all_key_points.append(
285
+ KeyPoints(xy=xy.round().int(), class_id=class_id, confidence=confidence)
286
+ )
287
+ return all_key_points, detections
@@ -0,0 +1,213 @@
1
+ from threading import Lock
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from inference_models import Detections, ObjectDetectionModel
8
+ from inference_models.configuration import DEFAULT_DEVICE
9
+ from inference_models.entities import ColorFormat
10
+ from inference_models.errors import (
11
+ CorruptedModelPackageError,
12
+ EnvironmentConfigurationError,
13
+ MissingDependencyError,
14
+ )
15
+ from inference_models.models.common.model_packages import get_model_package_contents
16
+ from inference_models.models.common.onnx import (
17
+ run_session_with_batch_size_limit,
18
+ set_execution_provider_defaults,
19
+ )
20
+ from inference_models.models.common.roboflow.model_packages import (
21
+ InferenceConfig,
22
+ PreProcessingMetadata,
23
+ ResizeMode,
24
+ parse_class_names_file,
25
+ parse_inference_config,
26
+ )
27
+ from inference_models.models.common.roboflow.post_processing import (
28
+ post_process_nms_fused_model_output,
29
+ rescale_detections,
30
+ run_nms_for_object_detection,
31
+ )
32
+ from inference_models.models.common.roboflow.pre_processing import (
33
+ pre_process_network_input,
34
+ )
35
+ from inference_models.utils.onnx_introspection import (
36
+ get_selected_onnx_execution_providers,
37
+ )
38
+
39
+ try:
40
+ import onnxruntime
41
+ except ImportError as import_error:
42
+ raise MissingDependencyError(
43
+ message=f"Could not import YOLOv8 model with ONNX backend - this error means that some additional dependencies "
44
+ f"are not installed in the environment. If you run the `inference-models` library directly in your Python "
45
+ f"program, make sure the following extras of the package are installed: \n"
46
+ f"\t* `onnx-cpu` - when you wish to use library with CPU support only\n"
47
+ f"\t* `onnx-cu12` - for running on GPU with Cuda 12 installed\n"
48
+ f"\t* `onnx-cu118` - for running on GPU with Cuda 11.8 installed\n"
49
+ f"\t* `onnx-jp6-cu126` - for running on Jetson with Jetpack 6\n"
50
+ f"If you see this error using Roboflow infrastructure, make sure the service you use does support the model. "
51
+ f"You can also contact Roboflow to get support.",
52
+ help_url="https://todo",
53
+ ) from import_error
54
+
55
+
56
+ class YOLOv8ForObjectDetectionOnnx(
57
+ ObjectDetectionModel[torch.Tensor, PreProcessingMetadata, torch.Tensor]
58
+ ):
59
+
60
+ @classmethod
61
+ def from_pretrained(
62
+ cls,
63
+ model_name_or_path: str,
64
+ onnx_execution_providers: Optional[List[Union[str, tuple]]] = None,
65
+ default_onnx_trt_options: bool = True,
66
+ device: torch.device = DEFAULT_DEVICE,
67
+ **kwargs,
68
+ ) -> "YOLOv8ForObjectDetectionOnnx":
69
+ if onnx_execution_providers is None:
70
+ onnx_execution_providers = get_selected_onnx_execution_providers()
71
+ if not onnx_execution_providers:
72
+ raise EnvironmentConfigurationError(
73
+ message=f"Could not initialize model - selected backend is ONNX which requires execution provider to "
74
+ f"be specified - explicitly in `from_pretrained(...)` method or via env variable "
75
+ f"`ONNXRUNTIME_EXECUTION_PROVIDERS`. If you run model locally - adjust your setup, otherwise "
76
+ f"contact the platform support.",
77
+ help_url="https://todo",
78
+ )
79
+ onnx_execution_providers = set_execution_provider_defaults(
80
+ providers=onnx_execution_providers,
81
+ model_package_path=model_name_or_path,
82
+ device=device,
83
+ default_onnx_trt_options=default_onnx_trt_options,
84
+ )
85
+ model_package_content = get_model_package_contents(
86
+ model_package_dir=model_name_or_path,
87
+ elements=[
88
+ "class_names.txt",
89
+ "inference_config.json",
90
+ "weights.onnx",
91
+ ],
92
+ )
93
+ class_names = parse_class_names_file(
94
+ class_names_path=model_package_content["class_names.txt"]
95
+ )
96
+ inference_config = parse_inference_config(
97
+ config_path=model_package_content["inference_config.json"],
98
+ allowed_resize_modes={
99
+ ResizeMode.STRETCH_TO,
100
+ ResizeMode.LETTERBOX,
101
+ ResizeMode.CENTER_CROP,
102
+ ResizeMode.LETTERBOX_REFLECT_EDGES,
103
+ },
104
+ )
105
+ if inference_config.post_processing.type != "nms":
106
+ raise CorruptedModelPackageError(
107
+ message="Expected NMS to be the post-processing",
108
+ help_url="https://todo",
109
+ )
110
+ session = onnxruntime.InferenceSession(
111
+ path_or_bytes=model_package_content["weights.onnx"],
112
+ providers=onnx_execution_providers,
113
+ )
114
+ input_batch_size = session.get_inputs()[0].shape[0]
115
+ if isinstance(input_batch_size, str):
116
+ input_batch_size = None
117
+ input_name = session.get_inputs()[0].name
118
+ return cls(
119
+ session=session,
120
+ input_name=input_name,
121
+ class_names=class_names,
122
+ inference_config=inference_config,
123
+ device=device,
124
+ input_batch_size=input_batch_size,
125
+ )
126
+
127
+ def __init__(
128
+ self,
129
+ session: onnxruntime.InferenceSession,
130
+ input_name: str,
131
+ inference_config: InferenceConfig,
132
+ class_names: List[str],
133
+ device: torch.device,
134
+ input_batch_size: Optional[int],
135
+ ):
136
+ self._session = session
137
+ self._input_name = input_name
138
+ self._inference_config = inference_config
139
+ self._class_names = class_names
140
+ self._device = device
141
+ self._min_batch_size = input_batch_size
142
+ self._max_batch_size = (
143
+ input_batch_size
144
+ if input_batch_size is not None
145
+ else inference_config.forward_pass.max_dynamic_batch_size
146
+ )
147
+ self._session_thread_lock = Lock()
148
+
149
+ @property
150
+ def class_names(self) -> List[str]:
151
+ return self._class_names
152
+
153
+ def pre_process(
154
+ self,
155
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
156
+ input_color_format: Optional[ColorFormat] = None,
157
+ image_size: Optional[Union[Tuple[int, int], int]] = None,
158
+ **kwargs,
159
+ ) -> Tuple[torch.Tensor, List[PreProcessingMetadata]]:
160
+ return pre_process_network_input(
161
+ images=images,
162
+ image_pre_processing=self._inference_config.image_pre_processing,
163
+ network_input=self._inference_config.network_input,
164
+ target_device=self._device,
165
+ input_color_format=input_color_format,
166
+ image_size_wh=image_size,
167
+ )
168
+
169
+ def forward(self, pre_processed_images: torch.Tensor, **kwargs) -> torch.Tensor:
170
+ with self._session_thread_lock:
171
+ return run_session_with_batch_size_limit(
172
+ session=self._session,
173
+ inputs={self._input_name: pre_processed_images},
174
+ min_batch_size=self._min_batch_size,
175
+ max_batch_size=self._max_batch_size,
176
+ )[0]
177
+
178
+ def post_process(
179
+ self,
180
+ model_results: torch.Tensor,
181
+ pre_processing_meta: List[PreProcessingMetadata],
182
+ conf_thresh: float = 0.25,
183
+ iou_thresh: float = 0.45,
184
+ max_detections: int = 100,
185
+ class_agnostic: bool = False,
186
+ **kwargs,
187
+ ) -> List[Detections]:
188
+ if self._inference_config.post_processing.fused:
189
+ nms_results = post_process_nms_fused_model_output(
190
+ output=model_results, conf_thresh=conf_thresh
191
+ )
192
+ else:
193
+ nms_results = run_nms_for_object_detection(
194
+ output=model_results,
195
+ conf_thresh=conf_thresh,
196
+ iou_thresh=iou_thresh,
197
+ max_detections=max_detections,
198
+ class_agnostic=class_agnostic,
199
+ )
200
+ rescaled_results = rescale_detections(
201
+ detections=nms_results,
202
+ images_metadata=pre_processing_meta,
203
+ )
204
+ results = []
205
+ for result in rescaled_results:
206
+ results.append(
207
+ Detections(
208
+ xyxy=result[:, :4].round().int(),
209
+ class_id=result[:, 5].int(),
210
+ confidence=result[:, 4],
211
+ )
212
+ )
213
+ return results
@@ -0,0 +1,166 @@
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torchvision # DO NOT REMOVE, THIS IMPORT ENABLES NMS OPERATION
6
+
7
+ from inference_models import Detections, ObjectDetectionModel
8
+ from inference_models.configuration import DEFAULT_DEVICE
9
+ from inference_models.entities import ColorFormat
10
+ from inference_models.errors import CorruptedModelPackageError
11
+ from inference_models.models.common.model_packages import get_model_package_contents
12
+ from inference_models.models.common.roboflow.model_packages import (
13
+ InferenceConfig,
14
+ PreProcessingMetadata,
15
+ ResizeMode,
16
+ parse_class_names_file,
17
+ parse_inference_config,
18
+ )
19
+ from inference_models.models.common.roboflow.post_processing import (
20
+ post_process_nms_fused_model_output,
21
+ rescale_detections,
22
+ run_nms_for_object_detection,
23
+ )
24
+ from inference_models.models.common.roboflow.pre_processing import (
25
+ pre_process_network_input,
26
+ )
27
+ from inference_models.models.common.torch import generate_batch_chunks
28
+
29
+
30
+ class YOLOv8ForObjectDetectionTorchScript(
31
+ ObjectDetectionModel[torch.Tensor, PreProcessingMetadata, torch.Tensor]
32
+ ):
33
+
34
+ @classmethod
35
+ def from_pretrained(
36
+ cls,
37
+ model_name_or_path: str,
38
+ device: torch.device = DEFAULT_DEVICE,
39
+ **kwargs,
40
+ ) -> "YOLOv8ForObjectDetectionTorchScript":
41
+ model_package_content = get_model_package_contents(
42
+ model_package_dir=model_name_or_path,
43
+ elements=[
44
+ "class_names.txt",
45
+ "inference_config.json",
46
+ "weights.torchscript",
47
+ ],
48
+ )
49
+ class_names = parse_class_names_file(
50
+ class_names_path=model_package_content["class_names.txt"]
51
+ )
52
+ inference_config = parse_inference_config(
53
+ config_path=model_package_content["inference_config.json"],
54
+ allowed_resize_modes={
55
+ ResizeMode.STRETCH_TO,
56
+ ResizeMode.LETTERBOX,
57
+ ResizeMode.CENTER_CROP,
58
+ ResizeMode.LETTERBOX_REFLECT_EDGES,
59
+ },
60
+ )
61
+ if inference_config.post_processing.type != "nms":
62
+ raise CorruptedModelPackageError(
63
+ message="Expected NMS to be the post-processing",
64
+ help_url="https://todo",
65
+ )
66
+ if inference_config.forward_pass.static_batch_size is None:
67
+ raise CorruptedModelPackageError(
68
+ message="Expected static batch size to be registered in the inference configuration.",
69
+ help_url="https://todo",
70
+ )
71
+ model = torch.jit.load(
72
+ model_package_content["weights.torchscript"], map_location=device
73
+ ).eval()
74
+ return cls(
75
+ model=model,
76
+ class_names=class_names,
77
+ inference_config=inference_config,
78
+ device=device,
79
+ )
80
+
81
+ def __init__(
82
+ self,
83
+ model: torch.nn.Module,
84
+ inference_config: InferenceConfig,
85
+ class_names: List[str],
86
+ device: torch.device,
87
+ ):
88
+ self._model = model
89
+ self._inference_config = inference_config
90
+ self._class_names = class_names
91
+ self._device = device
92
+
93
+ @property
94
+ def class_names(self) -> List[str]:
95
+ return self._class_names
96
+
97
+ def pre_process(
98
+ self,
99
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
100
+ input_color_format: Optional[ColorFormat] = None,
101
+ image_size: Optional[Union[Tuple[int, int], int]] = None,
102
+ **kwargs,
103
+ ) -> Tuple[torch.Tensor, List[PreProcessingMetadata]]:
104
+ return pre_process_network_input(
105
+ images=images,
106
+ image_pre_processing=self._inference_config.image_pre_processing,
107
+ network_input=self._inference_config.network_input,
108
+ target_device=self._device,
109
+ input_color_format=input_color_format,
110
+ image_size_wh=image_size,
111
+ )
112
+
113
+ def forward(self, pre_processed_images: torch.Tensor, **kwargs) -> torch.Tensor:
114
+ with torch.inference_mode():
115
+ if (
116
+ pre_processed_images.shape[0]
117
+ == self._inference_config.forward_pass.static_batch_size
118
+ ):
119
+ return self._model(pre_processed_images).to(self._device)
120
+ results = []
121
+ for input_tensor, padding_size in generate_batch_chunks(
122
+ input_batch=pre_processed_images,
123
+ chunk_size=self._inference_config.forward_pass.static_batch_size,
124
+ ):
125
+ result_for_chunk = self._model(input_tensor)
126
+ if padding_size > 0:
127
+ result_for_chunk = result_for_chunk[:-padding_size]
128
+ results.append(result_for_chunk)
129
+ return torch.cat(results, dim=0).to(self._device)
130
+
131
+ def post_process(
132
+ self,
133
+ model_results: torch.Tensor,
134
+ pre_processing_meta: List[PreProcessingMetadata],
135
+ conf_thresh: float = 0.25,
136
+ iou_thresh: float = 0.45,
137
+ max_detections: int = 100,
138
+ class_agnostic: bool = False,
139
+ **kwargs,
140
+ ) -> List[Detections]:
141
+ if self._inference_config.post_processing.fused:
142
+ nms_results = post_process_nms_fused_model_output(
143
+ output=model_results, conf_thresh=conf_thresh
144
+ )
145
+ else:
146
+ nms_results = run_nms_for_object_detection(
147
+ output=model_results,
148
+ conf_thresh=conf_thresh,
149
+ iou_thresh=iou_thresh,
150
+ max_detections=max_detections,
151
+ class_agnostic=class_agnostic,
152
+ )
153
+ rescaled_results = rescale_detections(
154
+ detections=nms_results,
155
+ images_metadata=pre_processing_meta,
156
+ )
157
+ results = []
158
+ for result in rescaled_results:
159
+ results.append(
160
+ Detections(
161
+ xyxy=result[:, :4].round().int(),
162
+ class_id=result[:, 5].int(),
163
+ confidence=result[:, 4],
164
+ )
165
+ )
166
+ return results