inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,319 @@
1
+ import os
2
+ from typing import List, Optional, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ from torch import nn
7
+ from transformers import ViTModel
8
+
9
+ from inference_models import (
10
+ ClassificationModel,
11
+ ClassificationPrediction,
12
+ MultiLabelClassificationModel,
13
+ MultiLabelClassificationPrediction,
14
+ )
15
+ from inference_models.configuration import DEFAULT_DEVICE
16
+ from inference_models.entities import ColorFormat
17
+ from inference_models.errors import CorruptedModelPackageError
18
+ from inference_models.models.common.model_packages import get_model_package_contents
19
+ from inference_models.models.common.roboflow.model_packages import (
20
+ InferenceConfig,
21
+ ResizeMode,
22
+ parse_class_names_file,
23
+ parse_inference_config,
24
+ )
25
+ from inference_models.models.common.roboflow.pre_processing import (
26
+ pre_process_network_input,
27
+ )
28
+
29
+
30
+ class VITClassifier(nn.Module):
31
+
32
+ def __init__(
33
+ self,
34
+ backbone: nn.Module,
35
+ classifier: nn.Module,
36
+ softmax_fused: bool,
37
+ ):
38
+ super().__init__()
39
+ self._backbone = backbone
40
+ self._classifier = classifier
41
+ self._softmax_fused = softmax_fused
42
+
43
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
44
+ outputs = self._backbone(pixel_values=pixel_values)
45
+ logits = self._classifier(outputs.last_hidden_state[:, 0])
46
+ if not self._softmax_fused:
47
+ logits = torch.nn.functional.softmax(logits, dim=-1)
48
+ return logits
49
+
50
+
51
+ class VITForClassificationHF(ClassificationModel[torch.Tensor, torch.Tensor]):
52
+
53
+ @classmethod
54
+ def from_pretrained(
55
+ cls,
56
+ model_name_or_path: str,
57
+ device: torch.device = DEFAULT_DEVICE,
58
+ **kwargs,
59
+ ) -> "VITForClassificationHF":
60
+ model_package_content = get_model_package_contents(
61
+ model_package_dir=model_name_or_path,
62
+ elements=[
63
+ "class_names.txt",
64
+ "classifier_layer_weights.pth",
65
+ "inference_config.json",
66
+ "vit/config.json",
67
+ "vit/model.safetensors",
68
+ ],
69
+ )
70
+ class_names = parse_class_names_file(
71
+ class_names_path=model_package_content["class_names.txt"]
72
+ )
73
+ inference_config = parse_inference_config(
74
+ config_path=model_package_content["inference_config.json"],
75
+ allowed_resize_modes={
76
+ ResizeMode.STRETCH_TO,
77
+ ResizeMode.LETTERBOX,
78
+ ResizeMode.CENTER_CROP,
79
+ ResizeMode.LETTERBOX_REFLECT_EDGES,
80
+ ResizeMode.FIT_LONGER_EDGE,
81
+ },
82
+ )
83
+ if inference_config.model_initialization is None:
84
+ raise CorruptedModelPackageError(
85
+ message="Expected model initialization parameters not provided in inference config.",
86
+ help_url="https://todo",
87
+ )
88
+ num_classes = inference_config.model_initialization.get("num_classes")
89
+ if not isinstance(num_classes, int):
90
+ raise CorruptedModelPackageError(
91
+ message="Expected model initialization parameter `num_classes` not provided or in invalid format.",
92
+ help_url="https://todo",
93
+ )
94
+ if inference_config.post_processing.type != "softmax":
95
+ raise CorruptedModelPackageError(
96
+ message="Expected Softmax to be the post-processing",
97
+ help_url="https://todo",
98
+ )
99
+ backbone = ViTModel.from_pretrained(os.path.join(model_name_or_path, "vit")).to(
100
+ device
101
+ )
102
+ classifier = nn.Linear(backbone.config.hidden_size, num_classes).to(device)
103
+ classifier_state_dict = torch.load(
104
+ model_package_content["classifier_layer_weights.pth"],
105
+ weights_only=True,
106
+ map_location=device,
107
+ )
108
+ classifier.load_state_dict(classifier_state_dict)
109
+ model = (
110
+ VITClassifier(
111
+ backbone=backbone,
112
+ classifier=classifier,
113
+ softmax_fused=inference_config.post_processing.fused,
114
+ )
115
+ .to(device)
116
+ .eval()
117
+ )
118
+ return cls(
119
+ model=model,
120
+ inference_config=inference_config,
121
+ class_names=class_names,
122
+ device=device,
123
+ )
124
+
125
+ def __init__(
126
+ self,
127
+ model: VITClassifier,
128
+ inference_config: InferenceConfig,
129
+ class_names: List[str],
130
+ device: torch.device,
131
+ ):
132
+ self._model = model
133
+ self._inference_config = inference_config
134
+ self._class_names = class_names
135
+ self._device = device
136
+
137
+ @property
138
+ def class_names(self) -> List[str]:
139
+ return self._class_names
140
+
141
+ def pre_process(
142
+ self,
143
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
144
+ input_color_format: Optional[ColorFormat] = None,
145
+ **kwargs,
146
+ ) -> torch.Tensor:
147
+ return pre_process_network_input(
148
+ images=images,
149
+ image_pre_processing=self._inference_config.image_pre_processing,
150
+ network_input=self._inference_config.network_input,
151
+ target_device=self._device,
152
+ input_color_format=input_color_format,
153
+ )[0]
154
+
155
+ def forward(self, pre_processed_images: torch.Tensor, **kwargs) -> torch.Tensor:
156
+ with torch.inference_mode():
157
+ return self._model(pre_processed_images)
158
+
159
+ def post_process(
160
+ self,
161
+ model_results: torch.Tensor,
162
+ **kwargs,
163
+ ) -> ClassificationPrediction:
164
+ return ClassificationPrediction(
165
+ class_id=model_results.argmax(dim=-1),
166
+ confidence=model_results,
167
+ )
168
+
169
+
170
+ class VITMultiLabelClassifier(nn.Module):
171
+
172
+ def __init__(
173
+ self,
174
+ backbone: nn.Module,
175
+ classifier: nn.Module,
176
+ sigmoid_fused: bool,
177
+ ):
178
+ super().__init__()
179
+ self._backbone = backbone
180
+ self._classifier = classifier
181
+ self._sigmoid_fused = sigmoid_fused
182
+
183
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
184
+ outputs = self._backbone(pixel_values=pixel_values)
185
+ logits = self._classifier(outputs.last_hidden_state[:, 0])
186
+ if not self._sigmoid_fused:
187
+ logits = torch.nn.functional.sigmoid(logits)
188
+ return logits
189
+
190
+
191
+ class VITForMultiLabelClassificationHF(
192
+ MultiLabelClassificationModel[torch.Tensor, torch.Tensor]
193
+ ):
194
+
195
+ @classmethod
196
+ def from_pretrained(
197
+ cls,
198
+ model_name_or_path: str,
199
+ default_onnx_trt_options: bool = True,
200
+ device: torch.device = DEFAULT_DEVICE,
201
+ **kwargs,
202
+ ) -> "VITForMultiLabelClassificationHF":
203
+ model_package_content = get_model_package_contents(
204
+ model_package_dir=model_name_or_path,
205
+ elements=[
206
+ "class_names.txt",
207
+ "classifier_layer_weights.pth",
208
+ "inference_config.json",
209
+ "vit/config.json",
210
+ "vit/model.safetensors",
211
+ ],
212
+ )
213
+ class_names = parse_class_names_file(
214
+ class_names_path=model_package_content["class_names.txt"]
215
+ )
216
+ inference_config = parse_inference_config(
217
+ config_path=model_package_content["inference_config.json"],
218
+ allowed_resize_modes={
219
+ ResizeMode.STRETCH_TO,
220
+ ResizeMode.LETTERBOX,
221
+ ResizeMode.CENTER_CROP,
222
+ ResizeMode.LETTERBOX_REFLECT_EDGES,
223
+ ResizeMode.FIT_LONGER_EDGE,
224
+ },
225
+ )
226
+ if inference_config.model_initialization is None:
227
+ raise CorruptedModelPackageError(
228
+ message="Expected model initialization parameters not provided in inference config.",
229
+ help_url="https://todo",
230
+ )
231
+ num_classes = inference_config.model_initialization.get("num_classes")
232
+ if not isinstance(num_classes, int):
233
+ raise CorruptedModelPackageError(
234
+ message="Expected model initialization parameter `num_classes` not provided or in invalid format.",
235
+ help_url="https://todo",
236
+ )
237
+ if inference_config.post_processing.type != "sigmoid":
238
+ raise CorruptedModelPackageError(
239
+ message="Expected sigmoid to be the post-processing",
240
+ help_url="https://todo",
241
+ )
242
+ backbone = ViTModel.from_pretrained(os.path.join(model_name_or_path, "vit")).to(
243
+ device
244
+ )
245
+ classifier = nn.Linear(backbone.config.hidden_size, num_classes).to(device)
246
+ classifier_state_dict = torch.load(
247
+ model_package_content["classifier_layer_weights.pth"],
248
+ weights_only=True,
249
+ map_location=device,
250
+ )
251
+ classifier.load_state_dict(classifier_state_dict)
252
+ model = (
253
+ VITMultiLabelClassifier(
254
+ backbone=backbone,
255
+ classifier=classifier,
256
+ sigmoid_fused=inference_config.post_processing.fused,
257
+ )
258
+ .to(device)
259
+ .eval()
260
+ )
261
+ return cls(
262
+ model=model,
263
+ inference_config=inference_config,
264
+ class_names=class_names,
265
+ device=device,
266
+ )
267
+
268
+ def __init__(
269
+ self,
270
+ model: VITMultiLabelClassifier,
271
+ inference_config: InferenceConfig,
272
+ class_names: List[str],
273
+ device: torch.device,
274
+ ):
275
+ self._model = model
276
+ self._inference_config = inference_config
277
+ self._class_names = class_names
278
+ self._device = device
279
+
280
+ @property
281
+ def class_names(self) -> List[str]:
282
+ return self._class_names
283
+
284
+ def pre_process(
285
+ self,
286
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
287
+ input_color_format: Optional[ColorFormat] = None,
288
+ **kwargs,
289
+ ) -> torch.Tensor:
290
+ return pre_process_network_input(
291
+ images=images,
292
+ image_pre_processing=self._inference_config.image_pre_processing,
293
+ network_input=self._inference_config.network_input,
294
+ target_device=self._device,
295
+ input_color_format=input_color_format,
296
+ )[0]
297
+
298
+ def forward(self, pre_processed_images: torch.Tensor, **kwargs) -> torch.Tensor:
299
+ with torch.inference_mode():
300
+ return self._model(pre_processed_images)
301
+
302
+ def post_process(
303
+ self,
304
+ model_results: torch.Tensor,
305
+ confidence: float = 0.5,
306
+ **kwargs,
307
+ ) -> List[MultiLabelClassificationPrediction]:
308
+ results = []
309
+ for batch_element_confidence in model_results:
310
+ predicted_classes = torch.argwhere(
311
+ batch_element_confidence >= confidence
312
+ ).squeeze(dim=-1)
313
+ results.append(
314
+ MultiLabelClassificationPrediction(
315
+ class_ids=predicted_classes,
316
+ confidence=batch_element_confidence,
317
+ )
318
+ )
319
+ return results
@@ -0,0 +1,326 @@
1
+ from threading import Lock
2
+ from typing import List, Optional, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from inference_models import (
8
+ ClassificationModel,
9
+ ClassificationPrediction,
10
+ MultiLabelClassificationModel,
11
+ MultiLabelClassificationPrediction,
12
+ )
13
+ from inference_models.configuration import DEFAULT_DEVICE
14
+ from inference_models.entities import ColorFormat
15
+ from inference_models.errors import (
16
+ CorruptedModelPackageError,
17
+ EnvironmentConfigurationError,
18
+ MissingDependencyError,
19
+ )
20
+ from inference_models.models.base.types import PreprocessedInputs
21
+ from inference_models.models.common.model_packages import get_model_package_contents
22
+ from inference_models.models.common.onnx import (
23
+ run_session_with_batch_size_limit,
24
+ set_execution_provider_defaults,
25
+ )
26
+ from inference_models.models.common.roboflow.model_packages import (
27
+ InferenceConfig,
28
+ ResizeMode,
29
+ parse_class_names_file,
30
+ parse_inference_config,
31
+ )
32
+ from inference_models.models.common.roboflow.pre_processing import (
33
+ pre_process_network_input,
34
+ )
35
+ from inference_models.utils.onnx_introspection import (
36
+ get_selected_onnx_execution_providers,
37
+ )
38
+
39
+ try:
40
+ import onnxruntime
41
+ except ImportError as import_error:
42
+ raise MissingDependencyError(
43
+ message=f"Could not import VIT model with ONNX backend - this error means that some additional dependencies "
44
+ f"are not installed in the environment. If you run the `inference-models` library directly in your Python "
45
+ f"program, make sure the following extras of the package are installed: \n"
46
+ f"\t* `onnx-cpu` - when you wish to use library with CPU support only\n"
47
+ f"\t* `onnx-cu12` - for running on GPU with Cuda 12 installed\n"
48
+ f"\t* `onnx-cu118` - for running on GPU with Cuda 11.8 installed\n"
49
+ f"\t* `onnx-jp6-cu126` - for running on Jetson with Jetpack 6\n"
50
+ f"If you see this error using Roboflow infrastructure, make sure the service you use does support the model. "
51
+ f"You can also contact Roboflow to get support.",
52
+ help_url="https://todo",
53
+ ) from import_error
54
+
55
+
56
+ class VITForClassificationOnnx(ClassificationModel[torch.Tensor, torch.Tensor]):
57
+
58
+ @classmethod
59
+ def from_pretrained(
60
+ cls,
61
+ model_name_or_path: str,
62
+ onnx_execution_providers: Optional[List[Union[str, tuple]]] = None,
63
+ default_onnx_trt_options: bool = True,
64
+ device: torch.device = DEFAULT_DEVICE,
65
+ **kwargs,
66
+ ) -> "VITForClassificationOnnx":
67
+ if onnx_execution_providers is None:
68
+ onnx_execution_providers = get_selected_onnx_execution_providers()
69
+ if not onnx_execution_providers:
70
+ raise EnvironmentConfigurationError(
71
+ message=f"Could not initialize model - selected backend is ONNX which requires execution provider to "
72
+ f"be specified - explicitly in `from_pretrained(...)` method or via env variable "
73
+ f"`ONNXRUNTIME_EXECUTION_PROVIDERS`. If you run model locally - adjust your setup, otherwise "
74
+ f"contact the platform support.",
75
+ help_url="https://todo",
76
+ )
77
+ onnx_execution_providers = set_execution_provider_defaults(
78
+ providers=onnx_execution_providers,
79
+ model_package_path=model_name_or_path,
80
+ device=device,
81
+ default_onnx_trt_options=default_onnx_trt_options,
82
+ )
83
+ model_package_content = get_model_package_contents(
84
+ model_package_dir=model_name_or_path,
85
+ elements=[
86
+ "class_names.txt",
87
+ "inference_config.json",
88
+ "weights.onnx",
89
+ ],
90
+ )
91
+ class_names = parse_class_names_file(
92
+ class_names_path=model_package_content["class_names.txt"]
93
+ )
94
+ inference_config = parse_inference_config(
95
+ config_path=model_package_content["inference_config.json"],
96
+ allowed_resize_modes={
97
+ ResizeMode.STRETCH_TO,
98
+ ResizeMode.LETTERBOX,
99
+ ResizeMode.CENTER_CROP,
100
+ ResizeMode.LETTERBOX_REFLECT_EDGES,
101
+ },
102
+ )
103
+ if inference_config.post_processing.type != "softmax":
104
+ raise CorruptedModelPackageError(
105
+ message="Expected Softmax to be the post-processing",
106
+ help_url="https://todo",
107
+ )
108
+ session = onnxruntime.InferenceSession(
109
+ path_or_bytes=model_package_content["weights.onnx"],
110
+ providers=onnx_execution_providers,
111
+ )
112
+ input_shape = session.get_inputs()[0].shape
113
+ input_batch_size = input_shape[0]
114
+ if isinstance(input_batch_size, str):
115
+ input_batch_size = None
116
+ input_name = session.get_inputs()[0].name
117
+ return cls(
118
+ session=session,
119
+ input_name=input_name,
120
+ inference_config=inference_config,
121
+ class_names=class_names,
122
+ device=device,
123
+ input_batch_size=input_batch_size,
124
+ )
125
+
126
+ def __init__(
127
+ self,
128
+ session: onnxruntime.InferenceSession,
129
+ input_name: str,
130
+ inference_config: InferenceConfig,
131
+ class_names: List[str],
132
+ device: torch.device,
133
+ input_batch_size: Optional[int],
134
+ ):
135
+ self._session = session
136
+ self._input_name = input_name
137
+ self._inference_config = inference_config
138
+ self._class_names = class_names
139
+ self._device = device
140
+ self._input_batch_size = input_batch_size
141
+ self._session_thread_lock = Lock()
142
+
143
+ @property
144
+ def class_names(self) -> List[str]:
145
+ return self._class_names
146
+
147
+ def pre_process(
148
+ self,
149
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
150
+ input_color_format: Optional[ColorFormat] = None,
151
+ **kwargs,
152
+ ) -> torch.Tensor:
153
+ return pre_process_network_input(
154
+ images=images,
155
+ image_pre_processing=self._inference_config.image_pre_processing,
156
+ network_input=self._inference_config.network_input,
157
+ target_device=self._device,
158
+ input_color_format=input_color_format,
159
+ )[0]
160
+
161
+ def forward(
162
+ self, pre_processed_images: PreprocessedInputs, **kwargs
163
+ ) -> torch.Tensor:
164
+ with self._session_thread_lock:
165
+ return run_session_with_batch_size_limit(
166
+ session=self._session,
167
+ inputs={self._input_name: pre_processed_images},
168
+ min_batch_size=self._input_batch_size,
169
+ max_batch_size=self._input_batch_size,
170
+ )[0]
171
+
172
+ def post_process(
173
+ self,
174
+ model_results: torch.Tensor,
175
+ **kwargs,
176
+ ) -> ClassificationPrediction:
177
+ if self._inference_config.post_processing.fused:
178
+ confidence = model_results
179
+ else:
180
+ confidence = torch.nn.functional.softmax(model_results, dim=-1)
181
+ return ClassificationPrediction(
182
+ class_id=confidence.argmax(dim=-1),
183
+ confidence=confidence,
184
+ )
185
+
186
+
187
+ class VITForMultiLabelClassificationOnnx(
188
+ MultiLabelClassificationModel[torch.Tensor, torch.Tensor]
189
+ ):
190
+
191
+ @classmethod
192
+ def from_pretrained(
193
+ cls,
194
+ model_name_or_path: str,
195
+ onnx_execution_providers: Optional[List[Union[str, tuple]]] = None,
196
+ default_onnx_trt_options: bool = True,
197
+ device: torch.device = DEFAULT_DEVICE,
198
+ **kwargs,
199
+ ) -> "VITForMultiLabelClassificationOnnx":
200
+ if onnx_execution_providers is None:
201
+ onnx_execution_providers = get_selected_onnx_execution_providers()
202
+ if not onnx_execution_providers:
203
+ raise EnvironmentConfigurationError(
204
+ message=f"Could not initialize model - selected backend is ONNX which requires execution provider to "
205
+ f"be specified - explicitly in `from_pretrained(...)` method or via env variable "
206
+ f"`ONNXRUNTIME_EXECUTION_PROVIDERS`. If you run model locally - adjust your setup, otherwise "
207
+ f"contact the platform support.",
208
+ help_url="https://todo",
209
+ )
210
+ onnx_execution_providers = set_execution_provider_defaults(
211
+ providers=onnx_execution_providers,
212
+ model_package_path=model_name_or_path,
213
+ device=device,
214
+ default_onnx_trt_options=default_onnx_trt_options,
215
+ )
216
+ model_package_content = get_model_package_contents(
217
+ model_package_dir=model_name_or_path,
218
+ elements=[
219
+ "class_names.txt",
220
+ "inference_config.json",
221
+ "weights.onnx",
222
+ ],
223
+ )
224
+ class_names = parse_class_names_file(
225
+ class_names_path=model_package_content["class_names.txt"]
226
+ )
227
+ inference_config = parse_inference_config(
228
+ config_path=model_package_content["inference_config.json"],
229
+ allowed_resize_modes={
230
+ ResizeMode.STRETCH_TO,
231
+ ResizeMode.LETTERBOX,
232
+ ResizeMode.CENTER_CROP,
233
+ ResizeMode.LETTERBOX_REFLECT_EDGES,
234
+ },
235
+ )
236
+ if inference_config.post_processing.type != "sigmoid":
237
+ raise CorruptedModelPackageError(
238
+ message="Expected sigmoid to be the post-processing",
239
+ help_url="https://todo",
240
+ )
241
+ session = onnxruntime.InferenceSession(
242
+ path_or_bytes=model_package_content["weights.onnx"],
243
+ providers=onnx_execution_providers,
244
+ )
245
+ input_shape = session.get_inputs()[0].shape
246
+ input_batch_size = input_shape[0]
247
+ if isinstance(input_batch_size, str):
248
+ input_batch_size = None
249
+ input_name = session.get_inputs()[0].name
250
+ return cls(
251
+ session=session,
252
+ input_name=input_name,
253
+ inference_config=inference_config,
254
+ class_names=class_names,
255
+ device=device,
256
+ input_batch_size=input_batch_size,
257
+ )
258
+
259
+ def __init__(
260
+ self,
261
+ session: onnxruntime.InferenceSession,
262
+ input_name: str,
263
+ inference_config: InferenceConfig,
264
+ class_names: List[str],
265
+ device: torch.device,
266
+ input_batch_size: Optional[int],
267
+ ):
268
+ self._session = session
269
+ self._input_name = input_name
270
+ self._inference_config = inference_config
271
+ self._class_names = class_names
272
+ self._device = device
273
+ self._input_batch_size = input_batch_size
274
+ self._session_thread_lock = Lock()
275
+
276
+ @property
277
+ def class_names(self) -> List[str]:
278
+ return self._class_names
279
+
280
+ def pre_process(
281
+ self,
282
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
283
+ input_color_format: Optional[ColorFormat] = None,
284
+ **kwargs,
285
+ ) -> torch.Tensor:
286
+ return pre_process_network_input(
287
+ images=images,
288
+ image_pre_processing=self._inference_config.image_pre_processing,
289
+ network_input=self._inference_config.network_input,
290
+ target_device=self._device,
291
+ input_color_format=input_color_format,
292
+ )[0]
293
+
294
+ def forward(
295
+ self, pre_processed_images: PreprocessedInputs, **kwargs
296
+ ) -> torch.Tensor:
297
+ with self._session_thread_lock:
298
+ return run_session_with_batch_size_limit(
299
+ session=self._session,
300
+ inputs={self._input_name: pre_processed_images},
301
+ min_batch_size=self._input_batch_size,
302
+ max_batch_size=self._input_batch_size,
303
+ )[0]
304
+
305
+ def post_process(
306
+ self,
307
+ model_results: torch.Tensor,
308
+ confidence: float = 0.5,
309
+ **kwargs,
310
+ ) -> List[MultiLabelClassificationPrediction]:
311
+ if self._inference_config.post_processing.fused:
312
+ model_results = model_results
313
+ else:
314
+ model_results = torch.nn.functional.sigmoid(model_results)
315
+ results = []
316
+ for batch_element_confidence in model_results:
317
+ predicted_classes = torch.argwhere(
318
+ batch_element_confidence >= confidence
319
+ ).squeeze(dim=-1)
320
+ results.append(
321
+ MultiLabelClassificationPrediction(
322
+ class_ids=predicted_classes,
323
+ confidence=batch_element_confidence,
324
+ )
325
+ )
326
+ return results