inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,26 @@
1
+ from typing import Optional
2
+
3
+ from torch import Tensor
4
+
5
+
6
+ class NestedTensor(object):
7
+ def __init__(self, tensors, mask: Optional[Tensor]):
8
+ self.tensors = tensors
9
+ self.mask = mask
10
+
11
+ def to(self, device):
12
+ # type: (Device) -> NestedTensor # noqa
13
+ cast_tensor = self.tensors.to(device)
14
+ mask = self.mask
15
+ if mask is not None:
16
+ assert mask is not None
17
+ cast_mask = mask.to(device)
18
+ else:
19
+ cast_mask = None
20
+ return NestedTensor(cast_tensor, cast_mask)
21
+
22
+ def decompose(self):
23
+ return self.tensors, self.mask
24
+
25
+ def __repr__(self):
26
+ return str(self.tensors)
@@ -0,0 +1,180 @@
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
7
+ # Copyright (c) 2024 Baidu. All Rights Reserved.
8
+ # ------------------------------------------------------------------------------------------------
9
+ # Modified from Deformable DETR
10
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
11
+ # ------------------------------------------------------------------------------------------------
12
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
13
+ # ------------------------------------------------------------------------------------------------
14
+ """
15
+ Multi-Scale Deformable Attention Module
16
+ """
17
+
18
+ from __future__ import absolute_import, division, print_function
19
+
20
+ import math
21
+ import warnings
22
+
23
+ import numpy as np
24
+ import torch
25
+ import torch.nn.functional as F
26
+ from torch import nn
27
+ from torch.nn.init import constant_, xavier_uniform_
28
+
29
+ from inference_models.models.rfdetr.ms_deform_attn_func import (
30
+ ms_deform_attn_core_pytorch,
31
+ )
32
+
33
+
34
+ def _is_power_of_2(n):
35
+ if (not isinstance(n, int)) or (n < 0):
36
+ raise ValueError(
37
+ "invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))
38
+ )
39
+ return (n & (n - 1) == 0) and n != 0
40
+
41
+
42
+ class MSDeformAttn(nn.Module):
43
+ """Multi-Scale Deformable Attention Module"""
44
+
45
+ def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
46
+ """
47
+ Multi-Scale Deformable Attention Module
48
+ :param d_model hidden dimension
49
+ :param n_levels number of feature levels
50
+ :param n_heads number of attention heads
51
+ :param n_points number of sampling points per attention head per feature level
52
+ """
53
+ super().__init__()
54
+ if d_model % n_heads != 0:
55
+ raise ValueError(
56
+ "d_model must be divisible by n_heads, but got {} and {}".format(
57
+ d_model, n_heads
58
+ )
59
+ )
60
+ _d_per_head = d_model // n_heads
61
+ # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
62
+ if not _is_power_of_2(_d_per_head):
63
+ warnings.warn(
64
+ "You'd better set d_model in MSDeformAttn to make the "
65
+ "dimension of each attention head a power of 2 "
66
+ "which is more efficient in our CUDA implementation."
67
+ )
68
+
69
+ self.im2col_step = 64
70
+
71
+ self.d_model = d_model
72
+ self.n_levels = n_levels
73
+ self.n_heads = n_heads
74
+ self.n_points = n_points
75
+
76
+ self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
77
+ self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
78
+ self.value_proj = nn.Linear(d_model, d_model)
79
+ self.output_proj = nn.Linear(d_model, d_model)
80
+
81
+ self._reset_parameters()
82
+
83
+ self._export = False
84
+
85
+ def export(self):
86
+ """export mode"""
87
+ self._export = True
88
+
89
+ def _reset_parameters(self):
90
+ constant_(self.sampling_offsets.weight.data, 0.0)
91
+ thetas = torch.arange(self.n_heads, dtype=torch.float32) * (
92
+ 2.0 * math.pi / self.n_heads
93
+ )
94
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
95
+ grid_init = (
96
+ (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
97
+ .view(self.n_heads, 1, 1, 2)
98
+ .repeat(1, self.n_levels, self.n_points, 1)
99
+ )
100
+ for i in range(self.n_points):
101
+ grid_init[:, :, i, :] *= i + 1
102
+ with torch.no_grad():
103
+ self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
104
+ constant_(self.attention_weights.weight.data, 0.0)
105
+ constant_(self.attention_weights.bias.data, 0.0)
106
+ xavier_uniform_(self.value_proj.weight.data)
107
+ constant_(self.value_proj.bias.data, 0.0)
108
+ xavier_uniform_(self.output_proj.weight.data)
109
+ constant_(self.output_proj.bias.data, 0.0)
110
+
111
+ def forward(
112
+ self,
113
+ query,
114
+ reference_points,
115
+ input_flatten,
116
+ input_spatial_shapes,
117
+ input_level_start_index,
118
+ input_padding_mask=None,
119
+ ):
120
+ """
121
+ :param query (N, Length_{query}, C)
122
+ :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
123
+ or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
124
+ :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
125
+ :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
126
+ :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
127
+ :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
128
+
129
+ :return output (N, Length_{query}, C)
130
+ """
131
+ N, Len_q, _ = query.shape
132
+ N, Len_in, _ = input_flatten.shape
133
+ assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
134
+
135
+ value = self.value_proj(input_flatten)
136
+ if input_padding_mask is not None:
137
+ value = value.masked_fill(input_padding_mask[..., None], float(0))
138
+
139
+ sampling_offsets = self.sampling_offsets(query).view(
140
+ N, Len_q, self.n_heads, self.n_levels, self.n_points, 2
141
+ )
142
+ attention_weights = self.attention_weights(query).view(
143
+ N, Len_q, self.n_heads, self.n_levels * self.n_points
144
+ )
145
+
146
+ # N, Len_q, n_heads, n_levels, n_points, 2
147
+ if reference_points.shape[-1] == 2:
148
+ offset_normalizer = torch.stack(
149
+ [input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1
150
+ )
151
+ sampling_locations = (
152
+ reference_points[:, :, None, :, None, :]
153
+ + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
154
+ )
155
+ elif reference_points.shape[-1] == 4:
156
+ sampling_locations = (
157
+ reference_points[:, :, None, :, None, :2]
158
+ + sampling_offsets
159
+ / self.n_points
160
+ * reference_points[:, :, None, :, None, 2:]
161
+ * 0.5
162
+ )
163
+ else:
164
+ raise ValueError(
165
+ "Last dim of reference_points must be 2 or 4, but get {} instead.".format(
166
+ reference_points.shape[-1]
167
+ )
168
+ )
169
+ attention_weights = F.softmax(attention_weights, -1)
170
+
171
+ value = (
172
+ value.transpose(1, 2)
173
+ .contiguous()
174
+ .view(N, self.n_heads, self.d_model // self.n_heads, Len_in)
175
+ )
176
+ output = ms_deform_attn_core_pytorch(
177
+ value, input_spatial_shapes, sampling_locations, attention_weights
178
+ )
179
+ output = self.output_proj(output)
180
+ return output
@@ -0,0 +1,60 @@
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
7
+ # Copyright (c) 2024 Baidu. All Rights Reserved.
8
+ # ------------------------------------------------------------------------------------------------
9
+ # Modified from Deformable DETR
10
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
11
+ # ------------------------------------------------------------------------------------------------
12
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
13
+ # ------------------------------------------------------------------------------------------------
14
+ """
15
+ ms_deform_attn_func
16
+ """
17
+ from __future__ import absolute_import, division, print_function
18
+
19
+ import torch
20
+ import torch.nn.functional as F
21
+ from torch.autograd import Function
22
+ from torch.autograd.function import once_differentiable
23
+
24
+
25
+ def ms_deform_attn_core_pytorch(
26
+ value, value_spatial_shapes, sampling_locations, attention_weights
27
+ ):
28
+ """ "for debug and test only, need to use cuda version instead"""
29
+ # B, n_heads, head_dim, N
30
+ B, n_heads, head_dim, _ = value.shape
31
+ _, Len_q, n_heads, L, P, _ = sampling_locations.shape
32
+ value_list = value.split([H * W for H, W in value_spatial_shapes], dim=3)
33
+ sampling_grids = 2 * sampling_locations - 1
34
+ sampling_value_list = []
35
+ for lid_, (H, W) in enumerate(value_spatial_shapes):
36
+ # B, n_heads, head_dim, H, W
37
+ value_l_ = value_list[lid_].view(B * n_heads, head_dim, H, W)
38
+ # B, Len_q, n_heads, P, 2 -> B, n_heads, Len_q, P, 2 -> B*n_heads, Len_q, P, 2
39
+ sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
40
+ # B*n_heads, head_dim, Len_q, P
41
+ sampling_value_l_ = F.grid_sample(
42
+ value_l_,
43
+ sampling_grid_l_,
44
+ mode="bilinear",
45
+ padding_mode="zeros",
46
+ align_corners=False,
47
+ )
48
+ sampling_value_list.append(sampling_value_l_)
49
+ # (B, Len_q, n_heads, L * P) -> (B, n_heads, Len_q, L, P) -> (B*n_heads, 1, Len_q, L*P)
50
+ attention_weights = attention_weights.transpose(1, 2).reshape(
51
+ B * n_heads, 1, Len_q, L * P
52
+ )
53
+ # B*n_heads, head_dim, Len_q, L*P
54
+ sampling_value_list = torch.stack(sampling_value_list, dim=-2).flatten(-2)
55
+ output = (
56
+ (sampling_value_list * attention_weights)
57
+ .sum(-1)
58
+ .view(B, n_heads * head_dim, Len_q)
59
+ )
60
+ return output.transpose(1, 2).contiguous()
@@ -0,0 +1,166 @@
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
7
+ # Copyright (c) 2024 Baidu. All Rights Reserved.
8
+ # ------------------------------------------------------------------------
9
+ # Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR)
10
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
11
+ # ------------------------------------------------------------------------
12
+ # Copied from DETR (https://github.com/facebookresearch/detr)
13
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
14
+ # ------------------------------------------------------------------------
15
+
16
+ """
17
+ Various positional encodings for the transformer.
18
+ """
19
+ import math
20
+
21
+ import torch
22
+ from torch import nn
23
+
24
+ from inference_models.models.rfdetr.misc import NestedTensor
25
+
26
+
27
+ class PositionEmbeddingSine(nn.Module):
28
+ """
29
+ This is a more standard version of the position embedding, very similar to the one
30
+ used by the Attention is all you need paper, generalized to work on images.
31
+ """
32
+
33
+ def __init__(
34
+ self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
35
+ ):
36
+ super().__init__()
37
+ self.num_pos_feats = num_pos_feats
38
+ self.temperature = temperature
39
+ self.normalize = normalize
40
+ if scale is not None and normalize is False:
41
+ raise ValueError("normalize should be True if scale is passed")
42
+ if scale is None:
43
+ scale = 2 * math.pi
44
+ self.scale = scale
45
+ self._export = False
46
+
47
+ def export(self):
48
+ self._export = True
49
+ self._forward_origin = self.forward
50
+ self.forward = self.forward_export
51
+
52
+ def forward(self, tensor_list: NestedTensor, align_dim_orders=True):
53
+ x = tensor_list.tensors
54
+ mask = tensor_list.mask
55
+ assert mask is not None
56
+ not_mask = ~mask
57
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
58
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
59
+ if self.normalize:
60
+ eps = 1e-6
61
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
62
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
63
+
64
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
65
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
66
+
67
+ pos_x = x_embed[:, :, :, None] / dim_t
68
+ pos_y = y_embed[:, :, :, None] / dim_t
69
+ pos_x = torch.stack(
70
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
71
+ ).flatten(3)
72
+ pos_y = torch.stack(
73
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
74
+ ).flatten(3)
75
+ if align_dim_orders:
76
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(1, 2, 0, 3)
77
+ # return: (H, W, bs, C)
78
+ else:
79
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
80
+ # return: (bs, C, H, W)
81
+ return pos
82
+
83
+ def forward_export(self, mask: torch.Tensor, align_dim_orders=True):
84
+ assert mask is not None
85
+ not_mask = ~mask
86
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
87
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
88
+ if self.normalize:
89
+ eps = 1e-6
90
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
91
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
92
+
93
+ dim_t = torch.arange(
94
+ self.num_pos_feats, dtype=torch.float32, device=mask.device
95
+ )
96
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
97
+
98
+ pos_x = x_embed[:, :, :, None] / dim_t
99
+ pos_y = y_embed[:, :, :, None] / dim_t
100
+ pos_x = torch.stack(
101
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
102
+ ).flatten(3)
103
+ pos_y = torch.stack(
104
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
105
+ ).flatten(3)
106
+ if align_dim_orders:
107
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(1, 2, 0, 3)
108
+ # return: (H, W, bs, C)
109
+ else:
110
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
111
+ # return: (bs, C, H, W)
112
+ return pos
113
+
114
+
115
+ class PositionEmbeddingLearned(nn.Module):
116
+ """
117
+ Absolute pos embedding, learned.
118
+ """
119
+
120
+ def __init__(self, num_pos_feats=256):
121
+ super().__init__()
122
+ self.row_embed = nn.Embedding(50, num_pos_feats)
123
+ self.col_embed = nn.Embedding(50, num_pos_feats)
124
+ self.reset_parameters()
125
+ self._export = False
126
+
127
+ def export(self):
128
+ raise NotImplementedError
129
+
130
+ def reset_parameters(self):
131
+ nn.init.uniform_(self.row_embed.weight)
132
+ nn.init.uniform_(self.col_embed.weight)
133
+
134
+ def forward(self, tensor_list: NestedTensor):
135
+ x = tensor_list.tensors
136
+ h, w = x.shape[:2]
137
+ i = torch.arange(w, device=x.device)
138
+ j = torch.arange(h, device=x.device)
139
+ x_emb = self.col_embed(i)
140
+ y_emb = self.row_embed(j)
141
+ pos = (
142
+ torch.cat(
143
+ [
144
+ x_emb.unsqueeze(0).repeat(h, 1, 1),
145
+ y_emb.unsqueeze(1).repeat(1, w, 1),
146
+ ],
147
+ dim=-1,
148
+ )
149
+ .unsqueeze(2)
150
+ .repeat(1, 1, x.shape[2], 1)
151
+ )
152
+ # return: (H, W, bs, C)
153
+ return pos
154
+
155
+
156
+ def build_position_encoding(hidden_dim, position_embedding):
157
+ N_steps = hidden_dim // 2
158
+ if position_embedding in ("v2", "sine"):
159
+ # TODO find a better way of exposing other arguments
160
+ position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
161
+ elif position_embedding in ("v3", "learned"):
162
+ position_embedding = PositionEmbeddingLearned(N_steps)
163
+ else:
164
+ raise ValueError(f"not supported {position_embedding}")
165
+
166
+ return position_embedding
@@ -0,0 +1,83 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+
7
+
8
+ def box_cxcywh_to_xyxy(x):
9
+ x_c, y_c, w, h = x.unbind(-1)
10
+ b = [
11
+ (x_c - 0.5 * w.clamp(min=0.0)),
12
+ (y_c - 0.5 * h.clamp(min=0.0)),
13
+ (x_c + 0.5 * w.clamp(min=0.0)),
14
+ (y_c + 0.5 * h.clamp(min=0.0)),
15
+ ]
16
+ return torch.stack(b, dim=-1)
17
+
18
+
19
+ class PostProcess(nn.Module):
20
+ """This module converts the model's output into the format expected by the coco api"""
21
+
22
+ def __init__(self, num_select=300) -> None:
23
+ super().__init__()
24
+ self.num_select = num_select
25
+
26
+ @torch.no_grad()
27
+ def forward(self, outputs, target_sizes):
28
+ """Perform the computation
29
+ Parameters:
30
+ outputs: raw outputs of the model
31
+ target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
32
+ For evaluation, this must be the original image size (before any data augmentation)
33
+ For visualization, this should be the image size after data augment, but before padding
34
+ """
35
+ out_logits, out_bbox = outputs["pred_logits"], outputs["pred_boxes"]
36
+ out_masks = outputs.get("pred_masks", None)
37
+
38
+ assert len(out_logits) == len(target_sizes)
39
+ assert target_sizes.shape[1] == 2
40
+
41
+ prob = out_logits.sigmoid()
42
+ topk_values, topk_indexes = torch.topk(
43
+ prob.view(out_logits.shape[0], -1), self.num_select, dim=1
44
+ )
45
+ scores = topk_values
46
+ topk_boxes = topk_indexes // out_logits.shape[2]
47
+ labels = topk_indexes % out_logits.shape[2]
48
+ boxes = box_cxcywh_to_xyxy(out_bbox)
49
+ boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
50
+
51
+ # and from relative [0, 1] to absolute [0, height] coordinates
52
+ img_h, img_w = target_sizes.unbind(1)
53
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
54
+ boxes = boxes * scale_fct[:, None, :]
55
+
56
+ results = []
57
+ if out_masks is not None:
58
+ for i in range(out_masks.shape[0]):
59
+ res_i = {"scores": scores[i], "labels": labels[i], "boxes": boxes[i]}
60
+ k_idx = topk_boxes[i]
61
+ masks_i = torch.gather(
62
+ out_masks[i],
63
+ 0,
64
+ k_idx.unsqueeze(-1)
65
+ .unsqueeze(-1)
66
+ .repeat(1, out_masks.shape[-2], out_masks.shape[-1]),
67
+ ) # [K, Hm, Wm]
68
+ h, w = target_sizes[i].tolist()
69
+ masks_i = F.interpolate(
70
+ masks_i.unsqueeze(1),
71
+ size=(int(h), int(w)),
72
+ mode="bilinear",
73
+ align_corners=False,
74
+ ) # [K,1,H,W]
75
+ res_i["masks"] = masks_i > 0.0
76
+ results.append(res_i)
77
+ else:
78
+ results = [
79
+ {"scores": s, "labels": l, "boxes": b}
80
+ for s, l, b in zip(scores, labels, boxes)
81
+ ]
82
+
83
+ return results