inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,305 @@
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import numpy as np
4
+ import timm
5
+ import torch
6
+ from torch import nn
7
+
8
+ from inference_models import (
9
+ ClassificationModel,
10
+ ClassificationPrediction,
11
+ MultiLabelClassificationModel,
12
+ MultiLabelClassificationPrediction,
13
+ )
14
+ from inference_models.configuration import DEFAULT_DEVICE
15
+ from inference_models.entities import ColorFormat
16
+ from inference_models.errors import CorruptedModelPackageError
17
+ from inference_models.models.common.model_packages import get_model_package_contents
18
+ from inference_models.models.common.roboflow.model_packages import (
19
+ InferenceConfig,
20
+ ResizeMode,
21
+ parse_class_names_file,
22
+ parse_inference_config,
23
+ )
24
+ from inference_models.models.common.roboflow.pre_processing import (
25
+ pre_process_network_input,
26
+ )
27
+
28
+
29
+ class ResNetClassifier(nn.Module):
30
+
31
+ def __init__(self, backbone: nn.Module, softmax_fused: bool):
32
+ super().__init__()
33
+ self._backbone = backbone
34
+ self._softmax_fused = softmax_fused
35
+
36
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
37
+ results = self._backbone(x)
38
+ if not self._softmax_fused:
39
+ results = torch.nn.functional.softmax(results, dim=-1)
40
+ return results
41
+
42
+
43
+ class ResNetForClassificationTorch(ClassificationModel[torch.Tensor, torch.Tensor]):
44
+
45
+ @classmethod
46
+ def from_pretrained(
47
+ cls,
48
+ model_name_or_path: str,
49
+ device: torch.device = DEFAULT_DEVICE,
50
+ **kwargs,
51
+ ) -> "ResNetForClassificationTorch":
52
+ model_package_content = get_model_package_contents(
53
+ model_package_dir=model_name_or_path,
54
+ elements=[
55
+ "class_names.txt",
56
+ "inference_config.json",
57
+ "weights.pth",
58
+ ],
59
+ )
60
+ class_names = parse_class_names_file(
61
+ class_names_path=model_package_content["class_names.txt"]
62
+ )
63
+ inference_config = parse_inference_config(
64
+ config_path=model_package_content["inference_config.json"],
65
+ allowed_resize_modes={
66
+ ResizeMode.STRETCH_TO,
67
+ ResizeMode.LETTERBOX,
68
+ ResizeMode.CENTER_CROP,
69
+ ResizeMode.LETTERBOX_REFLECT_EDGES,
70
+ },
71
+ )
72
+ if inference_config.model_initialization is None:
73
+ raise CorruptedModelPackageError(
74
+ message="Expected model initialization parameters not provided in inference config.",
75
+ help_url="https://todo",
76
+ )
77
+ num_classes = inference_config.model_initialization.get("num_classes")
78
+ model_name = inference_config.model_initialization.get("model_name")
79
+ if not isinstance(num_classes, int):
80
+ raise CorruptedModelPackageError(
81
+ message="Expected model initialization parameter `num_classes` not provided or in invalid format.",
82
+ help_url="https://todo",
83
+ )
84
+ if not isinstance(model_name, str):
85
+ raise CorruptedModelPackageError(
86
+ message="Expected model initialization parameter `model_name` not provided or in invalid format.",
87
+ help_url="https://todo",
88
+ )
89
+ if inference_config.post_processing.type != "softmax":
90
+ raise CorruptedModelPackageError(
91
+ message="Expected softmax to be the post-processing",
92
+ help_url="https://todo",
93
+ )
94
+ backbone = timm.create_model(
95
+ model_name,
96
+ pretrained=False,
97
+ num_classes=num_classes,
98
+ ).to(device)
99
+ state_dict = torch.load(
100
+ model_package_content["weights.pth"],
101
+ weights_only=True,
102
+ map_location=device,
103
+ )
104
+ backbone.load_state_dict(state_dict)
105
+ model = ResNetClassifier(
106
+ backbone=backbone,
107
+ softmax_fused=inference_config.post_processing.fused,
108
+ ).to(device)
109
+ return cls(
110
+ model=model.eval(),
111
+ inference_config=inference_config,
112
+ class_names=class_names,
113
+ device=device,
114
+ )
115
+
116
+ def __init__(
117
+ self,
118
+ model: ResNetClassifier,
119
+ inference_config: InferenceConfig,
120
+ class_names: List[str],
121
+ device: torch.device,
122
+ ):
123
+ self._model = model
124
+ self._inference_config = inference_config
125
+ self._class_names = class_names
126
+ self._device = device
127
+
128
+ @property
129
+ def class_names(self) -> List[str]:
130
+ return self._class_names
131
+
132
+ def pre_process(
133
+ self,
134
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
135
+ input_color_format: Optional[ColorFormat] = None,
136
+ image_size: Optional[Tuple[int, int]] = None,
137
+ **kwargs,
138
+ ) -> torch.Tensor:
139
+ return pre_process_network_input(
140
+ images=images,
141
+ image_pre_processing=self._inference_config.image_pre_processing,
142
+ network_input=self._inference_config.network_input,
143
+ target_device=self._device,
144
+ input_color_format=input_color_format,
145
+ image_size_wh=image_size,
146
+ )[0]
147
+
148
+ def forward(self, pre_processed_images: torch.Tensor, **kwargs) -> torch.Tensor:
149
+ with torch.inference_mode():
150
+ return self._model(pre_processed_images)
151
+
152
+ def post_process(
153
+ self,
154
+ model_results: torch.Tensor,
155
+ **kwargs,
156
+ ) -> ClassificationPrediction:
157
+ return ClassificationPrediction(
158
+ class_id=model_results.argmax(dim=-1),
159
+ confidence=model_results,
160
+ )
161
+
162
+
163
+ class ResNetMultiLabelClassifier(nn.Module):
164
+
165
+ def __init__(self, backbone: nn.Module, sigmoid_fused: bool):
166
+ super().__init__()
167
+ self._backbone = backbone
168
+ self._sigmoid_fused = sigmoid_fused
169
+
170
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
171
+ results = self._backbone(x)
172
+ if not self._sigmoid_fused:
173
+ results = torch.nn.functional.sigmoid(results)
174
+ return results
175
+
176
+
177
+ class ResNetForMultiLabelClassificationTorch(
178
+ MultiLabelClassificationModel[torch.Tensor, torch.Tensor]
179
+ ):
180
+
181
+ @classmethod
182
+ def from_pretrained(
183
+ cls,
184
+ model_name_or_path: str,
185
+ device: torch.device = DEFAULT_DEVICE,
186
+ **kwargs,
187
+ ) -> "ResNetForMultiLabelClassificationTorch":
188
+ model_package_content = get_model_package_contents(
189
+ model_package_dir=model_name_or_path,
190
+ elements=[
191
+ "class_names.txt",
192
+ "inference_config.json",
193
+ "weights.pth",
194
+ ],
195
+ )
196
+ class_names = parse_class_names_file(
197
+ class_names_path=model_package_content["class_names.txt"]
198
+ )
199
+ inference_config = parse_inference_config(
200
+ config_path=model_package_content["inference_config.json"],
201
+ allowed_resize_modes={
202
+ ResizeMode.STRETCH_TO,
203
+ ResizeMode.LETTERBOX,
204
+ ResizeMode.CENTER_CROP,
205
+ ResizeMode.LETTERBOX_REFLECT_EDGES,
206
+ },
207
+ )
208
+ if inference_config.model_initialization is None:
209
+ raise CorruptedModelPackageError(
210
+ message="Expected model initialization parameters not provided in inference config.",
211
+ help_url="https://todo",
212
+ )
213
+ num_classes = inference_config.model_initialization.get("num_classes")
214
+ model_name = inference_config.model_initialization.get("model_name")
215
+ if not isinstance(num_classes, int):
216
+ raise CorruptedModelPackageError(
217
+ message="Expected model initialization parameter `num_classes` not provided or in invalid format.",
218
+ help_url="https://todo",
219
+ )
220
+ if not isinstance(model_name, str):
221
+ raise CorruptedModelPackageError(
222
+ message="Expected model initialization parameter `model_name` not provided or in invalid format.",
223
+ help_url="https://todo",
224
+ )
225
+ if inference_config.post_processing.type != "sigmoid":
226
+ raise CorruptedModelPackageError(
227
+ message="Expected sigmoid to be the post-processing",
228
+ help_url="https://todo",
229
+ )
230
+ backbone = timm.create_model(
231
+ model_name,
232
+ pretrained=False,
233
+ num_classes=num_classes,
234
+ ).to(device)
235
+ state_dict = torch.load(
236
+ model_package_content["weights.pth"],
237
+ weights_only=True,
238
+ map_location=device,
239
+ )
240
+ backbone.load_state_dict(state_dict)
241
+ model = ResNetMultiLabelClassifier(
242
+ backbone=backbone,
243
+ sigmoid_fused=inference_config.post_processing.fused,
244
+ ).to(device)
245
+ return cls(
246
+ model=model.eval(),
247
+ inference_config=inference_config,
248
+ class_names=class_names,
249
+ device=device,
250
+ )
251
+
252
+ def __init__(
253
+ self,
254
+ model: ResNetMultiLabelClassifier,
255
+ inference_config: InferenceConfig,
256
+ class_names: List[str],
257
+ device: torch.device,
258
+ ):
259
+ self._model = model
260
+ self._inference_config = inference_config
261
+ self._class_names = class_names
262
+ self._device = device
263
+
264
+ @property
265
+ def class_names(self) -> List[str]:
266
+ return self._class_names
267
+
268
+ def pre_process(
269
+ self,
270
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
271
+ input_color_format: Optional[ColorFormat] = None,
272
+ image_size: Optional[Tuple[int, int]] = None,
273
+ **kwargs,
274
+ ) -> torch.Tensor:
275
+ return pre_process_network_input(
276
+ images=images,
277
+ image_pre_processing=self._inference_config.image_pre_processing,
278
+ network_input=self._inference_config.network_input,
279
+ target_device=self._device,
280
+ input_color_format=input_color_format,
281
+ image_size_wh=image_size,
282
+ )[0]
283
+
284
+ def forward(self, pre_processed_images: torch.Tensor, **kwargs) -> torch.Tensor:
285
+ with torch.inference_mode():
286
+ return self._model(pre_processed_images)
287
+
288
+ def post_process(
289
+ self,
290
+ model_results: torch.Tensor,
291
+ confidence: float = 0.5,
292
+ **kwargs,
293
+ ) -> List[MultiLabelClassificationPrediction]:
294
+ results = []
295
+ for batch_element_confidence in model_results:
296
+ predicted_classes = torch.argwhere(
297
+ batch_element_confidence >= confidence
298
+ ).squeeze(dim=-1)
299
+ results.append(
300
+ MultiLabelClassificationPrediction(
301
+ class_ids=predicted_classes,
302
+ confidence=batch_element_confidence,
303
+ )
304
+ )
305
+ return results
@@ -0,0 +1,369 @@
1
+ from threading import Lock
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from inference_models import (
8
+ ClassificationModel,
9
+ ClassificationPrediction,
10
+ MultiLabelClassificationModel,
11
+ MultiLabelClassificationPrediction,
12
+ )
13
+ from inference_models.configuration import DEFAULT_DEVICE
14
+ from inference_models.entities import ColorFormat
15
+ from inference_models.errors import (
16
+ CorruptedModelPackageError,
17
+ MissingDependencyError,
18
+ ModelRuntimeError,
19
+ )
20
+ from inference_models.models.base.types import PreprocessedInputs
21
+ from inference_models.models.common.cuda import (
22
+ use_cuda_context,
23
+ use_primary_cuda_context,
24
+ )
25
+ from inference_models.models.common.model_packages import get_model_package_contents
26
+ from inference_models.models.common.roboflow.model_packages import (
27
+ InferenceConfig,
28
+ ResizeMode,
29
+ TRTConfig,
30
+ parse_class_names_file,
31
+ parse_inference_config,
32
+ parse_trt_config,
33
+ )
34
+ from inference_models.models.common.roboflow.pre_processing import (
35
+ pre_process_network_input,
36
+ )
37
+ from inference_models.models.common.trt import (
38
+ get_engine_inputs_and_outputs,
39
+ infer_from_trt_engine,
40
+ load_model,
41
+ )
42
+
43
+ try:
44
+ import tensorrt as trt
45
+ except ImportError as import_error:
46
+ raise MissingDependencyError(
47
+ message=f"Could not import YOLOv8 model with TRT backend - this error means that some additional dependencies "
48
+ f"are not installed in the environment. If you run the `inference-models` library directly in your Python "
49
+ f"program, make sure the following extras of the package are installed: `trt10` - installation can only "
50
+ f"succeed for Linux and Windows machines with Cuda 12 installed. Jetson devices, should have TRT 10.x "
51
+ f"installed for all builds with Jetpack 6. "
52
+ f"If you see this error using Roboflow infrastructure, make sure the service you use does support the model. "
53
+ f"You can also contact Roboflow to get support.",
54
+ help_url="https://todo",
55
+ ) from import_error
56
+
57
+ try:
58
+ import pycuda.driver as cuda
59
+ except ImportError as import_error:
60
+ raise MissingDependencyError(
61
+ message="TODO", help_url="https://todo"
62
+ ) from import_error
63
+
64
+
65
+ class ResNetForClassificationTRT(ClassificationModel[torch.Tensor, torch.Tensor]):
66
+
67
+ @classmethod
68
+ def from_pretrained(
69
+ cls,
70
+ model_name_or_path: str,
71
+ device: torch.device = DEFAULT_DEVICE,
72
+ engine_host_code_allowed: bool = False,
73
+ **kwargs,
74
+ ) -> "ResNetForClassificationTRT":
75
+ if device.type != "cuda":
76
+ raise ModelRuntimeError(
77
+ message=f"TRT engine only runs on CUDA device - {device} device detected.",
78
+ help_url="https://todo",
79
+ )
80
+ model_package_content = get_model_package_contents(
81
+ model_package_dir=model_name_or_path,
82
+ elements=[
83
+ "class_names.txt",
84
+ "inference_config.json",
85
+ "trt_config.json",
86
+ "engine.plan",
87
+ ],
88
+ )
89
+ class_names = parse_class_names_file(
90
+ class_names_path=model_package_content["class_names.txt"]
91
+ )
92
+ inference_config = parse_inference_config(
93
+ config_path=model_package_content["inference_config.json"],
94
+ allowed_resize_modes={
95
+ ResizeMode.STRETCH_TO,
96
+ ResizeMode.LETTERBOX,
97
+ ResizeMode.CENTER_CROP,
98
+ ResizeMode.LETTERBOX_REFLECT_EDGES,
99
+ },
100
+ )
101
+ if inference_config.post_processing.type != "softmax":
102
+ raise CorruptedModelPackageError(
103
+ message="Expected Softmax to be the post-processing",
104
+ help_url="https://todo",
105
+ )
106
+ trt_config = parse_trt_config(
107
+ config_path=model_package_content["trt_config.json"]
108
+ )
109
+ cuda.init()
110
+ cuda_device = cuda.Device(device.index or 0)
111
+ with use_primary_cuda_context(cuda_device=cuda_device) as cuda_context:
112
+ engine = load_model(
113
+ model_path=model_package_content["engine.plan"],
114
+ engine_host_code_allowed=engine_host_code_allowed,
115
+ )
116
+ execution_context = engine.create_execution_context()
117
+ inputs, outputs = get_engine_inputs_and_outputs(engine=engine)
118
+ if len(inputs) != 1:
119
+ raise CorruptedModelPackageError(
120
+ message=f"Implementation assume single model input, found: {len(inputs)}.",
121
+ help_url="https://todo",
122
+ )
123
+ if len(outputs) != 1:
124
+ raise CorruptedModelPackageError(
125
+ message=f"Implementation assume single model output, found: {len(outputs)}.",
126
+ help_url="https://todo",
127
+ )
128
+ return cls(
129
+ engine=engine,
130
+ input_name=inputs[0],
131
+ output_name=outputs[0],
132
+ class_names=class_names,
133
+ inference_config=inference_config,
134
+ trt_config=trt_config,
135
+ device=device,
136
+ cuda_context=cuda_context,
137
+ execution_context=execution_context,
138
+ )
139
+
140
+ def __init__(
141
+ self,
142
+ engine: trt.ICudaEngine,
143
+ input_name: str,
144
+ output_name: str,
145
+ class_names: List[str],
146
+ inference_config: InferenceConfig,
147
+ trt_config: TRTConfig,
148
+ device: torch.device,
149
+ cuda_context: cuda.Context,
150
+ execution_context: trt.IExecutionContext,
151
+ ):
152
+ self._engine = engine
153
+ self._input_name = input_name
154
+ self._output_names = [output_name]
155
+ self._class_names = class_names
156
+ self._inference_config = inference_config
157
+ self._trt_config = trt_config
158
+ self._device = device
159
+ self._cuda_context = cuda_context
160
+ self._execution_context = execution_context
161
+ self._lock = Lock()
162
+
163
+ @property
164
+ def class_names(self) -> List[str]:
165
+ return self._class_names
166
+
167
+ def pre_process(
168
+ self,
169
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
170
+ input_color_format: Optional[ColorFormat] = None,
171
+ image_size: Optional[Tuple[int, int]] = None,
172
+ **kwargs,
173
+ ) -> torch.Tensor:
174
+ return pre_process_network_input(
175
+ images=images,
176
+ image_pre_processing=self._inference_config.image_pre_processing,
177
+ network_input=self._inference_config.network_input,
178
+ target_device=self._device,
179
+ input_color_format=input_color_format,
180
+ image_size_wh=image_size,
181
+ )[0]
182
+
183
+ def forward(
184
+ self, pre_processed_images: PreprocessedInputs, **kwargs
185
+ ) -> torch.Tensor:
186
+ with self._lock:
187
+ with use_cuda_context(context=self._cuda_context):
188
+ return infer_from_trt_engine(
189
+ pre_processed_images=pre_processed_images,
190
+ trt_config=self._trt_config,
191
+ engine=self._engine,
192
+ context=self._execution_context,
193
+ device=self._device,
194
+ input_name=self._input_name,
195
+ outputs=self._output_names,
196
+ )[0]
197
+
198
+ def post_process(
199
+ self,
200
+ model_results: torch.Tensor,
201
+ **kwargs,
202
+ ) -> ClassificationPrediction:
203
+ if self._inference_config.post_processing.fused:
204
+ confidence = model_results
205
+ else:
206
+ confidence = torch.nn.functional.softmax(model_results, dim=-1)
207
+ return ClassificationPrediction(
208
+ class_id=confidence.argmax(dim=-1),
209
+ confidence=confidence,
210
+ )
211
+
212
+
213
+ class ResNetForMultiLabelClassificationTRT(
214
+ MultiLabelClassificationModel[torch.Tensor, torch.Tensor]
215
+ ):
216
+
217
+ @classmethod
218
+ def from_pretrained(
219
+ cls,
220
+ model_name_or_path: str,
221
+ device: torch.device = DEFAULT_DEVICE,
222
+ engine_host_code_allowed: bool = False,
223
+ **kwargs,
224
+ ) -> "ResNetForMultiLabelClassificationTRT":
225
+ if device.type != "cuda":
226
+ raise ModelRuntimeError(
227
+ message=f"TRT engine only runs on CUDA device - {device} device detected.",
228
+ help_url="https://todo",
229
+ )
230
+ model_package_content = get_model_package_contents(
231
+ model_package_dir=model_name_or_path,
232
+ elements=[
233
+ "class_names.txt",
234
+ "inference_config.json",
235
+ "trt_config.json",
236
+ "engine.plan",
237
+ ],
238
+ )
239
+ class_names = parse_class_names_file(
240
+ class_names_path=model_package_content["class_names.txt"]
241
+ )
242
+ inference_config = parse_inference_config(
243
+ config_path=model_package_content["inference_config.json"],
244
+ allowed_resize_modes={
245
+ ResizeMode.STRETCH_TO,
246
+ ResizeMode.LETTERBOX,
247
+ ResizeMode.CENTER_CROP,
248
+ ResizeMode.LETTERBOX_REFLECT_EDGES,
249
+ },
250
+ )
251
+ if inference_config.post_processing.type != "sigmoid":
252
+ raise CorruptedModelPackageError(
253
+ message="Expected sigmoid to be the post-processing",
254
+ help_url="https://todo",
255
+ )
256
+ trt_config = parse_trt_config(
257
+ config_path=model_package_content["trt_config.json"]
258
+ )
259
+ cuda.init()
260
+ cuda_device = cuda.Device(device.index or 0)
261
+ with use_primary_cuda_context(cuda_device=cuda_device) as cuda_context:
262
+ engine = load_model(
263
+ model_path=model_package_content["engine.plan"],
264
+ engine_host_code_allowed=engine_host_code_allowed,
265
+ )
266
+ execution_context = engine.create_execution_context()
267
+ inputs, outputs = get_engine_inputs_and_outputs(engine=engine)
268
+ if len(inputs) != 1:
269
+ raise CorruptedModelPackageError(
270
+ message=f"Implementation assume single model input, found: {len(inputs)}.",
271
+ help_url="https://todo",
272
+ )
273
+ if len(outputs) != 1:
274
+ raise CorruptedModelPackageError(
275
+ message=f"Implementation assume single model output, found: {len(outputs)}.",
276
+ help_url="https://todo",
277
+ )
278
+ return cls(
279
+ engine=engine,
280
+ input_name=inputs[0],
281
+ output_name=outputs[0],
282
+ class_names=class_names,
283
+ inference_config=inference_config,
284
+ trt_config=trt_config,
285
+ device=device,
286
+ cuda_context=cuda_context,
287
+ execution_context=execution_context,
288
+ )
289
+
290
+ def __init__(
291
+ self,
292
+ engine: trt.ICudaEngine,
293
+ input_name: str,
294
+ output_name: str,
295
+ class_names: List[str],
296
+ inference_config: InferenceConfig,
297
+ trt_config: TRTConfig,
298
+ device: torch.device,
299
+ cuda_context: cuda.Context,
300
+ execution_context: trt.IExecutionContext,
301
+ ):
302
+ self._engine = engine
303
+ self._input_name = input_name
304
+ self._output_names = [output_name]
305
+ self._class_names = class_names
306
+ self._inference_config = inference_config
307
+ self._trt_config = trt_config
308
+ self._device = device
309
+ self._cuda_context = cuda_context
310
+ self._execution_context = execution_context
311
+ self._lock = Lock()
312
+
313
+ @property
314
+ def class_names(self) -> List[str]:
315
+ return self._class_names
316
+
317
+ def pre_process(
318
+ self,
319
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
320
+ input_color_format: Optional[ColorFormat] = None,
321
+ image_size: Optional[Tuple[int, int]] = None,
322
+ **kwargs,
323
+ ) -> torch.Tensor:
324
+ return pre_process_network_input(
325
+ images=images,
326
+ image_pre_processing=self._inference_config.image_pre_processing,
327
+ network_input=self._inference_config.network_input,
328
+ target_device=self._device,
329
+ input_color_format=input_color_format,
330
+ image_size_wh=image_size,
331
+ )[0]
332
+
333
+ def forward(
334
+ self, pre_processed_images: PreprocessedInputs, **kwargs
335
+ ) -> torch.Tensor:
336
+ with self._lock:
337
+ with use_cuda_context(context=self._cuda_context):
338
+ return infer_from_trt_engine(
339
+ pre_processed_images=pre_processed_images,
340
+ trt_config=self._trt_config,
341
+ engine=self._engine,
342
+ context=self._execution_context,
343
+ device=self._device,
344
+ input_name=self._input_name,
345
+ outputs=self._output_names,
346
+ )[0]
347
+
348
+ def post_process(
349
+ self,
350
+ model_results: torch.Tensor,
351
+ confidence: float = 0.5,
352
+ **kwargs,
353
+ ) -> List[MultiLabelClassificationPrediction]:
354
+ if self._inference_config.post_processing.fused:
355
+ model_results = model_results
356
+ else:
357
+ model_results = torch.nn.functional.sigmoid(model_results)
358
+ results = []
359
+ for batch_element_confidence in model_results:
360
+ predicted_classes = torch.argwhere(
361
+ batch_element_confidence >= confidence
362
+ ).squeeze(dim=-1)
363
+ results.append(
364
+ MultiLabelClassificationPrediction(
365
+ class_ids=predicted_classes,
366
+ confidence=batch_element_confidence,
367
+ )
368
+ )
369
+ return results
File without changes