inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,767 @@
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
7
+ # Copyright (c) 2024 Baidu. All Rights Reserved.
8
+ # ------------------------------------------------------------------------
9
+ # Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR)
10
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
11
+ # ------------------------------------------------------------------------
12
+ # Modified from DETR (https://github.com/facebookresearch/detr)
13
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
14
+ # ------------------------------------------------------------------------
15
+ """
16
+ Transformer class
17
+ """
18
+ import copy
19
+ import math
20
+ from typing import Optional
21
+
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from torch import Tensor, nn
25
+
26
+ from inference_models.models.rfdetr.ms_deform_attn import MSDeformAttn
27
+
28
+
29
+ class MLP(nn.Module):
30
+ """Very simple multi-layer perceptron (also called FFN)"""
31
+
32
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
33
+ super().__init__()
34
+ self.num_layers = num_layers
35
+ h = [hidden_dim] * (num_layers - 1)
36
+ self.layers = nn.ModuleList(
37
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
38
+ )
39
+
40
+ def forward(self, x):
41
+ for i, layer in enumerate(self.layers):
42
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
43
+ return x
44
+
45
+
46
+ def gen_sineembed_for_position(pos_tensor, dim=128):
47
+ # n_query, bs, _ = pos_tensor.size()
48
+ # sineembed_tensor = torch.zeros(n_query, bs, 256)
49
+ scale = 2 * math.pi
50
+ dim_t = torch.arange(dim, dtype=pos_tensor.dtype, device=pos_tensor.device)
51
+ dim_t = 10000 ** (2 * (dim_t // 2) / dim)
52
+ x_embed = pos_tensor[:, :, 0] * scale
53
+ y_embed = pos_tensor[:, :, 1] * scale
54
+ pos_x = x_embed[:, :, None] / dim_t
55
+ pos_y = y_embed[:, :, None] / dim_t
56
+ pos_x = torch.stack(
57
+ (pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3
58
+ ).flatten(2)
59
+ pos_y = torch.stack(
60
+ (pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3
61
+ ).flatten(2)
62
+ if pos_tensor.size(-1) == 2:
63
+ pos = torch.cat((pos_y, pos_x), dim=2)
64
+ elif pos_tensor.size(-1) == 4:
65
+ w_embed = pos_tensor[:, :, 2] * scale
66
+ pos_w = w_embed[:, :, None] / dim_t
67
+ pos_w = torch.stack(
68
+ (pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3
69
+ ).flatten(2)
70
+
71
+ h_embed = pos_tensor[:, :, 3] * scale
72
+ pos_h = h_embed[:, :, None] / dim_t
73
+ pos_h = torch.stack(
74
+ (pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3
75
+ ).flatten(2)
76
+
77
+ pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
78
+ else:
79
+ raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1)))
80
+ return pos
81
+
82
+
83
+ def gen_encoder_output_proposals(
84
+ memory, memory_padding_mask, spatial_shapes, unsigmoid=True
85
+ ):
86
+ """
87
+ Input:
88
+ - memory: bs, \sum{hw}, d_model
89
+ - memory_padding_mask: bs, \sum{hw}
90
+ - spatial_shapes: nlevel, 2
91
+ Output:
92
+ - output_memory: bs, \sum{hw}, d_model
93
+ - output_proposals: bs, \sum{hw}, 4
94
+ """
95
+ N_, S_, C_ = memory.shape
96
+ base_scale = 4.0
97
+ proposals = []
98
+ _cur = 0
99
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
100
+ if memory_padding_mask is not None:
101
+ mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H_ * W_)].view(
102
+ N_, H_, W_, 1
103
+ )
104
+ valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
105
+ valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
106
+ else:
107
+ valid_H = torch.tensor([H_ for _ in range(N_)], device=memory.device)
108
+ valid_W = torch.tensor([W_ for _ in range(N_)], device=memory.device)
109
+
110
+ grid_y, grid_x = torch.meshgrid(
111
+ torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
112
+ torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device),
113
+ )
114
+ grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2
115
+
116
+ scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(
117
+ N_, 1, 1, 2
118
+ )
119
+ grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
120
+
121
+ wh = torch.ones_like(grid) * 0.05 * (2.0**lvl)
122
+
123
+ proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
124
+ proposals.append(proposal)
125
+ _cur += H_ * W_
126
+
127
+ output_proposals = torch.cat(proposals, 1)
128
+ output_proposals_valid = (
129
+ (output_proposals > 0.01) & (output_proposals < 0.99)
130
+ ).all(-1, keepdim=True)
131
+
132
+ if unsigmoid:
133
+ output_proposals = torch.log(
134
+ output_proposals / (1 - output_proposals)
135
+ ) # unsigmoid
136
+ if memory_padding_mask is not None:
137
+ output_proposals = output_proposals.masked_fill(
138
+ memory_padding_mask.unsqueeze(-1), float("inf")
139
+ )
140
+ output_proposals = output_proposals.masked_fill(
141
+ ~output_proposals_valid, float("inf")
142
+ )
143
+ else:
144
+ if memory_padding_mask is not None:
145
+ output_proposals = output_proposals.masked_fill(
146
+ memory_padding_mask.unsqueeze(-1), float(0)
147
+ )
148
+ output_proposals = output_proposals.masked_fill(
149
+ ~output_proposals_valid, float(0)
150
+ )
151
+
152
+ output_memory = memory
153
+ if memory_padding_mask is not None:
154
+ output_memory = output_memory.masked_fill(
155
+ memory_padding_mask.unsqueeze(-1), float(0)
156
+ )
157
+ output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
158
+
159
+ return output_memory.to(memory.dtype), output_proposals.to(memory.dtype)
160
+
161
+
162
+ class Transformer(nn.Module):
163
+
164
+ def __init__(
165
+ self,
166
+ d_model=512,
167
+ sa_nhead=8,
168
+ ca_nhead=8,
169
+ num_queries=300,
170
+ num_decoder_layers=6,
171
+ dim_feedforward=2048,
172
+ dropout=0.0,
173
+ activation="relu",
174
+ normalize_before=False,
175
+ return_intermediate_dec=False,
176
+ group_detr=1,
177
+ two_stage=False,
178
+ num_feature_levels=4,
179
+ dec_n_points=4,
180
+ lite_refpoint_refine=False,
181
+ decoder_norm_type="LN",
182
+ bbox_reparam=False,
183
+ ):
184
+ super().__init__()
185
+ self.encoder = None
186
+
187
+ decoder_layer = TransformerDecoderLayer(
188
+ d_model,
189
+ sa_nhead,
190
+ ca_nhead,
191
+ dim_feedforward,
192
+ dropout,
193
+ activation,
194
+ normalize_before,
195
+ group_detr=group_detr,
196
+ num_feature_levels=num_feature_levels,
197
+ dec_n_points=dec_n_points,
198
+ skip_self_attn=False,
199
+ )
200
+ assert decoder_norm_type in ["LN", "Identity"]
201
+ norm = {
202
+ "LN": lambda channels: nn.LayerNorm(channels),
203
+ "Identity": lambda channels: nn.Identity(),
204
+ }
205
+ decoder_norm = norm[decoder_norm_type](d_model)
206
+
207
+ self.decoder = TransformerDecoder(
208
+ decoder_layer,
209
+ num_decoder_layers,
210
+ decoder_norm,
211
+ return_intermediate=return_intermediate_dec,
212
+ d_model=d_model,
213
+ lite_refpoint_refine=lite_refpoint_refine,
214
+ bbox_reparam=bbox_reparam,
215
+ )
216
+
217
+ self.two_stage = two_stage
218
+ if two_stage:
219
+ self.enc_output = nn.ModuleList(
220
+ [nn.Linear(d_model, d_model) for _ in range(group_detr)]
221
+ )
222
+ self.enc_output_norm = nn.ModuleList(
223
+ [nn.LayerNorm(d_model) for _ in range(group_detr)]
224
+ )
225
+
226
+ self._reset_parameters()
227
+
228
+ self.num_queries = num_queries
229
+ self.d_model = d_model
230
+ self.dec_layers = num_decoder_layers
231
+ self.group_detr = group_detr
232
+ self.num_feature_levels = num_feature_levels
233
+ self.bbox_reparam = bbox_reparam
234
+
235
+ self._export = False
236
+
237
+ def export(self):
238
+ self._export = True
239
+
240
+ def _reset_parameters(self):
241
+ for p in self.parameters():
242
+ if p.dim() > 1:
243
+ nn.init.xavier_uniform_(p)
244
+ for m in self.modules():
245
+ if isinstance(m, MSDeformAttn):
246
+ m._reset_parameters()
247
+
248
+ def get_valid_ratio(self, mask):
249
+ _, H, W = mask.shape
250
+ valid_H = torch.sum(~mask[:, :, 0], 1)
251
+ valid_W = torch.sum(~mask[:, 0, :], 1)
252
+ valid_ratio_h = valid_H.float() / H
253
+ valid_ratio_w = valid_W.float() / W
254
+ valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
255
+ return valid_ratio
256
+
257
+ def forward(self, srcs, masks, pos_embeds, refpoint_embed, query_feat):
258
+ src_flatten = []
259
+ mask_flatten = [] if masks is not None else None
260
+ lvl_pos_embed_flatten = []
261
+ spatial_shapes = []
262
+ valid_ratios = [] if masks is not None else None
263
+ for lvl, (src, pos_embed) in enumerate(zip(srcs, pos_embeds)):
264
+ bs, c, h, w = src.shape
265
+ spatial_shape = (h, w)
266
+ spatial_shapes.append(spatial_shape)
267
+
268
+ src = src.flatten(2).transpose(1, 2) # bs, hw, c
269
+ pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c
270
+ lvl_pos_embed_flatten.append(pos_embed)
271
+ src_flatten.append(src)
272
+ if masks is not None:
273
+ mask = masks[lvl].flatten(1) # bs, hw
274
+ mask_flatten.append(mask)
275
+ memory = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c
276
+ if masks is not None:
277
+ mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw}
278
+ valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
279
+ lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # bs, \sum{hxw}, c
280
+ spatial_shapes = torch.as_tensor(
281
+ spatial_shapes, dtype=torch.long, device=memory.device
282
+ )
283
+ level_start_index = torch.cat(
284
+ (spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])
285
+ )
286
+
287
+ if self.two_stage:
288
+ output_memory, output_proposals = gen_encoder_output_proposals(
289
+ memory, mask_flatten, spatial_shapes, unsigmoid=not self.bbox_reparam
290
+ )
291
+ # group detr for first stage
292
+ refpoint_embed_ts, memory_ts, boxes_ts = [], [], []
293
+ group_detr = self.group_detr if self.training else 1
294
+ for g_idx in range(group_detr):
295
+ output_memory_gidx = self.enc_output_norm[g_idx](
296
+ self.enc_output[g_idx](output_memory)
297
+ )
298
+
299
+ enc_outputs_class_unselected_gidx = self.enc_out_class_embed[g_idx](
300
+ output_memory_gidx
301
+ )
302
+ if self.bbox_reparam:
303
+ enc_outputs_coord_delta_gidx = self.enc_out_bbox_embed[g_idx](
304
+ output_memory_gidx
305
+ )
306
+ enc_outputs_coord_cxcy_gidx = (
307
+ enc_outputs_coord_delta_gidx[..., :2]
308
+ * output_proposals[..., 2:]
309
+ + output_proposals[..., :2]
310
+ )
311
+ enc_outputs_coord_wh_gidx = (
312
+ enc_outputs_coord_delta_gidx[..., 2:].exp()
313
+ * output_proposals[..., 2:]
314
+ )
315
+ enc_outputs_coord_unselected_gidx = torch.concat(
316
+ [enc_outputs_coord_cxcy_gidx, enc_outputs_coord_wh_gidx], dim=-1
317
+ )
318
+ else:
319
+ enc_outputs_coord_unselected_gidx = (
320
+ self.enc_out_bbox_embed[g_idx](output_memory_gidx)
321
+ + output_proposals
322
+ ) # (bs, \sum{hw}, 4) unsigmoid
323
+
324
+ topk = min(
325
+ self.num_queries, enc_outputs_class_unselected_gidx.shape[-2]
326
+ )
327
+ topk_proposals_gidx = torch.topk(
328
+ enc_outputs_class_unselected_gidx.max(-1)[0], topk, dim=1
329
+ )[
330
+ 1
331
+ ] # bs, nq
332
+
333
+ refpoint_embed_gidx_undetach = torch.gather(
334
+ enc_outputs_coord_unselected_gidx,
335
+ 1,
336
+ topk_proposals_gidx.unsqueeze(-1).repeat(1, 1, 4),
337
+ ) # unsigmoid
338
+ # for decoder layer, detached as initial ones, (bs, nq, 4)
339
+ refpoint_embed_gidx = refpoint_embed_gidx_undetach.detach()
340
+
341
+ # get memory tgt
342
+ tgt_undetach_gidx = torch.gather(
343
+ output_memory_gidx,
344
+ 1,
345
+ topk_proposals_gidx.unsqueeze(-1).repeat(1, 1, self.d_model),
346
+ )
347
+
348
+ refpoint_embed_ts.append(refpoint_embed_gidx)
349
+ memory_ts.append(tgt_undetach_gidx)
350
+ boxes_ts.append(refpoint_embed_gidx_undetach)
351
+ # concat on dim=1, the nq dimension, (bs, nq, d) --> (bs, nq, d)
352
+ refpoint_embed_ts = torch.cat(refpoint_embed_ts, dim=1)
353
+ # (bs, nq, d)
354
+ memory_ts = torch.cat(memory_ts, dim=1) # .transpose(0, 1)
355
+ boxes_ts = torch.cat(boxes_ts, dim=1) # .transpose(0, 1)
356
+
357
+ if self.dec_layers > 0:
358
+ tgt = query_feat.unsqueeze(0).repeat(bs, 1, 1)
359
+ refpoint_embed = refpoint_embed.unsqueeze(0).repeat(bs, 1, 1)
360
+ if self.two_stage:
361
+ ts_len = refpoint_embed_ts.shape[-2]
362
+ refpoint_embed_ts_subset = refpoint_embed[..., :ts_len, :]
363
+ refpoint_embed_subset = refpoint_embed[..., ts_len:, :]
364
+
365
+ if self.bbox_reparam:
366
+ refpoint_embed_cxcy = (
367
+ refpoint_embed_ts_subset[..., :2] * refpoint_embed_ts[..., 2:]
368
+ )
369
+ refpoint_embed_cxcy = (
370
+ refpoint_embed_cxcy + refpoint_embed_ts[..., :2]
371
+ )
372
+ refpoint_embed_wh = (
373
+ refpoint_embed_ts_subset[..., 2:].exp()
374
+ * refpoint_embed_ts[..., 2:]
375
+ )
376
+ refpoint_embed_ts_subset = torch.concat(
377
+ [refpoint_embed_cxcy, refpoint_embed_wh], dim=-1
378
+ )
379
+ else:
380
+ refpoint_embed_ts_subset = (
381
+ refpoint_embed_ts_subset + refpoint_embed_ts
382
+ )
383
+
384
+ refpoint_embed = torch.concat(
385
+ [refpoint_embed_ts_subset, refpoint_embed_subset], dim=-2
386
+ )
387
+
388
+ hs, references = self.decoder(
389
+ tgt,
390
+ memory,
391
+ memory_key_padding_mask=mask_flatten,
392
+ pos=lvl_pos_embed_flatten,
393
+ refpoints_unsigmoid=refpoint_embed,
394
+ level_start_index=level_start_index,
395
+ spatial_shapes=spatial_shapes,
396
+ valid_ratios=(
397
+ valid_ratios.to(memory.dtype)
398
+ if valid_ratios is not None
399
+ else valid_ratios
400
+ ),
401
+ )
402
+ else:
403
+ assert self.two_stage, "if not using decoder, two_stage must be True"
404
+ hs = None
405
+ references = None
406
+
407
+ if self.two_stage:
408
+ if self.bbox_reparam:
409
+ return hs, references, memory_ts, boxes_ts
410
+ else:
411
+ return hs, references, memory_ts, boxes_ts.sigmoid()
412
+ return hs, references, None, None
413
+
414
+
415
+ class TransformerDecoder(nn.Module):
416
+
417
+ def __init__(
418
+ self,
419
+ decoder_layer,
420
+ num_layers,
421
+ norm=None,
422
+ return_intermediate=False,
423
+ d_model=256,
424
+ lite_refpoint_refine=False,
425
+ bbox_reparam=False,
426
+ ):
427
+ super().__init__()
428
+ self.layers = _get_clones(decoder_layer, num_layers)
429
+ self.num_layers = num_layers
430
+ self.d_model = d_model
431
+ self.norm = norm
432
+ self.return_intermediate = return_intermediate
433
+ self.lite_refpoint_refine = lite_refpoint_refine
434
+ self.bbox_reparam = bbox_reparam
435
+
436
+ self.ref_point_head = MLP(2 * d_model, d_model, d_model, 2)
437
+
438
+ self._export = False
439
+
440
+ def export(self):
441
+ self._export = True
442
+
443
+ def refpoints_refine(self, refpoints_unsigmoid, new_refpoints_delta):
444
+ if self.bbox_reparam:
445
+ new_refpoints_cxcy = (
446
+ new_refpoints_delta[..., :2] * refpoints_unsigmoid[..., 2:]
447
+ + refpoints_unsigmoid[..., :2]
448
+ )
449
+ new_refpoints_wh = (
450
+ new_refpoints_delta[..., 2:].exp() * refpoints_unsigmoid[..., 2:]
451
+ )
452
+ new_refpoints_unsigmoid = torch.concat(
453
+ [new_refpoints_cxcy, new_refpoints_wh], dim=-1
454
+ )
455
+ else:
456
+ new_refpoints_unsigmoid = refpoints_unsigmoid + new_refpoints_delta
457
+ return new_refpoints_unsigmoid
458
+
459
+ def forward(
460
+ self,
461
+ tgt,
462
+ memory,
463
+ tgt_mask: Optional[Tensor] = None,
464
+ memory_mask: Optional[Tensor] = None,
465
+ tgt_key_padding_mask: Optional[Tensor] = None,
466
+ memory_key_padding_mask: Optional[Tensor] = None,
467
+ pos: Optional[Tensor] = None,
468
+ refpoints_unsigmoid: Optional[Tensor] = None,
469
+ # for memory
470
+ level_start_index: Optional[Tensor] = None, # num_levels
471
+ spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
472
+ valid_ratios: Optional[Tensor] = None,
473
+ ):
474
+ output = tgt
475
+
476
+ intermediate = []
477
+ hs_refpoints_unsigmoid = [refpoints_unsigmoid]
478
+
479
+ def get_reference(refpoints):
480
+ # [num_queries, batch_size, 4]
481
+ obj_center = refpoints[..., :4]
482
+
483
+ if self._export:
484
+ query_sine_embed = gen_sineembed_for_position(
485
+ obj_center, self.d_model / 2
486
+ ) # bs, nq, 256*2
487
+ refpoints_input = obj_center[:, :, None] # bs, nq, 1, 4
488
+ else:
489
+ refpoints_input = (
490
+ obj_center[:, :, None]
491
+ * torch.cat([valid_ratios, valid_ratios], -1)[:, None]
492
+ ) # bs, nq, nlevel, 4
493
+ query_sine_embed = gen_sineembed_for_position(
494
+ refpoints_input[:, :, 0, :], self.d_model / 2
495
+ ) # bs, nq, 256*2
496
+ query_pos = self.ref_point_head(query_sine_embed)
497
+ return obj_center, refpoints_input, query_pos, query_sine_embed
498
+
499
+ # always use init refpoints
500
+ if self.lite_refpoint_refine:
501
+ if self.bbox_reparam:
502
+ obj_center, refpoints_input, query_pos, query_sine_embed = (
503
+ get_reference(refpoints_unsigmoid)
504
+ )
505
+ else:
506
+ obj_center, refpoints_input, query_pos, query_sine_embed = (
507
+ get_reference(refpoints_unsigmoid.sigmoid())
508
+ )
509
+
510
+ for layer_id, layer in enumerate(self.layers):
511
+ # iter refine each layer
512
+ if not self.lite_refpoint_refine:
513
+ if self.bbox_reparam:
514
+ obj_center, refpoints_input, query_pos, query_sine_embed = (
515
+ get_reference(refpoints_unsigmoid)
516
+ )
517
+ else:
518
+ obj_center, refpoints_input, query_pos, query_sine_embed = (
519
+ get_reference(refpoints_unsigmoid.sigmoid())
520
+ )
521
+
522
+ # For the first decoder layer, we do not apply transformation over p_s
523
+ pos_transformation = 1
524
+
525
+ query_pos = query_pos * pos_transformation
526
+
527
+ output = layer(
528
+ output,
529
+ memory,
530
+ tgt_mask=tgt_mask,
531
+ memory_mask=memory_mask,
532
+ tgt_key_padding_mask=tgt_key_padding_mask,
533
+ memory_key_padding_mask=memory_key_padding_mask,
534
+ pos=pos,
535
+ query_pos=query_pos,
536
+ query_sine_embed=query_sine_embed,
537
+ is_first=(layer_id == 0),
538
+ reference_points=refpoints_input,
539
+ spatial_shapes=spatial_shapes,
540
+ level_start_index=level_start_index,
541
+ )
542
+
543
+ if not self.lite_refpoint_refine:
544
+ # box iterative update
545
+ new_refpoints_delta = self.bbox_embed(output)
546
+ new_refpoints_unsigmoid = self.refpoints_refine(
547
+ refpoints_unsigmoid, new_refpoints_delta
548
+ )
549
+ if layer_id != self.num_layers - 1:
550
+ hs_refpoints_unsigmoid.append(new_refpoints_unsigmoid)
551
+ refpoints_unsigmoid = new_refpoints_unsigmoid.detach()
552
+
553
+ if self.return_intermediate:
554
+ intermediate.append(self.norm(output))
555
+
556
+ if self.norm is not None:
557
+ output = self.norm(output)
558
+ if self.return_intermediate:
559
+ intermediate.pop()
560
+ intermediate.append(output)
561
+
562
+ if self.return_intermediate:
563
+ if self._export:
564
+ # to shape: B, N, C
565
+ hs = intermediate[-1]
566
+ if self.bbox_embed is not None:
567
+ ref = hs_refpoints_unsigmoid[-1]
568
+ else:
569
+ ref = refpoints_unsigmoid
570
+ return hs, ref
571
+ # box iterative update
572
+ if self.bbox_embed is not None:
573
+ return [
574
+ torch.stack(intermediate),
575
+ torch.stack(hs_refpoints_unsigmoid),
576
+ ]
577
+ else:
578
+ return [torch.stack(intermediate), refpoints_unsigmoid.unsqueeze(0)]
579
+
580
+ return output.unsqueeze(0)
581
+
582
+
583
+ class TransformerDecoderLayer(nn.Module):
584
+
585
+ def __init__(
586
+ self,
587
+ d_model,
588
+ sa_nhead,
589
+ ca_nhead,
590
+ dim_feedforward=2048,
591
+ dropout=0.1,
592
+ activation="relu",
593
+ normalize_before=False,
594
+ group_detr=1,
595
+ num_feature_levels=4,
596
+ dec_n_points=4,
597
+ skip_self_attn=False,
598
+ ):
599
+ super().__init__()
600
+ # Decoder Self-Attention
601
+ self.self_attn = nn.MultiheadAttention(
602
+ embed_dim=d_model, num_heads=sa_nhead, dropout=dropout, batch_first=True
603
+ )
604
+ self.dropout1 = nn.Dropout(dropout)
605
+ self.norm1 = nn.LayerNorm(d_model)
606
+
607
+ # Decoder Cross-Attention
608
+ self.cross_attn = MSDeformAttn(
609
+ d_model,
610
+ n_levels=num_feature_levels,
611
+ n_heads=ca_nhead,
612
+ n_points=dec_n_points,
613
+ )
614
+
615
+ self.nhead = ca_nhead
616
+
617
+ # Implementation of Feedforward model
618
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
619
+ self.dropout = nn.Dropout(dropout)
620
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
621
+
622
+ self.norm2 = nn.LayerNorm(d_model)
623
+ self.norm3 = nn.LayerNorm(d_model)
624
+
625
+ self.dropout2 = nn.Dropout(dropout)
626
+ self.dropout3 = nn.Dropout(dropout)
627
+
628
+ self.activation = _get_activation_fn(activation)
629
+ self.normalize_before = normalize_before
630
+ self.group_detr = group_detr
631
+
632
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
633
+ return tensor if pos is None else tensor + pos
634
+
635
+ def forward_post(
636
+ self,
637
+ tgt,
638
+ memory,
639
+ tgt_mask: Optional[Tensor] = None,
640
+ memory_mask: Optional[Tensor] = None,
641
+ tgt_key_padding_mask: Optional[Tensor] = None,
642
+ memory_key_padding_mask: Optional[Tensor] = None,
643
+ pos: Optional[Tensor] = None,
644
+ query_pos: Optional[Tensor] = None,
645
+ query_sine_embed=None,
646
+ is_first=False,
647
+ reference_points=None,
648
+ spatial_shapes=None,
649
+ level_start_index=None,
650
+ ):
651
+ bs, num_queries, _ = tgt.shape
652
+
653
+ # ========== Begin of Self-Attention =============
654
+ # Apply projections here
655
+ # shape: batch_size x num_queries x 256
656
+ q = k = tgt + query_pos
657
+ v = tgt
658
+ if self.training:
659
+ q = torch.cat(q.split(num_queries // self.group_detr, dim=1), dim=0)
660
+ k = torch.cat(k.split(num_queries // self.group_detr, dim=1), dim=0)
661
+ v = torch.cat(v.split(num_queries // self.group_detr, dim=1), dim=0)
662
+
663
+ tgt2 = self.self_attn(
664
+ q,
665
+ k,
666
+ v,
667
+ attn_mask=tgt_mask,
668
+ key_padding_mask=tgt_key_padding_mask,
669
+ need_weights=False,
670
+ )[0]
671
+
672
+ if self.training:
673
+ tgt2 = torch.cat(tgt2.split(bs, dim=0), dim=1)
674
+ # ========== End of Self-Attention =============
675
+
676
+ tgt = tgt + self.dropout1(tgt2)
677
+ tgt = self.norm1(tgt)
678
+
679
+ # ========== Begin of Cross-Attention =============
680
+ tgt2 = self.cross_attn(
681
+ self.with_pos_embed(tgt, query_pos),
682
+ reference_points,
683
+ memory,
684
+ spatial_shapes,
685
+ level_start_index,
686
+ memory_key_padding_mask,
687
+ )
688
+ # ========== End of Cross-Attention =============
689
+
690
+ tgt = tgt + self.dropout2(tgt2)
691
+ tgt = self.norm2(tgt)
692
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
693
+ tgt = tgt + self.dropout3(tgt2)
694
+ tgt = self.norm3(tgt)
695
+ return tgt
696
+
697
+ def forward(
698
+ self,
699
+ tgt,
700
+ memory,
701
+ tgt_mask: Optional[Tensor] = None,
702
+ memory_mask: Optional[Tensor] = None,
703
+ tgt_key_padding_mask: Optional[Tensor] = None,
704
+ memory_key_padding_mask: Optional[Tensor] = None,
705
+ pos: Optional[Tensor] = None,
706
+ query_pos: Optional[Tensor] = None,
707
+ query_sine_embed=None,
708
+ is_first=False,
709
+ reference_points=None,
710
+ spatial_shapes=None,
711
+ level_start_index=None,
712
+ ):
713
+ return self.forward_post(
714
+ tgt,
715
+ memory,
716
+ tgt_mask,
717
+ memory_mask,
718
+ tgt_key_padding_mask,
719
+ memory_key_padding_mask,
720
+ pos,
721
+ query_pos,
722
+ query_sine_embed,
723
+ is_first,
724
+ reference_points,
725
+ spatial_shapes,
726
+ level_start_index,
727
+ )
728
+
729
+
730
+ def _get_clones(module, N):
731
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
732
+
733
+
734
+ def build_transformer(args):
735
+ try:
736
+ two_stage = args.two_stage
737
+ except:
738
+ two_stage = False
739
+
740
+ return Transformer(
741
+ d_model=args.hidden_dim,
742
+ sa_nhead=args.sa_nheads,
743
+ ca_nhead=args.ca_nheads,
744
+ num_queries=args.num_queries,
745
+ dropout=args.dropout,
746
+ dim_feedforward=args.dim_feedforward,
747
+ num_decoder_layers=args.dec_layers,
748
+ return_intermediate_dec=True,
749
+ group_detr=args.group_detr,
750
+ two_stage=two_stage,
751
+ num_feature_levels=args.num_feature_levels,
752
+ dec_n_points=args.dec_n_points,
753
+ lite_refpoint_refine=args.lite_refpoint_refine,
754
+ decoder_norm_type=args.decoder_norm,
755
+ bbox_reparam=args.bbox_reparam,
756
+ )
757
+
758
+
759
+ def _get_activation_fn(activation):
760
+ """Return an activation function given a string"""
761
+ if activation == "relu":
762
+ return F.relu
763
+ if activation == "gelu":
764
+ return F.gelu
765
+ if activation == "glu":
766
+ return F.glu
767
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")