inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,270 @@
1
+ import threading
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from inference_models import Detections, ObjectDetectionModel
8
+ from inference_models.configuration import DEFAULT_DEVICE
9
+ from inference_models.entities import ColorFormat
10
+ from inference_models.errors import (
11
+ CorruptedModelPackageError,
12
+ MissingDependencyError,
13
+ ModelRuntimeError,
14
+ )
15
+ from inference_models.models.common.cuda import (
16
+ use_cuda_context,
17
+ use_primary_cuda_context,
18
+ )
19
+ from inference_models.models.common.model_packages import get_model_package_contents
20
+ from inference_models.models.common.roboflow.model_packages import (
21
+ InferenceConfig,
22
+ PreProcessingMetadata,
23
+ ResizeMode,
24
+ TRTConfig,
25
+ parse_class_names_file,
26
+ parse_inference_config,
27
+ parse_trt_config,
28
+ )
29
+ from inference_models.models.common.roboflow.post_processing import (
30
+ rescale_image_detections,
31
+ )
32
+ from inference_models.models.common.roboflow.pre_processing import (
33
+ pre_process_network_input,
34
+ )
35
+ from inference_models.models.common.trt import (
36
+ get_engine_inputs_and_outputs,
37
+ infer_from_trt_engine,
38
+ load_model,
39
+ )
40
+ from inference_models.models.rfdetr.class_remapping import (
41
+ ClassesReMapping,
42
+ prepare_class_remapping,
43
+ )
44
+
45
+ try:
46
+ import tensorrt as trt
47
+ except ImportError as import_error:
48
+ raise MissingDependencyError(
49
+ message=f"Could not import RFDetr model with TRT backend - this error means that some additional dependencies "
50
+ f"are not installed in the environment. If you run the `inference-models` library directly in your Python "
51
+ f"program, make sure the following extras of the package are installed: `trt10` - installation can only "
52
+ f"succeed for Linux and Windows machines with Cuda 12 installed. Jetson devices, should have TRT 10.x "
53
+ f"installed for all builds with Jetpack 6. "
54
+ f"If you see this error using Roboflow infrastructure, make sure the service you use does support the model. "
55
+ f"You can also contact Roboflow to get support.",
56
+ help_url="https://todo",
57
+ ) from import_error
58
+
59
+ try:
60
+ import pycuda.driver as cuda
61
+ except ImportError as import_error:
62
+ raise MissingDependencyError(
63
+ message="TODO", help_url="https://todo"
64
+ ) from import_error
65
+
66
+
67
+ class RFDetrForObjectDetectionTRT(
68
+ (
69
+ ObjectDetectionModel[
70
+ torch.Tensor, PreProcessingMetadata, Tuple[torch.Tensor, torch.Tensor]
71
+ ]
72
+ )
73
+ ):
74
+
75
+ @classmethod
76
+ def from_pretrained(
77
+ cls,
78
+ model_name_or_path: str,
79
+ device: torch.device = DEFAULT_DEVICE,
80
+ engine_host_code_allowed: bool = False,
81
+ **kwargs,
82
+ ) -> "RFDetrForObjectDetectionTRT":
83
+ if device.type != "cuda":
84
+ raise ModelRuntimeError(
85
+ message=f"TRT engine only runs on CUDA device - {device} device detected.",
86
+ help_url="https://todo",
87
+ )
88
+ model_package_content = get_model_package_contents(
89
+ model_package_dir=model_name_or_path,
90
+ elements=[
91
+ "class_names.txt",
92
+ "inference_config.json",
93
+ "trt_config.json",
94
+ "engine.plan",
95
+ ],
96
+ )
97
+ class_names = parse_class_names_file(
98
+ class_names_path=model_package_content["class_names.txt"]
99
+ )
100
+ inference_config = parse_inference_config(
101
+ config_path=model_package_content["inference_config.json"],
102
+ allowed_resize_modes={
103
+ ResizeMode.STRETCH_TO,
104
+ ResizeMode.LETTERBOX,
105
+ ResizeMode.CENTER_CROP,
106
+ ResizeMode.LETTERBOX_REFLECT_EDGES,
107
+ },
108
+ )
109
+ classes_re_mapping = None
110
+ if inference_config.class_names_operations:
111
+ class_names, classes_re_mapping = prepare_class_remapping(
112
+ class_names=class_names,
113
+ class_names_operations=inference_config.class_names_operations,
114
+ device=device,
115
+ )
116
+ trt_config = parse_trt_config(
117
+ config_path=model_package_content["trt_config.json"]
118
+ )
119
+ cuda.init()
120
+ cuda_device = cuda.Device(device.index or 0)
121
+ with use_primary_cuda_context(cuda_device=cuda_device) as cuda_context:
122
+ engine = load_model(
123
+ model_path=model_package_content["engine.plan"],
124
+ engine_host_code_allowed=engine_host_code_allowed,
125
+ )
126
+ execution_context = engine.create_execution_context()
127
+ inputs, outputs = get_engine_inputs_and_outputs(engine=engine)
128
+ if len(inputs) != 1:
129
+ raise CorruptedModelPackageError(
130
+ message=f"Implementation assume single model input, found: {len(inputs)}.",
131
+ help_url="https://todo",
132
+ )
133
+ if len(outputs) != 2:
134
+ raise CorruptedModelPackageError(
135
+ message=f"Implementation assume 2 model outputs, found: {len(outputs)}.",
136
+ help_url="https://todo",
137
+ )
138
+ if "dets" not in outputs or "labels" not in outputs:
139
+ raise CorruptedModelPackageError(
140
+ message=f"Expected model outputs to be named `output0` and `output1`, but found: {outputs}.",
141
+ help_url="https://todo",
142
+ )
143
+ return cls(
144
+ engine=engine,
145
+ input_name=inputs[0],
146
+ output_names=["dets", "labels"],
147
+ class_names=class_names,
148
+ classes_re_mapping=classes_re_mapping,
149
+ inference_config=inference_config,
150
+ trt_config=trt_config,
151
+ device=device,
152
+ cuda_context=cuda_context,
153
+ execution_context=execution_context,
154
+ )
155
+
156
+ def __init__(
157
+ self,
158
+ engine: trt.ICudaEngine,
159
+ input_name: str,
160
+ output_names: List[str],
161
+ class_names: List[str],
162
+ classes_re_mapping: Optional[ClassesReMapping],
163
+ inference_config: InferenceConfig,
164
+ trt_config: TRTConfig,
165
+ device: torch.device,
166
+ cuda_context: cuda.Context,
167
+ execution_context: trt.IExecutionContext,
168
+ ):
169
+ self._engine = engine
170
+ self._input_name = input_name
171
+ self._output_names = output_names
172
+ self._inference_config = inference_config
173
+ self._class_names = class_names
174
+ self._classes_re_mapping = classes_re_mapping
175
+ self._device = device
176
+ self._cuda_context = cuda_context
177
+ self._execution_context = execution_context
178
+ self._trt_config = trt_config
179
+ self._lock = threading.Lock()
180
+
181
+ @property
182
+ def class_names(self) -> List[str]:
183
+ return self._class_names
184
+
185
+ def pre_process(
186
+ self,
187
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
188
+ input_color_format: Optional[ColorFormat] = None,
189
+ **kwargs,
190
+ ) -> Tuple[torch.Tensor, List[PreProcessingMetadata]]:
191
+ return pre_process_network_input(
192
+ images=images,
193
+ image_pre_processing=self._inference_config.image_pre_processing,
194
+ network_input=self._inference_config.network_input,
195
+ target_device=self._device,
196
+ input_color_format=input_color_format,
197
+ )
198
+
199
+ def forward(
200
+ self, pre_processed_images: torch.Tensor, **kwargs
201
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
202
+ with self._lock:
203
+ with use_cuda_context(context=self._cuda_context):
204
+ detections, labels = infer_from_trt_engine(
205
+ pre_processed_images=pre_processed_images,
206
+ trt_config=self._trt_config,
207
+ engine=self._engine,
208
+ context=self._execution_context,
209
+ device=self._device,
210
+ input_name=self._input_name,
211
+ outputs=self._output_names,
212
+ )
213
+ return detections, labels
214
+
215
+ def post_process(
216
+ self,
217
+ model_results: Tuple[torch.Tensor, torch.Tensor],
218
+ pre_processing_meta: List[PreProcessingMetadata],
219
+ threshold: float = 0.5,
220
+ **kwargs,
221
+ ) -> List[Detections]:
222
+ bboxes, logits = model_results
223
+ logits_sigmoid = torch.nn.functional.sigmoid(logits)
224
+ results = []
225
+ for image_bboxes, image_logits, image_meta in zip(
226
+ bboxes, logits_sigmoid, pre_processing_meta
227
+ ):
228
+ confidence, top_classes = image_logits.max(dim=1)
229
+ confidence_mask = confidence > threshold
230
+ confidence = confidence[confidence_mask]
231
+ top_classes = top_classes[confidence_mask]
232
+ selected_boxes = image_bboxes[confidence_mask]
233
+ confidence, sorted_indices = torch.sort(confidence, descending=True)
234
+ top_classes = top_classes[sorted_indices]
235
+ selected_boxes = selected_boxes[sorted_indices]
236
+ if self._classes_re_mapping is not None:
237
+ remapping_mask = torch.isin(
238
+ top_classes, self._classes_re_mapping.remaining_class_ids
239
+ )
240
+ top_classes = self._classes_re_mapping.class_mapping[
241
+ top_classes[remapping_mask]
242
+ ]
243
+ selected_boxes = selected_boxes[remapping_mask]
244
+ confidence = confidence[remapping_mask]
245
+ cxcy = selected_boxes[:, :2]
246
+ wh = selected_boxes[:, 2:]
247
+ xy_min = cxcy - 0.5 * wh
248
+ xy_max = cxcy + 0.5 * wh
249
+ selected_boxes_xyxy_pct = torch.cat([xy_min, xy_max], dim=-1)
250
+ inference_size_hwhw = torch.tensor(
251
+ [
252
+ image_meta.inference_size.height,
253
+ image_meta.inference_size.width,
254
+ image_meta.inference_size.height,
255
+ image_meta.inference_size.width,
256
+ ],
257
+ device=self._device,
258
+ )
259
+ selected_boxes_xyxy = selected_boxes_xyxy_pct * inference_size_hwhw
260
+ selected_boxes_xyxy = rescale_image_detections(
261
+ image_detections=selected_boxes_xyxy,
262
+ image_metadata=image_meta,
263
+ )
264
+ detections = Detections(
265
+ xyxy=selected_boxes_xyxy.round().int(),
266
+ confidence=confidence,
267
+ class_id=top_classes.int(),
268
+ )
269
+ results.append(detections)
270
+ return results
@@ -0,0 +1,273 @@
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+
7
+
8
+ from typing import Callable
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+
15
+ class DepthwiseConvBlock(nn.Module):
16
+ r"""Simplified ConvNeXt block without the MLP subnet"""
17
+
18
+ def __init__(self, dim, layer_scale_init_value=0):
19
+ super().__init__()
20
+ self.dwconv = nn.Conv2d(
21
+ dim, dim, kernel_size=3, padding=1, groups=dim
22
+ ) # depthwise conv
23
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
24
+ self.pwconv1 = nn.Linear(
25
+ dim, dim
26
+ ) # pointwise/1x1 convs, implemented with linear layers
27
+ self.act = nn.GELU()
28
+ self.gamma = (
29
+ nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
30
+ if layer_scale_init_value > 0
31
+ else None
32
+ )
33
+
34
+ def forward(self, x):
35
+ input = x
36
+ x = self.dwconv(x)
37
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
38
+ x = self.norm(x)
39
+ x = self.pwconv1(x)
40
+ x = self.act(x)
41
+ if self.gamma is not None:
42
+ x = self.gamma * x
43
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
44
+
45
+ return x + input
46
+
47
+
48
+ class MLPBlock(nn.Module):
49
+ def __init__(self, dim, layer_scale_init_value=0):
50
+ super().__init__()
51
+ self.norm_in = nn.LayerNorm(dim)
52
+ self.layers = nn.ModuleList(
53
+ [
54
+ nn.Linear(dim, dim * 4),
55
+ nn.GELU(),
56
+ nn.Linear(dim * 4, dim),
57
+ ]
58
+ )
59
+ self.gamma = (
60
+ nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
61
+ if layer_scale_init_value > 0
62
+ else None
63
+ )
64
+
65
+ def forward(self, x):
66
+ input = x
67
+ x = self.norm_in(x)
68
+ for layer in self.layers:
69
+ x = layer(x)
70
+ if self.gamma is not None:
71
+ x = self.gamma * x
72
+ return x + input
73
+
74
+
75
+ class SegmentationHead(nn.Module):
76
+ def __init__(
77
+ self,
78
+ in_dim,
79
+ num_blocks: int,
80
+ bottleneck_ratio: int = 1,
81
+ downsample_ratio: int = 4,
82
+ ):
83
+ super().__init__()
84
+
85
+ self.downsample_ratio = downsample_ratio
86
+ self.interaction_dim = (
87
+ in_dim // bottleneck_ratio if bottleneck_ratio is not None else in_dim
88
+ )
89
+ self.blocks = nn.ModuleList(
90
+ [DepthwiseConvBlock(in_dim) for _ in range(num_blocks)]
91
+ )
92
+ self.spatial_features_proj = (
93
+ nn.Identity()
94
+ if bottleneck_ratio is None
95
+ else nn.Conv2d(in_dim, self.interaction_dim, kernel_size=1)
96
+ )
97
+
98
+ self.query_features_block = MLPBlock(in_dim)
99
+ self.query_features_proj = (
100
+ nn.Identity()
101
+ if bottleneck_ratio is None
102
+ else nn.Linear(in_dim, self.interaction_dim)
103
+ )
104
+
105
+ self.bias = nn.Parameter(torch.zeros(1), requires_grad=True)
106
+
107
+ self._export = False
108
+
109
+ def export(self):
110
+ self._export = True
111
+ self._forward_origin = self.forward
112
+ self.forward = self.forward_export
113
+ for name, m in self.named_modules():
114
+ if (
115
+ hasattr(m, "export")
116
+ and isinstance(m.export, Callable)
117
+ and hasattr(m, "_export")
118
+ and not m._export
119
+ ):
120
+ m.export()
121
+
122
+ def forward(
123
+ self,
124
+ spatial_features: torch.Tensor,
125
+ query_features: list[torch.Tensor],
126
+ image_size: tuple[int, int],
127
+ skip_blocks: bool = False,
128
+ ) -> list[torch.Tensor]:
129
+ # spatial features: (B, C, H, W)
130
+ # query features: [(B, N, C)] for each decoder layer
131
+ # output: (B, N, H*r, W*r)
132
+ target_size = (
133
+ image_size[0] // self.downsample_ratio,
134
+ image_size[1] // self.downsample_ratio,
135
+ )
136
+ spatial_features = F.interpolate(
137
+ spatial_features, size=target_size, mode="bilinear", align_corners=False
138
+ )
139
+
140
+ mask_logits = []
141
+ if not skip_blocks:
142
+ for block, qf in zip(self.blocks, query_features):
143
+ spatial_features = block(spatial_features)
144
+ spatial_features_proj = self.spatial_features_proj(spatial_features)
145
+ qf = self.query_features_proj(self.query_features_block(qf))
146
+ mask_logits.append(
147
+ torch.einsum("bchw,bnc->bnhw", spatial_features_proj, qf)
148
+ + self.bias
149
+ )
150
+ else:
151
+ assert (
152
+ len(query_features) == 1
153
+ ), "skip_blocks is only supported for length 1 query features"
154
+ qf = self.query_features_proj(self.query_features_block(query_features[0]))
155
+ mask_logits.append(
156
+ torch.einsum("bchw,bnc->bnhw", spatial_features, qf) + self.bias
157
+ )
158
+
159
+ return mask_logits
160
+
161
+ def forward_export(
162
+ self,
163
+ spatial_features: torch.Tensor,
164
+ query_features: list[torch.Tensor],
165
+ image_size: tuple[int, int],
166
+ skip_blocks: bool = False,
167
+ ) -> list[torch.Tensor]:
168
+ assert (
169
+ len(query_features) == 1
170
+ ), "at export time, segmentation head expects exactly one query feature"
171
+
172
+ target_size = (
173
+ image_size[0] // self.downsample_ratio,
174
+ image_size[1] // self.downsample_ratio,
175
+ )
176
+ spatial_features = F.interpolate(
177
+ spatial_features, size=target_size, mode="bilinear", align_corners=False
178
+ )
179
+
180
+ if not skip_blocks:
181
+ for block in self.blocks:
182
+ spatial_features = block(spatial_features)
183
+
184
+ spatial_features_proj = self.spatial_features_proj(spatial_features)
185
+
186
+ qf = self.query_features_proj(self.query_features_block(query_features[0]))
187
+ return [torch.einsum("bchw,bnc->bnhw", spatial_features_proj, qf) + self.bias]
188
+
189
+
190
+ def point_sample(input, point_coords, **kwargs):
191
+ """
192
+ A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors.
193
+ Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside
194
+ [0, 1] x [0, 1] square.
195
+ Args:
196
+ input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid.
197
+ point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains
198
+ [0, 1] x [0, 1] normalized point coordinates.
199
+ Returns:
200
+ output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains
201
+ features for points in `point_coords`. The features are obtained via bilinear
202
+ interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`.
203
+ """
204
+ add_dim = False
205
+ if point_coords.dim() == 3:
206
+ add_dim = True
207
+ point_coords = point_coords.unsqueeze(2)
208
+ output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs)
209
+ if add_dim:
210
+ output = output.squeeze(3)
211
+ return output
212
+
213
+
214
+ def get_uncertain_point_coords_with_randomness(
215
+ coarse_logits,
216
+ uncertainty_func,
217
+ num_points,
218
+ oversample_ratio=3,
219
+ importance_sample_ratio=0.75,
220
+ ):
221
+ """
222
+ Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties
223
+ are calculated for each point using 'uncertainty_func' function that takes point's logit
224
+ prediction as input.
225
+ See PointRend paper for details.
226
+ Args:
227
+ coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for
228
+ class-specific or class-agnostic prediction.
229
+ uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that
230
+ contains logit predictions for P points and returns their uncertainties as a Tensor of
231
+ shape (N, 1, P).
232
+ num_points (int): The number of points P to sample.
233
+ oversample_ratio (int): Oversampling parameter.
234
+ importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling.
235
+ Returns:
236
+ point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P
237
+ sampled points.
238
+ """
239
+ assert oversample_ratio >= 1
240
+ assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0
241
+ num_boxes = coarse_logits.shape[0]
242
+ num_sampled = int(num_points * oversample_ratio)
243
+ point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device)
244
+ point_logits = point_sample(coarse_logits, point_coords, align_corners=False)
245
+ # It is crucial to calculate uncertainty based on the sampled prediction value for the points.
246
+ # Calculating uncertainties of the coarse predictions first and sampling them for points leads
247
+ # to incorrect results.
248
+ # To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between
249
+ # two coarse predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value.
250
+ # However, if we calculate uncertainties for the coarse predictions first,
251
+ # both will have -1 uncertainty, and the sampled point will get -1 uncertainty.
252
+ point_uncertainties = uncertainty_func(point_logits)
253
+ num_uncertain_points = int(importance_sample_ratio * num_points)
254
+ num_random_points = num_points - num_uncertain_points
255
+ idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
256
+ shift = num_sampled * torch.arange(
257
+ num_boxes, dtype=torch.long, device=coarse_logits.device
258
+ )
259
+ idx += shift[:, None]
260
+ point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(
261
+ num_boxes, num_uncertain_points, 2
262
+ )
263
+ if num_random_points > 0:
264
+ point_coords = torch.cat(
265
+ [
266
+ point_coords,
267
+ torch.rand(
268
+ num_boxes, num_random_points, 2, device=coarse_logits.device
269
+ ),
270
+ ],
271
+ dim=1,
272
+ )
273
+ return point_coords