inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,348 @@
1
+ from threading import Lock
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from inference_models import (
8
+ ClassificationModel,
9
+ ClassificationPrediction,
10
+ MultiLabelClassificationModel,
11
+ MultiLabelClassificationPrediction,
12
+ )
13
+ from inference_models.configuration import DEFAULT_DEVICE
14
+ from inference_models.entities import ColorFormat
15
+ from inference_models.errors import (
16
+ CorruptedModelPackageError,
17
+ EnvironmentConfigurationError,
18
+ MissingDependencyError,
19
+ )
20
+ from inference_models.models.base.types import PreprocessedInputs
21
+ from inference_models.models.common.model_packages import get_model_package_contents
22
+ from inference_models.models.common.onnx import (
23
+ run_session_with_batch_size_limit,
24
+ set_execution_provider_defaults,
25
+ )
26
+ from inference_models.models.common.roboflow.model_packages import (
27
+ InferenceConfig,
28
+ ResizeMode,
29
+ parse_class_names_file,
30
+ parse_inference_config,
31
+ )
32
+ from inference_models.models.common.roboflow.pre_processing import (
33
+ pre_process_network_input,
34
+ )
35
+ from inference_models.utils.onnx_introspection import (
36
+ get_selected_onnx_execution_providers,
37
+ )
38
+
39
+ try:
40
+ import onnxruntime
41
+ except ImportError as import_error:
42
+ raise MissingDependencyError(
43
+ message=f"Could not import DINOv3 model with ONNX backend - this error means that some additional dependencies "
44
+ f"are not installed in the environment. If you run the `inference-models` library directly in your Python "
45
+ f"program, make sure the following extras of the package are installed: \n"
46
+ f"\t* `onnx-cpu` - when you wish to use library with CPU support only\n"
47
+ f"\t* `onnx-cu12` - for running on GPU with Cuda 12 installed\n"
48
+ f"\t* `onnx-cu118` - for running on GPU with Cuda 11.8 installed\n"
49
+ f"\t* `onnx-jp6-cu126` - for running on Jetson with Jetpack 6\n"
50
+ f"If you see this error using Roboflow infrastructure, make sure the service you use does support the model. "
51
+ f"You can also contact Roboflow to get support.",
52
+ help_url="https://todo",
53
+ ) from import_error
54
+
55
+
56
+ class DinoV3ForClassificationOnnx(ClassificationModel[torch.Tensor, torch.Tensor]):
57
+
58
+ @classmethod
59
+ def from_pretrained(
60
+ cls,
61
+ model_name_or_path: str,
62
+ onnx_execution_providers: Optional[List[Union[str, tuple]]] = None,
63
+ default_onnx_trt_options: bool = True,
64
+ device: torch.device = DEFAULT_DEVICE,
65
+ **kwargs,
66
+ ) -> "DinoV3ForClassificationOnnx":
67
+ if onnx_execution_providers is None:
68
+ onnx_execution_providers = get_selected_onnx_execution_providers() # type: ignore
69
+ if not onnx_execution_providers:
70
+ raise EnvironmentConfigurationError(
71
+ message=f"Could not initialize model - selected backend is ONNX which requires execution provider to "
72
+ f"be specified - explicitly in `from_pretrained(...)` method or via env variable "
73
+ f"`ONNXRUNTIME_EXECUTION_PROVIDERS`. If you run model locally - adjust your setup, otherwise "
74
+ f"contact the platform support.",
75
+ help_url="https://todo",
76
+ )
77
+ onnx_execution_providers = set_execution_provider_defaults(
78
+ providers=onnx_execution_providers,
79
+ model_package_path=model_name_or_path,
80
+ device=device,
81
+ default_onnx_trt_options=default_onnx_trt_options,
82
+ )
83
+
84
+ required_files = ["class_names.txt", "inference_config.json"]
85
+ weights_file = "weights.onnx"
86
+ model_package_content = get_model_package_contents(
87
+ model_package_dir=model_name_or_path,
88
+ elements=required_files + [weights_file],
89
+ )
90
+
91
+ class_names = parse_class_names_file(
92
+ class_names_path=model_package_content["class_names.txt"]
93
+ )
94
+ inference_config = parse_inference_config(
95
+ config_path=model_package_content["inference_config.json"],
96
+ allowed_resize_modes={
97
+ ResizeMode.STRETCH_TO,
98
+ ResizeMode.LETTERBOX,
99
+ ResizeMode.CENTER_CROP,
100
+ ResizeMode.LETTERBOX_REFLECT_EDGES,
101
+ },
102
+ )
103
+
104
+ if (
105
+ not inference_config.post_processing
106
+ or inference_config.post_processing.type != "softmax"
107
+ ):
108
+ raise CorruptedModelPackageError(
109
+ message="Expected Softmax to be the post-processing",
110
+ help_url="https://todo",
111
+ )
112
+
113
+ session = onnxruntime.InferenceSession(
114
+ path_or_bytes=model_package_content[weights_file],
115
+ providers=onnx_execution_providers,
116
+ )
117
+ input_shape = session.get_inputs()[0].shape
118
+ input_batch_size = input_shape[0]
119
+ if isinstance(input_batch_size, str):
120
+ input_batch_size = None
121
+ input_name = session.get_inputs()[0].name
122
+
123
+ return cls(
124
+ session=session,
125
+ input_name=input_name,
126
+ inference_config=inference_config,
127
+ class_names=class_names,
128
+ device=device,
129
+ input_batch_size=input_batch_size,
130
+ )
131
+
132
+ def __init__(
133
+ self,
134
+ session: onnxruntime.InferenceSession,
135
+ input_name: str,
136
+ inference_config: InferenceConfig,
137
+ class_names: List[str],
138
+ device: torch.device,
139
+ input_batch_size: Optional[int],
140
+ ):
141
+ self._session = session
142
+ self._input_name = input_name
143
+ self._inference_config = inference_config
144
+ self._class_names = class_names
145
+ self._device = device
146
+ self._input_batch_size = input_batch_size
147
+ self._session_thread_lock = Lock()
148
+
149
+ @property
150
+ def class_names(self) -> List[str]:
151
+ return self._class_names
152
+
153
+ def pre_process(
154
+ self,
155
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
156
+ input_color_format: Optional[ColorFormat] = None,
157
+ image_size: Optional[Tuple[int, int]] = None,
158
+ **kwargs,
159
+ ) -> torch.Tensor:
160
+ return pre_process_network_input(
161
+ images=images,
162
+ image_pre_processing=self._inference_config.image_pre_processing,
163
+ network_input=self._inference_config.network_input,
164
+ target_device=self._device,
165
+ input_color_format=input_color_format,
166
+ image_size_wh=image_size,
167
+ )[0]
168
+
169
+ def forward(
170
+ self, pre_processed_images: PreprocessedInputs, **kwargs # type: ignore
171
+ ) -> torch.Tensor:
172
+ with self._session_thread_lock:
173
+ return run_session_with_batch_size_limit(
174
+ session=self._session,
175
+ inputs={self._input_name: pre_processed_images}, # type: ignore
176
+ min_batch_size=self._input_batch_size,
177
+ max_batch_size=self._input_batch_size,
178
+ )[0]
179
+
180
+ def post_process(
181
+ self,
182
+ model_results: torch.Tensor,
183
+ **kwargs,
184
+ ) -> ClassificationPrediction:
185
+ if (
186
+ self._inference_config.post_processing
187
+ and self._inference_config.post_processing.fused
188
+ ):
189
+ confidence = model_results
190
+ else:
191
+ confidence = torch.nn.functional.softmax(model_results, dim=-1)
192
+ return ClassificationPrediction(
193
+ class_id=confidence.argmax(dim=-1),
194
+ confidence=confidence,
195
+ )
196
+
197
+
198
+ class DinoV3ForMultiLabelClassificationOnnx(
199
+ MultiLabelClassificationModel[torch.Tensor, torch.Tensor]
200
+ ):
201
+
202
+ @classmethod
203
+ def from_pretrained(
204
+ cls,
205
+ model_name_or_path: str,
206
+ onnx_execution_providers: Optional[List[Union[str, tuple]]] = None,
207
+ default_onnx_trt_options: bool = True,
208
+ device: torch.device = DEFAULT_DEVICE,
209
+ **kwargs,
210
+ ) -> "DinoV3ForMultiLabelClassificationOnnx":
211
+ if onnx_execution_providers is None:
212
+ onnx_execution_providers = get_selected_onnx_execution_providers() # type: ignore
213
+ if not onnx_execution_providers:
214
+ raise EnvironmentConfigurationError(
215
+ message=f"Could not initialize model - selected backend is ONNX which requires execution provider to "
216
+ f"be specified - explicitly in `from_pretrained(...)` method or via env variable "
217
+ f"`ONNXRUNTIME_EXECUTION_PROVIDERS`. If you run model locally - adjust your setup, otherwise "
218
+ f"contact the platform support.",
219
+ help_url="https://todo",
220
+ )
221
+ onnx_execution_providers = set_execution_provider_defaults(
222
+ providers=onnx_execution_providers,
223
+ model_package_path=model_name_or_path,
224
+ device=device,
225
+ default_onnx_trt_options=default_onnx_trt_options,
226
+ )
227
+
228
+ required_files = ["class_names.txt", "inference_config.json"]
229
+ weights_file = "weights.onnx"
230
+ model_package_content = get_model_package_contents(
231
+ model_package_dir=model_name_or_path,
232
+ elements=required_files + [weights_file],
233
+ )
234
+
235
+ class_names = parse_class_names_file(
236
+ class_names_path=model_package_content["class_names.txt"]
237
+ )
238
+ inference_config = parse_inference_config(
239
+ config_path=model_package_content["inference_config.json"],
240
+ allowed_resize_modes={
241
+ ResizeMode.STRETCH_TO,
242
+ ResizeMode.LETTERBOX,
243
+ ResizeMode.CENTER_CROP,
244
+ ResizeMode.LETTERBOX_REFLECT_EDGES,
245
+ },
246
+ )
247
+
248
+ if (
249
+ inference_config.post_processing
250
+ and inference_config.post_processing.type != "sigmoid"
251
+ ):
252
+ raise CorruptedModelPackageError(
253
+ message="Expected Sigmoid to be the post-processing",
254
+ help_url="https://todo",
255
+ )
256
+
257
+ session = onnxruntime.InferenceSession(
258
+ path_or_bytes=model_package_content[weights_file],
259
+ providers=onnx_execution_providers,
260
+ )
261
+ input_shape = session.get_inputs()[0].shape
262
+ input_batch_size = input_shape[0]
263
+ if isinstance(input_batch_size, str):
264
+ input_batch_size = None
265
+ input_name = session.get_inputs()[0].name
266
+
267
+ return cls(
268
+ session=session,
269
+ input_name=input_name,
270
+ inference_config=inference_config,
271
+ class_names=class_names,
272
+ device=device,
273
+ input_batch_size=input_batch_size,
274
+ )
275
+
276
+ def __init__(
277
+ self,
278
+ session: onnxruntime.InferenceSession,
279
+ input_name: str,
280
+ inference_config: InferenceConfig,
281
+ class_names: List[str],
282
+ device: torch.device,
283
+ input_batch_size: Optional[int],
284
+ ):
285
+ self._session = session
286
+ self._input_name = input_name
287
+ self._inference_config = inference_config
288
+ self._class_names = class_names
289
+ self._device = device
290
+ self._input_batch_size = input_batch_size
291
+ self._session_thread_lock = Lock()
292
+
293
+ @property
294
+ def class_names(self) -> List[str]:
295
+ return self._class_names
296
+
297
+ def pre_process(
298
+ self,
299
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
300
+ input_color_format: Optional[ColorFormat] = None,
301
+ image_size: Optional[Tuple[int, int]] = None,
302
+ **kwargs,
303
+ ) -> torch.Tensor:
304
+ return pre_process_network_input(
305
+ images=images,
306
+ image_pre_processing=self._inference_config.image_pre_processing,
307
+ network_input=self._inference_config.network_input,
308
+ target_device=self._device,
309
+ input_color_format=input_color_format,
310
+ image_size_wh=image_size,
311
+ )[0]
312
+
313
+ def forward(
314
+ self, pre_processed_images: PreprocessedInputs, **kwargs # type: ignore
315
+ ) -> torch.Tensor:
316
+ with self._session_thread_lock:
317
+ return run_session_with_batch_size_limit(
318
+ session=self._session,
319
+ inputs={self._input_name: pre_processed_images}, # type: ignore
320
+ min_batch_size=self._input_batch_size,
321
+ max_batch_size=self._input_batch_size,
322
+ )[0]
323
+
324
+ def post_process(
325
+ self,
326
+ model_results: torch.Tensor,
327
+ confidence: float = 0.5,
328
+ **kwargs,
329
+ ) -> List[MultiLabelClassificationPrediction]:
330
+ if (
331
+ self._inference_config.post_processing
332
+ and self._inference_config.post_processing.fused
333
+ ):
334
+ model_results = model_results
335
+ else:
336
+ model_results = torch.nn.functional.sigmoid(model_results)
337
+ results = []
338
+ for batch_element_confidence in model_results:
339
+ predicted_classes = torch.argwhere(
340
+ batch_element_confidence >= confidence
341
+ ).squeeze(dim=-1)
342
+ results.append(
343
+ MultiLabelClassificationPrediction(
344
+ class_ids=predicted_classes,
345
+ confidence=batch_element_confidence,
346
+ )
347
+ )
348
+ return results
@@ -0,0 +1,323 @@
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import numpy as np
4
+ import timm
5
+ import torch
6
+ from torch import nn
7
+
8
+ from inference_models import (
9
+ ClassificationModel,
10
+ ClassificationPrediction,
11
+ MultiLabelClassificationModel,
12
+ MultiLabelClassificationPrediction,
13
+ )
14
+ from inference_models.configuration import DEFAULT_DEVICE
15
+ from inference_models.entities import ColorFormat
16
+ from inference_models.errors import CorruptedModelPackageError
17
+ from inference_models.models.common.model_packages import get_model_package_contents
18
+ from inference_models.models.common.roboflow.model_packages import (
19
+ InferenceConfig,
20
+ ResizeMode,
21
+ parse_class_names_file,
22
+ parse_inference_config,
23
+ )
24
+ from inference_models.models.common.roboflow.pre_processing import (
25
+ pre_process_network_input,
26
+ )
27
+
28
+
29
+ class DinoV3Model(nn.Module):
30
+ """DINOv3 model for classification using timm's EVA ViT backbone."""
31
+
32
+ def __init__(
33
+ self, num_classes: int, model_name: str = "vit_small_patch16_dinov3.lvd1689m"
34
+ ):
35
+ """
36
+ Args:
37
+ num_classes: Number of classes to classify
38
+ model_name: Name of the backbone model from timm
39
+ """
40
+ super().__init__()
41
+ self.num_classes = num_classes
42
+ self.model_name = model_name
43
+
44
+ # DinoV3 is implemented as a parameterization of EVA ViT in timm
45
+ self.backbone: timm.models.Eva = timm.create_model(
46
+ self.model_name, pretrained=False
47
+ )
48
+ self.backbone = self.backbone.eval()
49
+ self.linear_layer = nn.Linear(self.backbone.embed_dim, num_classes)
50
+
51
+ def forward_embedding(self, x: torch.Tensor) -> torch.Tensor:
52
+ """Extract features using the CLS token (position 0)."""
53
+ return self.backbone.forward_features(x)[:, 0]
54
+
55
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
56
+ return self.linear_layer(self.forward_embedding(x))
57
+
58
+
59
+ class DinoV3ForClassificationTorch(ClassificationModel[torch.Tensor, torch.Tensor]):
60
+
61
+ @classmethod
62
+ def from_pretrained(
63
+ cls,
64
+ model_name_or_path: str,
65
+ device: torch.device = DEFAULT_DEVICE,
66
+ **kwargs,
67
+ ) -> "DinoV3ForClassificationTorch":
68
+ model_package_content = get_model_package_contents(
69
+ model_package_dir=model_name_or_path,
70
+ elements=[
71
+ "class_names.txt",
72
+ "inference_config.json",
73
+ "weights.pth",
74
+ ],
75
+ )
76
+ class_names = parse_class_names_file(
77
+ class_names_path=model_package_content["class_names.txt"]
78
+ )
79
+ inference_config = parse_inference_config(
80
+ config_path=model_package_content["inference_config.json"],
81
+ allowed_resize_modes={
82
+ ResizeMode.STRETCH_TO,
83
+ ResizeMode.LETTERBOX,
84
+ ResizeMode.CENTER_CROP,
85
+ ResizeMode.LETTERBOX_REFLECT_EDGES,
86
+ },
87
+ )
88
+
89
+ if inference_config.model_initialization is None:
90
+ raise CorruptedModelPackageError(
91
+ message="Expected model initialization parameters not provided in inference config.",
92
+ help_url="https://todo",
93
+ )
94
+ num_classes = inference_config.model_initialization.get("num_classes")
95
+ model_name = inference_config.model_initialization.get("model_name")
96
+ if not isinstance(num_classes, int):
97
+ raise CorruptedModelPackageError(
98
+ message="Expected model initialization parameter `num_classes` not provided or in invalid format.",
99
+ help_url="https://todo",
100
+ )
101
+ if not isinstance(model_name, str):
102
+ raise CorruptedModelPackageError(
103
+ message="Expected model initialization parameter `model_name` not provided or in invalid format.",
104
+ help_url="https://todo",
105
+ )
106
+
107
+ if (
108
+ not inference_config.post_processing
109
+ or inference_config.post_processing.type != "softmax"
110
+ ):
111
+ raise CorruptedModelPackageError(
112
+ message="Expected Softmax to be the post-processing",
113
+ help_url="https://todo",
114
+ )
115
+
116
+ # Create model and load weights
117
+ model = DinoV3Model(num_classes=num_classes, model_name=model_name)
118
+ state_dict = torch.load(
119
+ model_package_content["weights.pth"],
120
+ map_location=device,
121
+ weights_only=True,
122
+ )
123
+ model.load_state_dict(state_dict)
124
+ model = model.to(device).eval()
125
+
126
+ return cls(
127
+ model=model,
128
+ inference_config=inference_config,
129
+ class_names=class_names,
130
+ device=device,
131
+ )
132
+
133
+ def __init__(
134
+ self,
135
+ model: DinoV3Model,
136
+ inference_config: InferenceConfig,
137
+ class_names: List[str],
138
+ device: torch.device,
139
+ ):
140
+ self._model = model
141
+ self._inference_config = inference_config
142
+ self._class_names = class_names
143
+ self._device = device
144
+
145
+ @property
146
+ def class_names(self) -> List[str]:
147
+ return self._class_names
148
+
149
+ def pre_process(
150
+ self,
151
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
152
+ input_color_format: Optional[ColorFormat] = None,
153
+ image_size: Optional[Tuple[int, int]] = None,
154
+ **kwargs,
155
+ ) -> torch.Tensor:
156
+ return pre_process_network_input(
157
+ images=images,
158
+ image_pre_processing=self._inference_config.image_pre_processing,
159
+ network_input=self._inference_config.network_input,
160
+ target_device=self._device,
161
+ input_color_format=input_color_format,
162
+ image_size_wh=image_size,
163
+ )[0]
164
+
165
+ def forward(self, pre_processed_images: torch.Tensor, **kwargs) -> torch.Tensor:
166
+ with torch.inference_mode():
167
+ return self._model(pre_processed_images)
168
+
169
+ def post_process(
170
+ self,
171
+ model_results: torch.Tensor,
172
+ **kwargs,
173
+ ) -> ClassificationPrediction:
174
+ if (
175
+ self._inference_config.post_processing
176
+ and self._inference_config.post_processing.fused
177
+ ):
178
+ confidence = model_results
179
+ else:
180
+ confidence = torch.nn.functional.softmax(model_results, dim=-1)
181
+ return ClassificationPrediction(
182
+ class_id=confidence.argmax(dim=-1),
183
+ confidence=confidence,
184
+ )
185
+
186
+
187
+ class DinoV3ForMultiLabelClassificationTorch(
188
+ MultiLabelClassificationModel[torch.Tensor, torch.Tensor]
189
+ ):
190
+
191
+ @classmethod
192
+ def from_pretrained(
193
+ cls,
194
+ model_name_or_path: str,
195
+ device: torch.device = DEFAULT_DEVICE,
196
+ **kwargs,
197
+ ) -> "DinoV3ForMultiLabelClassificationTorch":
198
+ model_package_content = get_model_package_contents(
199
+ model_package_dir=model_name_or_path,
200
+ elements=[
201
+ "class_names.txt",
202
+ "inference_config.json",
203
+ "weights.pth",
204
+ ],
205
+ )
206
+ class_names = parse_class_names_file(
207
+ class_names_path=model_package_content["class_names.txt"]
208
+ )
209
+ inference_config = parse_inference_config(
210
+ config_path=model_package_content["inference_config.json"],
211
+ allowed_resize_modes={
212
+ ResizeMode.STRETCH_TO,
213
+ ResizeMode.LETTERBOX,
214
+ ResizeMode.CENTER_CROP,
215
+ ResizeMode.LETTERBOX_REFLECT_EDGES,
216
+ },
217
+ )
218
+
219
+ if inference_config.model_initialization is None:
220
+ raise CorruptedModelPackageError(
221
+ message="Expected model initialization parameters not provided in inference config.",
222
+ help_url="https://todo",
223
+ )
224
+ num_classes = inference_config.model_initialization.get("num_classes")
225
+ model_name = inference_config.model_initialization.get("model_name")
226
+ if not isinstance(num_classes, int):
227
+ raise CorruptedModelPackageError(
228
+ message="Expected model initialization parameter `num_classes` not provided or in invalid format.",
229
+ help_url="https://todo",
230
+ )
231
+ if not isinstance(model_name, str):
232
+ raise CorruptedModelPackageError(
233
+ message="Expected model initialization parameter `model_name` not provided or in invalid format.",
234
+ help_url="https://todo",
235
+ )
236
+
237
+ if (
238
+ inference_config.post_processing
239
+ and inference_config.post_processing.type != "sigmoid"
240
+ ):
241
+ raise CorruptedModelPackageError(
242
+ message="Expected Sigmoid to be the post-processing",
243
+ help_url="https://todo",
244
+ )
245
+
246
+ # Create model and load weights
247
+ model = DinoV3Model(num_classes=num_classes, model_name=model_name)
248
+ state_dict = torch.load(
249
+ model_package_content["weights.pth"],
250
+ map_location=device,
251
+ weights_only=True,
252
+ )
253
+ model.load_state_dict(state_dict)
254
+ model = model.to(device).eval()
255
+
256
+ return cls(
257
+ model=model,
258
+ inference_config=inference_config,
259
+ class_names=class_names,
260
+ device=device,
261
+ )
262
+
263
+ def __init__(
264
+ self,
265
+ model: DinoV3Model,
266
+ inference_config: InferenceConfig,
267
+ class_names: List[str],
268
+ device: torch.device,
269
+ ):
270
+ self._model = model
271
+ self._inference_config = inference_config
272
+ self._class_names = class_names
273
+ self._device = device
274
+
275
+ @property
276
+ def class_names(self) -> List[str]:
277
+ return self._class_names
278
+
279
+ def pre_process(
280
+ self,
281
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
282
+ input_color_format: Optional[ColorFormat] = None,
283
+ image_size: Optional[Tuple[int, int]] = None,
284
+ **kwargs,
285
+ ) -> torch.Tensor:
286
+ return pre_process_network_input(
287
+ images=images,
288
+ image_pre_processing=self._inference_config.image_pre_processing,
289
+ network_input=self._inference_config.network_input,
290
+ target_device=self._device,
291
+ input_color_format=input_color_format,
292
+ image_size_wh=image_size,
293
+ )[0]
294
+
295
+ def forward(self, pre_processed_images: torch.Tensor, **kwargs) -> torch.Tensor:
296
+ with torch.inference_mode():
297
+ return self._model(pre_processed_images)
298
+
299
+ def post_process(
300
+ self,
301
+ model_results: torch.Tensor,
302
+ confidence: float = 0.5,
303
+ **kwargs,
304
+ ) -> List[MultiLabelClassificationPrediction]:
305
+ if (
306
+ self._inference_config.post_processing
307
+ and self._inference_config.post_processing.fused
308
+ ):
309
+ model_results = model_results
310
+ else:
311
+ model_results = torch.nn.functional.sigmoid(model_results)
312
+ results = []
313
+ for batch_element_confidence in model_results:
314
+ predicted_classes = torch.argwhere(
315
+ batch_element_confidence >= confidence
316
+ ).squeeze(dim=-1)
317
+ results.append(
318
+ MultiLabelClassificationPrediction(
319
+ class_ids=predicted_classes,
320
+ confidence=batch_element_confidence,
321
+ )
322
+ )
323
+ return results
File without changes