inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,206 @@
1
+ import threading
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from inference_models import InstanceDetections, InstanceSegmentationModel
8
+ from inference_models.configuration import DEFAULT_DEVICE
9
+ from inference_models.entities import ColorFormat
10
+ from inference_models.errors import (
11
+ EnvironmentConfigurationError,
12
+ MissingDependencyError,
13
+ )
14
+ from inference_models.models.common.model_packages import get_model_package_contents
15
+ from inference_models.models.common.onnx import (
16
+ run_session_with_batch_size_limit,
17
+ set_execution_provider_defaults,
18
+ )
19
+ from inference_models.models.common.roboflow.model_packages import (
20
+ InferenceConfig,
21
+ PreProcessingMetadata,
22
+ ResizeMode,
23
+ parse_class_names_file,
24
+ parse_inference_config,
25
+ )
26
+ from inference_models.models.common.roboflow.pre_processing import (
27
+ pre_process_network_input,
28
+ )
29
+ from inference_models.models.rfdetr.class_remapping import (
30
+ ClassesReMapping,
31
+ prepare_class_remapping,
32
+ )
33
+ from inference_models.models.rfdetr.common import (
34
+ post_process_instance_segmentation_results,
35
+ )
36
+ from inference_models.utils.onnx_introspection import (
37
+ get_selected_onnx_execution_providers,
38
+ )
39
+
40
+ try:
41
+ import onnxruntime
42
+ except ImportError as import_error:
43
+ raise MissingDependencyError(
44
+ message=f"Could not import YOLOv8 model with ONNX backend - this error means that some additional dependencies "
45
+ f"are not installed in the environment. If you run the `inference-models` library directly in your Python "
46
+ f"program, make sure the following extras of the package are installed: \n"
47
+ f"\t* `onnx-cpu` - when you wish to use library with CPU support only\n"
48
+ f"\t* `onnx-cu12` - for running on GPU with Cuda 12 installed\n"
49
+ f"\t* `onnx-cu118` - for running on GPU with Cuda 11.8 installed\n"
50
+ f"\t* `onnx-jp6-cu126` - for running on Jetson with Jetpack 6\n"
51
+ f"If you see this error using Roboflow infrastructure, make sure the service you use does support the model. "
52
+ f"You can also contact Roboflow to get support.",
53
+ help_url="https://todo",
54
+ ) from import_error
55
+
56
+
57
+ class RFDetrForInstanceSegmentationOnnx(
58
+ InstanceSegmentationModel[
59
+ torch.Tensor,
60
+ PreProcessingMetadata,
61
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
62
+ ]
63
+ ):
64
+
65
+ @classmethod
66
+ def from_pretrained(
67
+ cls,
68
+ model_name_or_path: str,
69
+ onnx_execution_providers: Optional[List[Union[str, tuple]]] = None,
70
+ default_onnx_trt_options: bool = True,
71
+ device: torch.device = DEFAULT_DEVICE,
72
+ **kwargs,
73
+ ) -> "RFDetrForInstanceSegmentationOnnx":
74
+ if onnx_execution_providers is None:
75
+ onnx_execution_providers = get_selected_onnx_execution_providers()
76
+ if not onnx_execution_providers:
77
+ raise EnvironmentConfigurationError(
78
+ message=f"Could not initialize model - selected backend is ONNX which requires execution provider to "
79
+ f"be specified - explicitly in `from_pretrained(...)` method or via env variable "
80
+ f"`ONNXRUNTIME_EXECUTION_PROVIDERS`. If you run model locally - adjust your setup, otherwise "
81
+ f"contact the platform support.",
82
+ help_url="https://todo",
83
+ )
84
+ onnx_execution_providers = set_execution_provider_defaults(
85
+ providers=onnx_execution_providers,
86
+ model_package_path=model_name_or_path,
87
+ device=device,
88
+ default_onnx_trt_options=default_onnx_trt_options,
89
+ )
90
+ model_package_content = get_model_package_contents(
91
+ model_package_dir=model_name_or_path,
92
+ elements=[
93
+ "class_names.txt",
94
+ "inference_config.json",
95
+ "weights.onnx",
96
+ ],
97
+ )
98
+ class_names = parse_class_names_file(
99
+ class_names_path=model_package_content["class_names.txt"]
100
+ )
101
+ inference_config = parse_inference_config(
102
+ config_path=model_package_content["inference_config.json"],
103
+ allowed_resize_modes={
104
+ ResizeMode.STRETCH_TO,
105
+ ResizeMode.LETTERBOX,
106
+ ResizeMode.CENTER_CROP,
107
+ ResizeMode.LETTERBOX_REFLECT_EDGES,
108
+ },
109
+ )
110
+ classes_re_mapping = None
111
+ if inference_config.class_names_operations:
112
+ class_names, classes_re_mapping = prepare_class_remapping(
113
+ class_names=class_names,
114
+ class_names_operations=inference_config.class_names_operations,
115
+ device=device,
116
+ )
117
+ session = onnxruntime.InferenceSession(
118
+ path_or_bytes=model_package_content["weights.onnx"],
119
+ providers=onnx_execution_providers,
120
+ )
121
+ input_batch_size = session.get_inputs()[0].shape[0]
122
+ if isinstance(input_batch_size, str):
123
+ input_batch_size = None
124
+ input_name = session.get_inputs()[0].name
125
+ return cls(
126
+ session=session,
127
+ input_name=input_name,
128
+ class_names=class_names,
129
+ classes_re_mapping=classes_re_mapping,
130
+ inference_config=inference_config,
131
+ device=device,
132
+ input_batch_size=input_batch_size,
133
+ )
134
+
135
+ def __init__(
136
+ self,
137
+ session: onnxruntime.InferenceSession,
138
+ input_name: str,
139
+ class_names: List[str],
140
+ classes_re_mapping: Optional[ClassesReMapping],
141
+ inference_config: InferenceConfig,
142
+ device: torch.device,
143
+ input_batch_size: Optional[int],
144
+ ):
145
+ self._session = session
146
+ self._input_name = input_name
147
+ self._inference_config = inference_config
148
+ self._class_names = class_names
149
+ self._classes_re_mapping = classes_re_mapping
150
+ self._device = device
151
+ self._min_batch_size = input_batch_size
152
+ self._max_batch_size = (
153
+ input_batch_size
154
+ if input_batch_size is not None
155
+ else inference_config.forward_pass.max_dynamic_batch_size
156
+ )
157
+ self._session_thread_lock = threading.Lock()
158
+
159
+ @property
160
+ def class_names(self) -> List[str]:
161
+ return self._class_names
162
+
163
+ def pre_process(
164
+ self,
165
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
166
+ input_color_format: Optional[ColorFormat] = None,
167
+ image_size: Optional[Tuple[int, int]] = None,
168
+ **kwargs,
169
+ ) -> Tuple[torch.Tensor, List[PreProcessingMetadata]]:
170
+ return pre_process_network_input(
171
+ images=images,
172
+ image_pre_processing=self._inference_config.image_pre_processing,
173
+ network_input=self._inference_config.network_input,
174
+ target_device=self._device,
175
+ input_color_format=input_color_format,
176
+ image_size_wh=image_size,
177
+ )
178
+
179
+ def forward(
180
+ self, pre_processed_images: torch.Tensor, **kwargs
181
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
182
+ with self._session_thread_lock:
183
+ bboxes, logits, masks = run_session_with_batch_size_limit(
184
+ session=self._session,
185
+ inputs={self._input_name: pre_processed_images},
186
+ min_batch_size=self._min_batch_size,
187
+ max_batch_size=self._max_batch_size,
188
+ )
189
+ return bboxes, logits, masks
190
+
191
+ def post_process(
192
+ self,
193
+ model_results: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
194
+ pre_processing_meta: List[PreProcessingMetadata],
195
+ threshold: float = 0.5,
196
+ **kwargs,
197
+ ) -> List[InstanceDetections]:
198
+ bboxes, logits, masks = model_results
199
+ return post_process_instance_segmentation_results(
200
+ bboxes=bboxes,
201
+ logits=logits,
202
+ masks=masks,
203
+ pre_processing_meta=pre_processing_meta,
204
+ threshold=threshold,
205
+ classes_re_mapping=self._classes_re_mapping,
206
+ )
@@ -0,0 +1,373 @@
1
+ import os.path
2
+ from copy import deepcopy
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from inference_models import InstanceDetections, InstanceSegmentationModel
9
+ from inference_models.configuration import DEFAULT_DEVICE
10
+ from inference_models.entities import ColorFormat
11
+ from inference_models.errors import (
12
+ CorruptedModelPackageError,
13
+ ModelLoadingError,
14
+ ModelRuntimeError,
15
+ )
16
+ from inference_models.logger import LOGGER
17
+ from inference_models.models.common.model_packages import get_model_package_contents
18
+ from inference_models.models.common.roboflow.model_packages import (
19
+ ColorMode,
20
+ DivisiblePadding,
21
+ InferenceConfig,
22
+ NetworkInputDefinition,
23
+ PreProcessingMetadata,
24
+ ResizeMode,
25
+ TrainingInputSize,
26
+ parse_class_names_file,
27
+ parse_inference_config,
28
+ )
29
+ from inference_models.models.common.roboflow.pre_processing import (
30
+ pre_process_network_input,
31
+ )
32
+ from inference_models.models.rfdetr.class_remapping import (
33
+ ClassesReMapping,
34
+ prepare_class_remapping,
35
+ )
36
+ from inference_models.models.rfdetr.common import (
37
+ parse_model_type,
38
+ post_process_instance_segmentation_results,
39
+ )
40
+ from inference_models.models.rfdetr.default_labels import resolve_labels
41
+ from inference_models.models.rfdetr.post_processor import PostProcess
42
+ from inference_models.models.rfdetr.rfdetr_base_pytorch import (
43
+ LWDETR,
44
+ RFDETRSegPreviewConfig,
45
+ build_model,
46
+ )
47
+
48
+ try:
49
+ torch.set_float32_matmul_precision("high")
50
+ except:
51
+ pass
52
+
53
+ CONFIG_FOR_MODEL_TYPE = {
54
+ "rfdetr-seg-preview": RFDETRSegPreviewConfig,
55
+ }
56
+
57
+
58
+ class RFDetrForInstanceSegmentationTorch(
59
+ InstanceSegmentationModel[
60
+ torch.Tensor,
61
+ PreProcessingMetadata,
62
+ dict,
63
+ ]
64
+ ):
65
+
66
+ @classmethod
67
+ def from_pretrained(
68
+ cls,
69
+ model_name_or_path: str,
70
+ device: torch.device = DEFAULT_DEVICE,
71
+ model_type: Optional[str] = None,
72
+ labels: Optional[Union[str, List[str]]] = None,
73
+ resolution: Optional[int] = None,
74
+ **kwargs,
75
+ ) -> "RFDetrForInstanceSegmentationTorch":
76
+ if os.path.isfile(model_name_or_path):
77
+ return cls.from_checkpoint_file(
78
+ checkpoint_path=model_name_or_path,
79
+ model_type=model_type,
80
+ labels=labels,
81
+ resolution=resolution,
82
+ )
83
+ model_package_content = get_model_package_contents(
84
+ model_package_dir=model_name_or_path,
85
+ elements=[
86
+ "class_names.txt",
87
+ "inference_config.json",
88
+ "model_type.json",
89
+ "weights.pth",
90
+ ],
91
+ )
92
+ class_names = parse_class_names_file(
93
+ class_names_path=model_package_content["class_names.txt"]
94
+ )
95
+ inference_config = parse_inference_config(
96
+ config_path=model_package_content["inference_config.json"],
97
+ allowed_resize_modes={
98
+ ResizeMode.STRETCH_TO,
99
+ ResizeMode.LETTERBOX,
100
+ ResizeMode.CENTER_CROP,
101
+ ResizeMode.LETTERBOX_REFLECT_EDGES,
102
+ },
103
+ )
104
+ classes_re_mapping = None
105
+ if inference_config.class_names_operations:
106
+ class_names, classes_re_mapping = prepare_class_remapping(
107
+ class_names=class_names,
108
+ class_names_operations=inference_config.class_names_operations,
109
+ device=device,
110
+ )
111
+ weights_dict = torch.load(
112
+ model_package_content["weights.pth"],
113
+ map_location=device,
114
+ weights_only=False,
115
+ )["model"]
116
+ model_type = parse_model_type(
117
+ config_path=model_package_content["model_type.json"]
118
+ )
119
+ if model_type not in CONFIG_FOR_MODEL_TYPE:
120
+ raise CorruptedModelPackageError(
121
+ message=f"Model package describes model_type as '{model_type}' which is not supported. "
122
+ f"Supported model types: {list(CONFIG_FOR_MODEL_TYPE.keys())}.",
123
+ help_url="https://todo",
124
+ )
125
+ model_config = CONFIG_FOR_MODEL_TYPE[model_type](device=device)
126
+ checkpoint_num_classes = weights_dict["class_embed.bias"].shape[0]
127
+ model_config.num_classes = checkpoint_num_classes - 1
128
+ model_config.resolution = (
129
+ inference_config.network_input.training_input_size.height
130
+ )
131
+ model = build_model(config=model_config)
132
+ model.load_state_dict(weights_dict)
133
+ model = model.eval().to(device)
134
+ post_processor = PostProcess()
135
+ return cls(
136
+ model=model,
137
+ class_names=class_names,
138
+ classes_re_mapping=classes_re_mapping,
139
+ device=device,
140
+ inference_config=inference_config,
141
+ post_processor=post_processor,
142
+ resolution=model_config.resolution,
143
+ )
144
+
145
+ @classmethod
146
+ def from_checkpoint_file(
147
+ cls,
148
+ checkpoint_path: str,
149
+ model_type: Optional[str] = "rfdetr-seg-preview",
150
+ labels: Optional[Union[str, List[str]]] = None,
151
+ resolution: Optional[int] = None,
152
+ device: torch.device = DEFAULT_DEVICE,
153
+ ):
154
+ if model_type is None:
155
+ raise ModelLoadingError(
156
+ message="While loading RFDetr model (using torch backend) could not determine `model_type`. "
157
+ "If you used `RFDetrForObjectDetectionTorch` directly imported in your code, please pass "
158
+ f"one of the value: {CONFIG_FOR_MODEL_TYPE.keys()} as the parameter. If you see this "
159
+ f"error, while using `AutoModel.from_pretrained(...)` or thrown from managed Roboflow service, "
160
+ f"this is a bug - raise the issue: https://github.com/roboflow/inference/issue providing "
161
+ f"full context.",
162
+ help_url="https://todo",
163
+ )
164
+ weights_dict = torch.load(
165
+ checkpoint_path,
166
+ map_location=device,
167
+ weights_only=False,
168
+ )["model"]
169
+ if model_type not in CONFIG_FOR_MODEL_TYPE:
170
+ raise ModelLoadingError(
171
+ message=f"Model package describes model_type as '{model_type}' which is not supported. "
172
+ f"Supported model types: {list(CONFIG_FOR_MODEL_TYPE.keys())}.",
173
+ help_url="https://todo",
174
+ )
175
+ model_config = CONFIG_FOR_MODEL_TYPE[model_type](device=device)
176
+ divisibility = model_config.num_windows * model_config.patch_size
177
+ if resolution is not None:
178
+ if resolution < 0 or resolution % divisibility != 0:
179
+ raise ModelLoadingError(
180
+ message=f"Attempted to load RFDetr model (using torch backend) with `resolution` parameter which "
181
+ f"is invalid - the model required positive value divisible by 56. Make sure you used "
182
+ f"proper value, corresponding to the one used to train the model.",
183
+ help_url="https://todo",
184
+ )
185
+ model_config.resolution = resolution
186
+ inference_config = InferenceConfig(
187
+ network_input=NetworkInputDefinition(
188
+ training_input_size=TrainingInputSize(
189
+ height=model_config.resolution,
190
+ width=model_config.resolution,
191
+ ),
192
+ dynamic_spatial_size_supported=True,
193
+ dynamic_spatial_size_mode=DivisiblePadding(
194
+ type="pad-to-be-divisible",
195
+ value=divisibility,
196
+ ),
197
+ color_mode=ColorMode.BGR,
198
+ resize_mode=ResizeMode.STRETCH_TO,
199
+ input_channels=3,
200
+ scaling_factor=255,
201
+ normalization=([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
202
+ )
203
+ )
204
+ checkpoint_num_classes = weights_dict["class_embed.bias"].shape[0]
205
+ model_config.num_classes = checkpoint_num_classes - 1
206
+ model = build_model(config=model_config)
207
+ if labels is None:
208
+ class_names = [f"class_{i}" for i in range(checkpoint_num_classes)]
209
+ elif isinstance(labels, str):
210
+ class_names = resolve_labels(labels=labels)
211
+ else:
212
+ class_names = labels
213
+ if checkpoint_num_classes != len(class_names):
214
+ raise ModelLoadingError(
215
+ message=f"Checkpoint pointed to load RFDetr defines {checkpoint_num_classes} output classes, but "
216
+ f"loaded labels define {len(class_names)} classes - fix the value of `labels` parameter.",
217
+ help_url="https://todo",
218
+ )
219
+ model.load_state_dict(weights_dict)
220
+ model = model.eval().to(device)
221
+ post_processor = PostProcess()
222
+ return cls(
223
+ model=model,
224
+ class_names=class_names,
225
+ classes_re_mapping=None,
226
+ device=device,
227
+ inference_config=inference_config,
228
+ post_processor=post_processor,
229
+ resolution=model_config.resolution,
230
+ )
231
+
232
+ def __init__(
233
+ self,
234
+ model: LWDETR,
235
+ inference_config: InferenceConfig,
236
+ class_names: List[str],
237
+ classes_re_mapping: Optional[ClassesReMapping],
238
+ device: torch.device,
239
+ post_processor: PostProcess,
240
+ resolution: int,
241
+ ):
242
+ self._model = model
243
+ self._inference_config = inference_config
244
+ self._class_names = class_names
245
+ self._classes_re_mapping = classes_re_mapping
246
+ self._post_processor = post_processor
247
+ self._device = device
248
+ self._resolution = resolution
249
+ self._has_warned_about_not_being_optimized_for_inference = False
250
+ self._inference_model: Optional[LWDETR] = None
251
+ self._optimized_has_been_compiled = False
252
+ self._optimized_batch_size = None
253
+ self._optimized_dtype = None
254
+
255
+ @property
256
+ def class_names(self) -> List[str]:
257
+ return self._class_names
258
+
259
+ def optimize_for_inference(
260
+ self,
261
+ compile: bool = True,
262
+ batch_size: int = 1,
263
+ dtype: torch.dtype = torch.float32,
264
+ ) -> None:
265
+ self.remove_optimized_model()
266
+ self._inference_model = deepcopy(self._model)
267
+ self._inference_model.eval()
268
+ self._inference_model.export()
269
+ self._inference_model = self._inference_model.to(dtype=dtype)
270
+ self._optimized_dtype = dtype
271
+ if compile:
272
+ self._inference_model = torch.jit.trace(
273
+ self._inference_model,
274
+ torch.randn(
275
+ batch_size,
276
+ 3,
277
+ self._resolution,
278
+ self._resolution,
279
+ device=self._device,
280
+ dtype=dtype,
281
+ ),
282
+ )
283
+ self._optimized_has_been_compiled = True
284
+ self._optimized_batch_size = batch_size
285
+
286
+ def remove_optimized_model(self) -> None:
287
+ self._has_warned_about_not_being_optimized_for_inference = False
288
+ self._inference_model = None
289
+ self._optimized_has_been_compiled = False
290
+ self._optimized_batch_size = None
291
+
292
+ def pre_process(
293
+ self,
294
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
295
+ input_color_format: Optional[ColorFormat] = None,
296
+ image_size: Optional[Tuple[int, int]] = None,
297
+ **kwargs,
298
+ ) -> Tuple[torch.Tensor, List[PreProcessingMetadata]]:
299
+ return pre_process_network_input(
300
+ images=images,
301
+ image_pre_processing=self._inference_config.image_pre_processing,
302
+ network_input=self._inference_config.network_input,
303
+ target_device=self._device,
304
+ input_color_format=input_color_format,
305
+ image_size_wh=image_size,
306
+ )
307
+
308
+ def forward(self, pre_processed_images: torch.Tensor, **kwargs) -> dict:
309
+ if (
310
+ self._inference_model is None
311
+ and not self._has_warned_about_not_being_optimized_for_inference
312
+ ):
313
+ LOGGER.warning(
314
+ "Model is not optimized for inference. "
315
+ "Latency may be higher than expected. "
316
+ "You can optimize the model for inference by calling model.optimize_for_inference()."
317
+ )
318
+ self._has_warned_about_not_being_optimized_for_inference = True
319
+ if self._inference_model is not None:
320
+ if (self._resolution, self._resolution) != tuple(
321
+ pre_processed_images.shape[2:]
322
+ ):
323
+ raise ModelRuntimeError(
324
+ message=f"Resolution mismatch. Model was optimized for resolution {self._resolution}, "
325
+ f"but got {tuple(pre_processed_images.shape[2:])}. "
326
+ "You can explicitly remove the optimized model by calling model.remove_optimized_model().",
327
+ help_url="https://todo",
328
+ )
329
+ if self._optimized_has_been_compiled:
330
+ if self._optimized_batch_size != pre_processed_images.shape[0]:
331
+ raise ModelRuntimeError(
332
+ message="Batch size mismatch. Optimized model was compiled for batch size "
333
+ f"{self._optimized_batch_size}, but got {pre_processed_images.shape[0]}. "
334
+ "You can explicitly remove the optimized model by calling model.remove_optimized_model(). "
335
+ "Alternatively, you can recompile the optimized model for a different batch size "
336
+ "by calling model.optimize_for_inference(batch_size=<new_batch_size>).",
337
+ help_url="https://todo",
338
+ )
339
+ with torch.inference_mode():
340
+ if self._inference_model:
341
+ predictions = self._inference_model(
342
+ pre_processed_images.to(dtype=self._optimized_dtype)
343
+ )
344
+ else:
345
+ predictions = self._model(pre_processed_images)
346
+ if isinstance(predictions, tuple):
347
+ predictions = {
348
+ "pred_logits": predictions[1],
349
+ "pred_boxes": predictions[0],
350
+ "pred_masks": predictions[2],
351
+ }
352
+ return predictions
353
+
354
+ def post_process(
355
+ self,
356
+ model_results: dict,
357
+ pre_processing_meta: List[PreProcessingMetadata],
358
+ threshold: float = 0.5,
359
+ **kwargs,
360
+ ) -> List[InstanceDetections]:
361
+ bboxes, logits, masks = (
362
+ model_results["pred_boxes"],
363
+ model_results["pred_logits"],
364
+ model_results["pred_masks"],
365
+ )
366
+ return post_process_instance_segmentation_results(
367
+ bboxes=bboxes,
368
+ logits=logits,
369
+ masks=masks,
370
+ pre_processing_meta=pre_processing_meta,
371
+ threshold=threshold,
372
+ classes_re_mapping=self._classes_re_mapping,
373
+ )