birder 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. birder/adversarial/base.py +1 -1
  2. birder/adversarial/simba.py +4 -4
  3. birder/common/cli.py +1 -1
  4. birder/common/fs_ops.py +11 -11
  5. birder/common/lib.py +2 -2
  6. birder/common/masking.py +3 -3
  7. birder/common/training_cli.py +24 -2
  8. birder/common/training_utils.py +28 -4
  9. birder/data/collators/detection.py +9 -1
  10. birder/data/transforms/detection.py +27 -8
  11. birder/data/transforms/mosaic.py +1 -1
  12. birder/datahub/classification.py +3 -3
  13. birder/inference/classification.py +3 -3
  14. birder/inference/data_parallel.py +1 -1
  15. birder/inference/detection.py +5 -5
  16. birder/inference/wbf.py +1 -1
  17. birder/introspection/attention_rollout.py +5 -5
  18. birder/introspection/feature_pca.py +4 -4
  19. birder/introspection/gradcam.py +1 -1
  20. birder/introspection/guided_backprop.py +2 -2
  21. birder/introspection/transformer_attribution.py +3 -3
  22. birder/layers/attention_pool.py +2 -2
  23. birder/model_registry/model_registry.py +2 -1
  24. birder/net/__init__.py +2 -0
  25. birder/net/_rope_vit_configs.py +5 -0
  26. birder/net/_vit_configs.py +0 -13
  27. birder/net/alexnet.py +5 -5
  28. birder/net/base.py +28 -3
  29. birder/net/biformer.py +17 -17
  30. birder/net/cait.py +2 -2
  31. birder/net/cas_vit.py +1 -1
  32. birder/net/coat.py +15 -15
  33. birder/net/convnext_v1.py +2 -10
  34. birder/net/convnext_v1_iso.py +198 -0
  35. birder/net/convnext_v2.py +2 -10
  36. birder/net/crossformer.py +9 -9
  37. birder/net/crossvit.py +1 -1
  38. birder/net/cspnet.py +1 -1
  39. birder/net/cswin_transformer.py +10 -10
  40. birder/net/davit.py +10 -10
  41. birder/net/deit.py +56 -3
  42. birder/net/deit3.py +27 -15
  43. birder/net/detection/__init__.py +4 -0
  44. birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
  45. birder/net/detection/base.py +6 -5
  46. birder/net/detection/deformable_detr.py +26 -28
  47. birder/net/detection/detr.py +9 -9
  48. birder/net/detection/efficientdet.py +9 -28
  49. birder/net/detection/faster_rcnn.py +22 -22
  50. birder/net/detection/fcos.py +8 -8
  51. birder/net/detection/plain_detr.py +852 -0
  52. birder/net/detection/retinanet.py +4 -4
  53. birder/net/detection/rt_detr_v1.py +81 -25
  54. birder/net/detection/rt_detr_v2.py +1147 -0
  55. birder/net/detection/ssd.py +5 -5
  56. birder/net/detection/yolo_v2.py +12 -12
  57. birder/net/detection/yolo_v3.py +19 -19
  58. birder/net/detection/yolo_v4.py +16 -16
  59. birder/net/detection/yolo_v4_tiny.py +3 -3
  60. birder/net/edgenext.py +3 -3
  61. birder/net/edgevit.py +10 -14
  62. birder/net/efficientformer_v1.py +1 -1
  63. birder/net/efficientvim.py +9 -9
  64. birder/net/efficientvit_mit.py +2 -2
  65. birder/net/efficientvit_msft.py +3 -3
  66. birder/net/fasternet.py +1 -1
  67. birder/net/fastvit.py +5 -12
  68. birder/net/flexivit.py +28 -15
  69. birder/net/focalnet.py +5 -9
  70. birder/net/gc_vit.py +11 -11
  71. birder/net/ghostnet_v1.py +1 -1
  72. birder/net/ghostnet_v2.py +1 -1
  73. birder/net/groupmixformer.py +12 -12
  74. birder/net/hgnet_v1.py +1 -1
  75. birder/net/hgnet_v2.py +4 -4
  76. birder/net/hiera.py +6 -6
  77. birder/net/hieradet.py +9 -9
  78. birder/net/hornet.py +3 -3
  79. birder/net/iformer.py +4 -4
  80. birder/net/inception_next.py +4 -14
  81. birder/net/levit.py +3 -3
  82. birder/net/lit_v1.py +13 -15
  83. birder/net/lit_v1_tiny.py +9 -9
  84. birder/net/lit_v2.py +14 -15
  85. birder/net/maxvit.py +10 -22
  86. birder/net/metaformer.py +2 -2
  87. birder/net/mim/crossmae.py +5 -5
  88. birder/net/mim/fcmae.py +3 -5
  89. birder/net/mim/mae_hiera.py +7 -7
  90. birder/net/mim/mae_vit.py +3 -5
  91. birder/net/mim/simmim.py +2 -3
  92. birder/net/mobilenet_v4_hybrid.py +4 -4
  93. birder/net/mobileone.py +5 -12
  94. birder/net/mobilevit_v1.py +2 -2
  95. birder/net/mobilevit_v2.py +5 -9
  96. birder/net/mvit_v2.py +24 -24
  97. birder/net/nextvit.py +2 -2
  98. birder/net/pit.py +11 -26
  99. birder/net/pvt_v1.py +4 -4
  100. birder/net/pvt_v2.py +5 -11
  101. birder/net/regionvit.py +15 -15
  102. birder/net/regnet.py +1 -1
  103. birder/net/repghost.py +4 -5
  104. birder/net/repvgg.py +3 -5
  105. birder/net/repvit.py +2 -2
  106. birder/net/resnest.py +1 -1
  107. birder/net/rope_deit3.py +29 -15
  108. birder/net/rope_flexivit.py +28 -15
  109. birder/net/rope_vit.py +41 -23
  110. birder/net/sequencer2d.py +3 -4
  111. birder/net/shufflenet_v1.py +1 -1
  112. birder/net/shufflenet_v2.py +1 -1
  113. birder/net/simple_vit.py +47 -5
  114. birder/net/smt.py +7 -7
  115. birder/net/ssl/barlow_twins.py +1 -1
  116. birder/net/ssl/byol.py +2 -2
  117. birder/net/ssl/capi.py +3 -3
  118. birder/net/ssl/data2vec2.py +1 -1
  119. birder/net/ssl/dino_v2.py +11 -1
  120. birder/net/ssl/franca.py +26 -2
  121. birder/net/ssl/i_jepa.py +4 -4
  122. birder/net/ssl/mmcr.py +1 -1
  123. birder/net/swiftformer.py +1 -1
  124. birder/net/swin_transformer_v1.py +4 -5
  125. birder/net/swin_transformer_v2.py +4 -7
  126. birder/net/tiny_vit.py +3 -3
  127. birder/net/transnext.py +19 -19
  128. birder/net/uniformer.py +4 -4
  129. birder/net/vgg.py +1 -10
  130. birder/net/vit.py +38 -25
  131. birder/net/vit_parallel.py +35 -20
  132. birder/net/vit_sam.py +10 -10
  133. birder/net/vovnet_v2.py +1 -1
  134. birder/net/xcit.py +9 -7
  135. birder/ops/msda.py +4 -4
  136. birder/ops/swattention.py +10 -10
  137. birder/results/classification.py +3 -3
  138. birder/results/gui.py +8 -8
  139. birder/scripts/benchmark.py +37 -12
  140. birder/scripts/evaluate.py +1 -1
  141. birder/scripts/predict.py +3 -3
  142. birder/scripts/predict_detection.py +2 -2
  143. birder/scripts/train.py +63 -15
  144. birder/scripts/train_barlow_twins.py +10 -7
  145. birder/scripts/train_byol.py +10 -7
  146. birder/scripts/train_capi.py +15 -10
  147. birder/scripts/train_data2vec.py +10 -7
  148. birder/scripts/train_data2vec2.py +10 -7
  149. birder/scripts/train_detection.py +29 -14
  150. birder/scripts/train_dino_v1.py +13 -9
  151. birder/scripts/train_dino_v2.py +27 -14
  152. birder/scripts/train_dino_v2_dist.py +28 -15
  153. birder/scripts/train_franca.py +16 -9
  154. birder/scripts/train_i_jepa.py +12 -9
  155. birder/scripts/train_ibot.py +15 -11
  156. birder/scripts/train_kd.py +64 -17
  157. birder/scripts/train_mim.py +11 -8
  158. birder/scripts/train_mmcr.py +11 -8
  159. birder/scripts/train_rotnet.py +11 -7
  160. birder/scripts/train_simclr.py +10 -7
  161. birder/scripts/train_vicreg.py +10 -7
  162. birder/tools/adversarial.py +4 -4
  163. birder/tools/auto_anchors.py +5 -5
  164. birder/tools/avg_model.py +1 -1
  165. birder/tools/convert_model.py +30 -22
  166. birder/tools/det_results.py +1 -1
  167. birder/tools/download_model.py +1 -1
  168. birder/tools/ensemble_model.py +1 -1
  169. birder/tools/introspection.py +11 -2
  170. birder/tools/labelme_to_coco.py +2 -2
  171. birder/tools/model_info.py +12 -14
  172. birder/tools/pack.py +8 -8
  173. birder/tools/quantize_model.py +53 -4
  174. birder/tools/results.py +2 -2
  175. birder/tools/show_det_iterator.py +19 -6
  176. birder/tools/show_iterator.py +2 -2
  177. birder/tools/similarity.py +5 -5
  178. birder/tools/stats.py +4 -6
  179. birder/tools/voc_to_coco.py +1 -1
  180. birder/version.py +1 -1
  181. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
  182. birder-0.4.1.dist-info/RECORD +300 -0
  183. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
  184. birder-0.4.0.dist-info/RECORD +0 -297
  185. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
  186. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
  187. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1147 @@
1
+ """
2
+ RT-DETR v2 (Real-Time DEtection TRansformer), adapted from
3
+ https://github.com/lyuwenyu/RT-DETR/tree/main/rtdetrv2_pytorch
4
+
5
+ Paper "RT-DETRv2: Improved Baseline with Bag-of-Freebies for Real-Time Detection Transformer",
6
+ https://arxiv.org/abs/2407.17140
7
+ """
8
+
9
+ # Reference license: Apache-2.0
10
+
11
+ import copy
12
+ import math
13
+ from typing import Any
14
+ from typing import Literal
15
+ from typing import Optional
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+ from torchvision.ops import MLP
21
+ from torchvision.ops import boxes as box_ops
22
+
23
+ from birder.common import training_utils
24
+ from birder.model_registry import registry
25
+ from birder.net.base import DetectorBackbone
26
+ from birder.net.detection.base import DetectionBaseNet
27
+ from birder.net.detection.deformable_detr import HungarianMatcher
28
+ from birder.net.detection.deformable_detr import inverse_sigmoid
29
+ from birder.net.detection.rt_detr_v1 import HybridEncoder
30
+ from birder.net.detection.rt_detr_v1 import get_contrastive_denoising_training_group
31
+ from birder.net.detection.rt_detr_v1 import varifocal_loss
32
+ from birder.ops.msda import MultiScaleDeformableAttention as MSDA
33
+
34
+
35
+ def _get_clones(module: nn.Module, N: int) -> nn.ModuleList:
36
+ return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
37
+
38
+
39
+ class MultiheadAttention(nn.Module):
40
+ def __init__(self, d_model: int, num_heads: int, attn_drop: float = 0.0, proj_drop: float = 0.0) -> None:
41
+ super().__init__()
42
+ assert d_model % num_heads == 0, "d_model should be divisible by num_heads"
43
+
44
+ self.num_heads = num_heads
45
+ self.head_dim = d_model // num_heads
46
+ self.scale = self.head_dim**-0.5
47
+
48
+ self.q_proj = nn.Linear(d_model, d_model)
49
+ self.k_proj = nn.Linear(d_model, d_model)
50
+ self.v_proj = nn.Linear(d_model, d_model)
51
+ self.attn_drop = nn.Dropout(attn_drop)
52
+ self.proj = nn.Linear(d_model, d_model)
53
+ self.proj_drop = nn.Dropout(proj_drop)
54
+
55
+ self.reset_parameters()
56
+
57
+ def reset_parameters(self) -> None:
58
+ nn.init.xavier_uniform_(self.q_proj.weight)
59
+ nn.init.xavier_uniform_(self.k_proj.weight)
60
+ nn.init.xavier_uniform_(self.v_proj.weight)
61
+ nn.init.xavier_uniform_(self.proj.weight)
62
+ if self.q_proj.bias is not None:
63
+ nn.init.zeros_(self.q_proj.bias)
64
+ nn.init.zeros_(self.k_proj.bias)
65
+ nn.init.zeros_(self.v_proj.bias)
66
+ nn.init.zeros_(self.proj.bias)
67
+
68
+ def forward(
69
+ self,
70
+ query: torch.Tensor,
71
+ key: torch.Tensor,
72
+ value: torch.Tensor,
73
+ attn_mask: Optional[torch.Tensor] = None,
74
+ ) -> torch.Tensor:
75
+ B, l_q, C = query.shape
76
+ q = self.q_proj(query).reshape(B, l_q, self.num_heads, self.head_dim).transpose(1, 2)
77
+ k = self.k_proj(key).reshape(B, key.size(1), self.num_heads, self.head_dim).transpose(1, 2)
78
+ v = self.v_proj(value).reshape(B, value.size(1), self.num_heads, self.head_dim).transpose(1, 2)
79
+
80
+ if attn_mask is not None:
81
+ # attn_mask is (L, S) boolean where True = masked
82
+ # SDPA expects True = attend, so we invert
83
+ mask = ~attn_mask
84
+ else:
85
+ mask = None
86
+
87
+ attn = F.scaled_dot_product_attention( # pylint: disable=not-callable
88
+ q, k, v, attn_mask=mask, dropout_p=self.attn_drop.p if self.training else 0.0, scale=self.scale
89
+ )
90
+
91
+ attn = attn.transpose(1, 2).reshape(B, l_q, C)
92
+ x = self.proj(attn)
93
+ x = self.proj_drop(x)
94
+
95
+ return x
96
+
97
+
98
+ class MultiScaleDeformableAttention(nn.Module):
99
+ """
100
+ Multi-Scale Deformable Attention with per-level point counts
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ d_model: int,
106
+ n_levels: int,
107
+ n_heads: int,
108
+ n_points: list[int],
109
+ method: Literal["default", "discrete"] = "default",
110
+ offset_scale: float = 0.5,
111
+ ) -> None:
112
+ super().__init__()
113
+ assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
114
+ assert len(n_points) == n_levels, f"n_points list length must equal n_levels ({n_levels})"
115
+ assert method in ("default", "discrete"), "method must be 'default' or 'discrete'"
116
+
117
+ dim_per_head = d_model // n_heads
118
+ if ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0) is False:
119
+ raise ValueError(
120
+ "Set d_model in MultiScaleDeformableAttention to make the dimension of each attention head a power of 2"
121
+ )
122
+
123
+ self.im2col_step = 64
124
+ self.d_model = d_model
125
+ self.n_levels = n_levels
126
+ self.n_heads = n_heads
127
+ self.method = method
128
+ self.offset_scale = offset_scale
129
+
130
+ self.num_points = n_points
131
+ num_points_scale = [1.0 / n for n in self.num_points for _ in range(n)]
132
+ self.num_points_scale = nn.Buffer(torch.tensor(num_points_scale, dtype=torch.float32))
133
+ self.total_points = sum(self.num_points)
134
+ self.uniform_points = len(set(self.num_points)) == 1
135
+
136
+ self.msda = MSDA()
137
+
138
+ self.sampling_offsets = nn.Linear(d_model, n_heads * self.total_points * 2)
139
+ self.attention_weights = nn.Linear(d_model, n_heads * self.total_points)
140
+ self.value_proj = nn.Linear(d_model, d_model)
141
+ self.output_proj = nn.Linear(d_model, d_model)
142
+
143
+ self.reset_parameters()
144
+
145
+ if method == "discrete":
146
+ for param in self.sampling_offsets.parameters():
147
+ param.requires_grad_(False)
148
+
149
+ def reset_parameters(self) -> None:
150
+ nn.init.constant_(self.sampling_offsets.weight, 0.0)
151
+ thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
152
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
153
+ grid_init = grid_init / grid_init.abs().max(-1, keepdim=True)[0]
154
+ grid_init = grid_init.view(self.n_heads, 1, 2).repeat(1, self.total_points, 1)
155
+ scaling = torch.concat([torch.arange(1, n + 1, dtype=torch.float32) for n in self.num_points]).view(1, -1, 1)
156
+ grid_init = grid_init * scaling
157
+
158
+ with torch.no_grad():
159
+ self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
160
+
161
+ nn.init.constant_(self.attention_weights.weight, 0.0)
162
+ nn.init.constant_(self.attention_weights.bias, 0.0)
163
+ nn.init.xavier_uniform_(self.value_proj.weight)
164
+ nn.init.constant_(self.value_proj.bias, 0.0)
165
+ nn.init.xavier_uniform_(self.output_proj.weight)
166
+ nn.init.constant_(self.output_proj.bias, 0.0)
167
+
168
+ def forward(
169
+ self,
170
+ query: torch.Tensor,
171
+ reference_points: torch.Tensor,
172
+ input_flatten: torch.Tensor,
173
+ input_spatial_shapes: torch.Tensor,
174
+ input_level_start_index: torch.Tensor,
175
+ input_padding_mask: Optional[torch.Tensor] = None,
176
+ ) -> torch.Tensor:
177
+ N, num_queries, _ = query.size()
178
+ N, sequence_length, _ = input_flatten.size()
179
+ assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == sequence_length
180
+
181
+ value = self.value_proj(input_flatten)
182
+ if input_padding_mask is not None:
183
+ value = value.masked_fill(input_padding_mask[..., None], float(0))
184
+
185
+ value = value.view(N, sequence_length, self.n_heads, self.d_model // self.n_heads)
186
+
187
+ sampling_offsets = self.sampling_offsets(query).view(N, num_queries, self.n_heads, self.total_points, 2)
188
+ attention_weights = self.attention_weights(query).view(N, num_queries, self.n_heads, self.total_points)
189
+ attention_weights = F.softmax(attention_weights, dim=-1)
190
+
191
+ if reference_points.shape[2] != self.n_levels:
192
+ if reference_points.shape[2] == 1:
193
+ reference_points = reference_points.expand(-1, -1, self.n_levels, -1)
194
+ else:
195
+ raise ValueError(
196
+ f"reference_points must have {self.n_levels} levels, but got {reference_points.shape[2]}"
197
+ )
198
+
199
+ if reference_points.shape[-1] == 2:
200
+ offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
201
+ sampling_locations_list = []
202
+ offset_idx = 0
203
+ for lvl in range(self.n_levels):
204
+ n_pts = self.num_points[lvl]
205
+ ref = reference_points[:, :, None, lvl : lvl + 1, :].expand(-1, -1, self.n_heads, n_pts, -1)
206
+ off = sampling_offsets[:, :, :, offset_idx : offset_idx + n_pts, :]
207
+ norm = offset_normalizer[lvl : lvl + 1].view(1, 1, 1, 1, 2)
208
+ sampling_locations_list.append(ref + off / norm)
209
+ offset_idx += n_pts
210
+
211
+ sampling_locations = torch.concat(sampling_locations_list, dim=3)
212
+
213
+ elif reference_points.shape[-1] == 4:
214
+ sampling_locations_list = []
215
+ offset_idx = 0
216
+ num_points_scale = self.num_points_scale.to(dtype=query.dtype)
217
+ for lvl in range(self.n_levels):
218
+ n_pts = self.num_points[lvl]
219
+ ref = reference_points[:, :, None, lvl : lvl + 1, :].expand(-1, -1, self.n_heads, n_pts, -1)
220
+ off = sampling_offsets[:, :, :, offset_idx : offset_idx + n_pts, :]
221
+ scale = num_points_scale[offset_idx : offset_idx + n_pts].view(1, 1, 1, n_pts, 1)
222
+ sampling_locations_list.append(ref[..., :2] + off * scale * ref[..., 2:] * self.offset_scale)
223
+ offset_idx += n_pts
224
+
225
+ sampling_locations = torch.concat(sampling_locations_list, dim=3)
226
+
227
+ else:
228
+ raise ValueError(
229
+ f"Last dim of reference_points must be 2 or 4, but get {reference_points.shape[-1]} instead"
230
+ )
231
+
232
+ if self.method == "discrete":
233
+ output = self._forward_fallback(
234
+ value, input_spatial_shapes, sampling_locations, attention_weights, method="discrete"
235
+ )
236
+ else:
237
+ if self.uniform_points is True:
238
+ n_pts = self.num_points[0]
239
+ sampling_locations = sampling_locations.view(N, num_queries, self.n_heads, self.n_levels, n_pts, 2)
240
+ attention_weights = attention_weights.view(N, num_queries, self.n_heads, self.n_levels, n_pts)
241
+ output = self.msda(
242
+ value,
243
+ input_spatial_shapes,
244
+ input_level_start_index,
245
+ sampling_locations,
246
+ attention_weights,
247
+ self.im2col_step,
248
+ )
249
+ else:
250
+ output = self._forward_fallback(
251
+ value, input_spatial_shapes, sampling_locations, attention_weights, method="default"
252
+ )
253
+
254
+ output = self.output_proj(output)
255
+ return output
256
+
257
+ def _forward_fallback(
258
+ self,
259
+ value: torch.Tensor,
260
+ spatial_shapes: torch.Tensor,
261
+ sampling_locations: torch.Tensor,
262
+ attention_weights: torch.Tensor,
263
+ method: str = "default",
264
+ ) -> torch.Tensor:
265
+ B, _, n_heads, head_dim = value.size()
266
+ num_queries = sampling_locations.size(1)
267
+
268
+ sampling_grids = 2 * sampling_locations - 1
269
+ split_shape: list[int] = (spatial_shapes[:, 0] * spatial_shapes[:, 1]).tolist()
270
+ value_list = value.permute(0, 2, 3, 1).flatten(0, 1).split(split_shape, dim=-1)
271
+ sampling_grids = sampling_grids.permute(0, 2, 1, 3, 4).flatten(0, 1)
272
+ sampling_locations_list = sampling_grids.split(self.num_points, dim=-2)
273
+
274
+ sampling_value_list = []
275
+ spatial_shapes_list: list[list[int]] = spatial_shapes.tolist()
276
+ for level, (H, W) in enumerate(spatial_shapes_list):
277
+ value_l = value_list[level].reshape(B * n_heads, head_dim, H, W)
278
+ sampling_grid_l = sampling_locations_list[level]
279
+
280
+ if method == "default":
281
+ sampling_value_l = F.grid_sample(
282
+ value_l,
283
+ sampling_grid_l,
284
+ mode="bilinear",
285
+ padding_mode="zeros",
286
+ align_corners=False,
287
+ )
288
+ else:
289
+ sampling_grid_l = sampling_grid_l.clone()
290
+ sampling_grid_l[..., 0] += 1.0 / W
291
+ sampling_grid_l[..., 1] += 1.0 / H
292
+ sampling_value_l = F.grid_sample(
293
+ value_l,
294
+ sampling_grid_l,
295
+ mode="nearest",
296
+ padding_mode="border",
297
+ align_corners=False,
298
+ )
299
+
300
+ # Original upstream code (expected grid of [0, 1])
301
+ # e.g. without the 'sampling_grids = 2 * sampling_locations - 1'
302
+ #
303
+ # n_pts = self.num_points[level]
304
+ # sampling_coord = (sampling_grid_l * torch.tensor([[W, H]], device=value.device) + 0.5).to(torch.int64)
305
+ # sampling_coord[..., 0] = sampling_coord[..., 0].clamp(0, W - 1)
306
+ # sampling_coord[..., 1] = sampling_coord[..., 1].clamp(0, H - 1)
307
+ # sampling_coord = sampling_coord.reshape(B * n_heads, num_queries * n_pts, 2)
308
+ # s_idx = (
309
+ # torch.arange(sampling_coord.shape[0], device=value.device)
310
+ # .unsqueeze(-1)
311
+ # .repeat(1, sampling_coord.shape[1])
312
+ # )
313
+ # sampling_value_l = value_l[s_idx, :, sampling_coord[..., 1], sampling_coord[..., 0]]
314
+ # ... = sampling_value_l.permute(0, 2, 1).reshape(B * n_heads, head_dim, num_queries, n_pts)
315
+
316
+ sampling_value_list.append(sampling_value_l)
317
+
318
+ attn_weights = attention_weights.permute(0, 2, 1, 3).reshape(B * n_heads, 1, num_queries, sum(self.num_points))
319
+ output = torch.concat(sampling_value_list, dim=-1) * attn_weights
320
+ output = output.sum(-1).reshape(B, n_heads * head_dim, num_queries)
321
+
322
+ return output.permute(0, 2, 1)
323
+
324
+
325
+ class TransformerDecoderLayer(nn.Module):
326
+ def __init__(
327
+ self,
328
+ d_model: int,
329
+ d_ffn: int,
330
+ dropout: float,
331
+ n_levels: int,
332
+ n_heads: int,
333
+ n_points: list[int],
334
+ method: Literal["default", "discrete"] = "default",
335
+ offset_scale: float = 0.5,
336
+ ) -> None:
337
+ super().__init__()
338
+
339
+ # Self attention
340
+ self.self_attn = MultiheadAttention(d_model, n_heads, attn_drop=dropout)
341
+ self.norm1 = nn.LayerNorm(d_model)
342
+
343
+ # Cross attention
344
+ self.cross_attn = MultiScaleDeformableAttention(
345
+ d_model, n_levels, n_heads, n_points, method=method, offset_scale=offset_scale
346
+ )
347
+ self.norm2 = nn.LayerNorm(d_model)
348
+
349
+ # FFN
350
+ self.linear1 = nn.Linear(d_model, d_ffn)
351
+ self.linear2 = nn.Linear(d_ffn, d_model)
352
+ self.norm3 = nn.LayerNorm(d_model)
353
+
354
+ self.activation = nn.ReLU()
355
+ self.dropout = nn.Dropout(dropout)
356
+
357
+ def forward(
358
+ self,
359
+ tgt: torch.Tensor,
360
+ query_pos: torch.Tensor,
361
+ reference_points: torch.Tensor,
362
+ src: torch.Tensor,
363
+ src_spatial_shapes: torch.Tensor,
364
+ level_start_index: torch.Tensor,
365
+ src_padding_mask: Optional[torch.Tensor],
366
+ self_attn_mask: Optional[torch.Tensor] = None,
367
+ ) -> torch.Tensor:
368
+ # Self attention
369
+ q = tgt + query_pos
370
+ k = tgt + query_pos
371
+
372
+ tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)
373
+ tgt = tgt + self.dropout(tgt2)
374
+ tgt = self.norm1(tgt)
375
+
376
+ # Cross attention
377
+ tgt2 = self.cross_attn(
378
+ tgt + query_pos, reference_points, src, src_spatial_shapes, level_start_index, src_padding_mask
379
+ )
380
+ tgt = tgt + self.dropout(tgt2)
381
+ tgt = self.norm2(tgt)
382
+
383
+ # FFN
384
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
385
+ tgt = tgt + self.dropout(tgt2)
386
+ tgt = self.norm3(tgt)
387
+
388
+ return tgt
389
+
390
+
391
+ # pylint: disable=invalid-name
392
+ class RT_DETRDecoder(nn.Module):
393
+ """
394
+ RT-DETR v2 Decoder with top-k query selection
395
+ """
396
+
397
+ def __init__(
398
+ self,
399
+ hidden_dim: int,
400
+ num_classes: int,
401
+ num_queries: int,
402
+ num_decoder_layers: int,
403
+ num_levels: int,
404
+ num_heads: int,
405
+ dim_feedforward: int,
406
+ dropout: float,
407
+ num_decoder_points: list[int],
408
+ method: Literal["default", "discrete"] = "default",
409
+ offset_scale: float = 0.5,
410
+ ) -> None:
411
+ super().__init__()
412
+ self.hidden_dim = hidden_dim
413
+ self.num_queries = num_queries
414
+ self.num_levels = num_levels
415
+
416
+ self.enc_output = nn.Sequential(
417
+ nn.Linear(hidden_dim, hidden_dim),
418
+ nn.LayerNorm(hidden_dim),
419
+ )
420
+ self.enc_score_head = nn.Linear(hidden_dim, num_classes)
421
+ self.enc_bbox_head = MLP(hidden_dim, [hidden_dim, hidden_dim, 4], activation_layer=nn.ReLU)
422
+
423
+ decoder_layer = TransformerDecoderLayer(
424
+ hidden_dim,
425
+ dim_feedforward,
426
+ dropout,
427
+ num_levels,
428
+ num_heads,
429
+ num_decoder_points,
430
+ method=method,
431
+ offset_scale=offset_scale,
432
+ )
433
+ self.layers = _get_clones(decoder_layer, num_decoder_layers)
434
+
435
+ self.query_pos_head = MLP(4, [2 * hidden_dim, hidden_dim], activation_layer=nn.ReLU)
436
+ self.class_embed = nn.ModuleList([nn.Linear(hidden_dim, num_classes) for _ in range(num_decoder_layers)])
437
+ self.bbox_embed = nn.ModuleList(
438
+ [MLP(hidden_dim, [hidden_dim, hidden_dim, 4], activation_layer=nn.ReLU) for _ in range(num_decoder_layers)]
439
+ )
440
+ self.use_cache = True
441
+ self._anchor_cache: dict[str, tuple[torch.Tensor, torch.Tensor]] = {}
442
+
443
+ # Weights initialization
444
+ prior_prob = 0.01
445
+ bias_value = -math.log((1 - prior_prob) / prior_prob)
446
+ nn.init.xavier_uniform_(self.enc_output[0].weight)
447
+ nn.init.xavier_uniform_(self.enc_score_head.weight)
448
+ nn.init.constant_(self.enc_score_head.bias, bias_value)
449
+ nn.init.zeros_(self.enc_bbox_head[-2].weight)
450
+ nn.init.zeros_(self.enc_bbox_head[-2].bias)
451
+ for class_embed in self.class_embed:
452
+ nn.init.constant_(class_embed.bias, bias_value)
453
+
454
+ for bbox_embed in self.bbox_embed:
455
+ nn.init.zeros_(bbox_embed[-2].weight)
456
+ nn.init.zeros_(bbox_embed[-2].bias)
457
+
458
+ def set_cache_enabled(self, enabled: bool) -> None:
459
+ self.use_cache = enabled
460
+ if enabled is False:
461
+ self.clear_cache()
462
+
463
+ def clear_cache(self) -> None:
464
+ self._anchor_cache.clear()
465
+
466
+ def _generate_anchors(
467
+ self,
468
+ spatial_shapes: list[list[int]],
469
+ grid_size: float = 0.05,
470
+ device: torch.device = torch.device("cpu"),
471
+ dtype: torch.dtype = torch.float32,
472
+ ) -> tuple[torch.Tensor, torch.Tensor]:
473
+ cache_key: Optional[str] = None
474
+ use_cache = self.use_cache is True and torch.jit.is_tracing() is False and torch.jit.is_scripting() is False
475
+ if use_cache is True:
476
+ spatial_key = ",".join(f"{int(h)}x{int(w)}" for h, w in spatial_shapes)
477
+ cache_key = f"{spatial_key}_{grid_size}_{device}_{dtype}"
478
+ cached = self._anchor_cache.get(cache_key)
479
+ if cached is not None:
480
+ return cached
481
+
482
+ anchors = []
483
+ for lvl, (h, w) in enumerate(spatial_shapes):
484
+ grid_y, grid_x = torch.meshgrid(
485
+ torch.arange(h, dtype=dtype, device=device),
486
+ torch.arange(w, dtype=dtype, device=device),
487
+ indexing="ij",
488
+ )
489
+ grid_xy = torch.stack([grid_x, grid_y], dim=-1)
490
+ valid_wh = torch.tensor([w, h], dtype=dtype, device=device)
491
+ grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_wh
492
+ wh = torch.ones_like(grid_xy) * grid_size * (2.0**lvl)
493
+ anchors.append(torch.concat([grid_xy, wh], dim=-1).reshape(-1, h * w, 4))
494
+
495
+ anchors = torch.concat(anchors, dim=1)
496
+ eps = 0.01
497
+ valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(dim=-1, keepdim=True)
498
+ anchors = torch.log(anchors / (1 - anchors))
499
+ anchors = torch.where(valid_mask, anchors, torch.inf)
500
+
501
+ if cache_key is not None:
502
+ self._anchor_cache[cache_key] = (anchors, valid_mask)
503
+
504
+ return (anchors, valid_mask)
505
+
506
+ def _get_decoder_input(
507
+ self,
508
+ memory: torch.Tensor,
509
+ spatial_shapes: list[list[int]],
510
+ memory_padding_mask: Optional[torch.Tensor] = None,
511
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
512
+ anchors, valid_mask = self._generate_anchors(spatial_shapes, device=memory.device, dtype=memory.dtype)
513
+ if memory_padding_mask is not None:
514
+ valid_mask = valid_mask & ~memory_padding_mask.unsqueeze(-1)
515
+
516
+ memory = valid_mask.to(memory.dtype) * memory
517
+ output_memory = self.enc_output(memory)
518
+ enc_outputs_class = self.enc_score_head(output_memory)
519
+ if memory_padding_mask is not None:
520
+ enc_outputs_class = enc_outputs_class.masked_fill(memory_padding_mask[..., None], float("-inf"))
521
+
522
+ enc_outputs_coord_unact = self.enc_bbox_head(output_memory) + anchors
523
+
524
+ # Select top-k queries based on classification confidence
525
+ _, topk_ind = torch.topk(enc_outputs_class.max(dim=-1).values, self.num_queries, dim=1)
526
+
527
+ # Gather reference points
528
+ reference_points_unact = enc_outputs_coord_unact.gather(
529
+ dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_coord_unact.shape[-1])
530
+ )
531
+
532
+ enc_topk_bboxes = reference_points_unact.sigmoid()
533
+
534
+ # Gather encoder logits for loss computation
535
+ enc_topk_logits = enc_outputs_class.gather(
536
+ dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1])
537
+ )
538
+
539
+ # Extract region features
540
+ target = output_memory.gather(dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1]))
541
+ target = target.detach()
542
+
543
+ return (target, reference_points_unact.detach(), enc_topk_bboxes, enc_topk_logits)
544
+
545
+ def forward( # pylint: disable=too-many-locals
546
+ self,
547
+ feats: list[torch.Tensor],
548
+ spatial_shapes: list[list[int]],
549
+ level_start_index: list[int],
550
+ denoising_class: Optional[torch.Tensor] = None,
551
+ denoising_bbox_unact: Optional[torch.Tensor] = None,
552
+ attn_mask: Optional[torch.Tensor] = None,
553
+ padding_mask: Optional[list[torch.Tensor]] = None,
554
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
555
+ memory = []
556
+ mask_flatten = []
557
+ for idx, feat in enumerate(feats):
558
+ feat_flat = feat.flatten(2).permute(0, 2, 1) # (B, H*W, C)
559
+ memory.append(feat_flat)
560
+ if padding_mask is not None:
561
+ mask_flatten.append(padding_mask[idx].flatten(1))
562
+
563
+ memory = torch.concat(memory, dim=1)
564
+ memory_padding_mask = torch.concat(mask_flatten, dim=1) if mask_flatten else None
565
+
566
+ # Get decoder input (query selection)
567
+ target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = self._get_decoder_input(
568
+ memory, spatial_shapes, memory_padding_mask
569
+ )
570
+
571
+ # Concatenate denoising queries if provided
572
+ if denoising_class is not None and denoising_bbox_unact is not None:
573
+ target = torch.concat([denoising_class, target], dim=1)
574
+ init_ref_points_unact = torch.concat([denoising_bbox_unact, init_ref_points_unact], dim=1)
575
+
576
+ # Prepare spatial shapes and level start index as tensors
577
+ spatial_shapes_tensor = torch.tensor(spatial_shapes, dtype=torch.long, device=memory.device)
578
+ level_start_index_tensor = torch.tensor(level_start_index, dtype=torch.long, device=memory.device)
579
+
580
+ # Decoder forward
581
+ out_bboxes = []
582
+ out_logits = []
583
+ reference_points = init_ref_points_unact.sigmoid()
584
+ for decoder_layer, bbox_head, class_head in zip(self.layers, self.bbox_embed, self.class_embed):
585
+ query_pos = self.query_pos_head(reference_points)
586
+ reference_points_input = reference_points.unsqueeze(2).repeat(1, 1, len(spatial_shapes), 1)
587
+ target = decoder_layer(
588
+ target,
589
+ query_pos,
590
+ reference_points_input,
591
+ memory,
592
+ spatial_shapes_tensor,
593
+ level_start_index_tensor,
594
+ memory_padding_mask,
595
+ attn_mask,
596
+ )
597
+
598
+ bbox_delta = bbox_head(target)
599
+ new_reference_points = inverse_sigmoid(reference_points) + bbox_delta
600
+ new_reference_points = new_reference_points.sigmoid()
601
+
602
+ # Classification
603
+ class_logits = class_head(target)
604
+
605
+ out_bboxes.append(new_reference_points)
606
+ out_logits.append(class_logits)
607
+
608
+ # Update reference points for next layer
609
+ reference_points = new_reference_points.detach()
610
+
611
+ out_bboxes = torch.stack(out_bboxes)
612
+ out_logits = torch.stack(out_logits)
613
+
614
+ return (out_bboxes, out_logits, enc_topk_bboxes, enc_topk_logits)
615
+
616
+
617
+ # pylint: disable=invalid-name
618
+ class RT_DETR_v2(DetectionBaseNet):
619
+ default_size = (640, 640)
620
+
621
+ def __init__(
622
+ self,
623
+ num_classes: int,
624
+ backbone: DetectorBackbone,
625
+ *,
626
+ config: Optional[dict[str, Any]] = None,
627
+ size: Optional[tuple[int, int]] = None,
628
+ export_mode: bool = False,
629
+ ) -> None:
630
+ super().__init__(num_classes, backbone, config=config, size=size, export_mode=export_mode)
631
+ assert self.config is not None, "must set config"
632
+
633
+ self.reparameterized = False
634
+
635
+ # Sigmoid based classification (no background class in predictions)
636
+ self.num_classes = self.num_classes - 1
637
+
638
+ hidden_dim = self.config.get("hidden_dim", 256)
639
+ num_heads = self.config.get("num_heads", 8)
640
+ dim_feedforward = self.config.get("dim_feedforward", 1024)
641
+ dropout: float = self.config.get("dropout", 0.0)
642
+ num_encoder_layers: int = self.config.get("num_encoder_layers", 1)
643
+ num_decoder_layers: int = self.config["num_decoder_layers"]
644
+ num_queries: int = self.config.get("num_queries", 300)
645
+ expansion: float = self.config.get("expansion", 1.0)
646
+ depth_multiplier: float = self.config.get("depth_multiplier", 1.0)
647
+ use_giou: bool = self.config.get("use_giou", True)
648
+ num_denoising: int = self.config.get("num_denoising", 100)
649
+ label_noise_ratio: float = self.config.get("label_noise_ratio", 0.5)
650
+ box_noise_scale: float = self.config.get("box_noise_scale", 1.0)
651
+ num_decoder_points: list[int] = self.config.get("num_decoder_points", [4, 4, 4])
652
+ method: Literal["default", "discrete"] = self.config.get("method", "default")
653
+ offset_scale: float = self.config.get("offset_scale", 0.5)
654
+
655
+ self.hidden_dim = hidden_dim
656
+ self.num_queries = num_queries
657
+ self.num_denoising = num_denoising
658
+ self.label_noise_ratio = label_noise_ratio
659
+ self.box_noise_scale = box_noise_scale
660
+
661
+ self.backbone.return_channels = self.backbone.return_channels[-3:]
662
+ self.backbone.return_stages = self.backbone.return_stages[-3:]
663
+ self.num_levels = len(self.backbone.return_channels)
664
+
665
+ self.encoder = HybridEncoder(
666
+ in_channels=self.backbone.return_channels,
667
+ hidden_dim=hidden_dim,
668
+ num_encoder_layers=num_encoder_layers,
669
+ dim_feedforward=dim_feedforward,
670
+ dropout=dropout,
671
+ num_heads=num_heads,
672
+ expansion=expansion,
673
+ depth_multiplier=depth_multiplier,
674
+ )
675
+ self.decoder = RT_DETRDecoder(
676
+ hidden_dim=hidden_dim,
677
+ num_classes=self.num_classes,
678
+ num_queries=num_queries,
679
+ num_decoder_layers=num_decoder_layers,
680
+ num_levels=self.num_levels,
681
+ num_heads=num_heads,
682
+ dim_feedforward=dim_feedforward,
683
+ dropout=dropout,
684
+ num_decoder_points=num_decoder_points,
685
+ method=method,
686
+ offset_scale=offset_scale,
687
+ )
688
+
689
+ self.matcher = HungarianMatcher(cost_class=2.0, cost_bbox=5.0, cost_giou=2.0, use_giou=use_giou)
690
+
691
+ # Denoising class embedding for Contrastive denoising (CDN) training
692
+ if self.num_denoising > 0:
693
+ self.denoising_class_embed = nn.Embedding(self.num_classes + 1, hidden_dim, padding_idx=self.num_classes)
694
+
695
+ if self.export_mode is False:
696
+ self.forward = torch.compiler.disable(recursive=False)(self.forward) # type: ignore[method-assign]
697
+
698
+ def _set_cache_enabled(self, enabled: bool) -> None:
699
+ self.encoder.set_cache_enabled(enabled)
700
+ self.decoder.set_cache_enabled(enabled)
701
+
702
+ def clear_cache(self) -> None:
703
+ self.encoder.clear_cache()
704
+ self.decoder.clear_cache()
705
+
706
+ def adjust_size(self, new_size: tuple[int, int]) -> None:
707
+ if new_size == self.size:
708
+ return
709
+
710
+ super().adjust_size(new_size)
711
+ self.clear_cache()
712
+
713
+ def set_dynamic_size(self, dynamic_size: bool = True) -> None:
714
+ super().set_dynamic_size(dynamic_size)
715
+ self._set_cache_enabled(dynamic_size is False)
716
+
717
+ def reset_classifier(self, num_classes: int) -> None:
718
+ self.num_classes = num_classes
719
+
720
+ self.decoder.enc_score_head = nn.Linear(self.hidden_dim, num_classes)
721
+ self.decoder.class_embed = nn.ModuleList(
722
+ [nn.Linear(self.hidden_dim, num_classes) for _ in range(len(self.decoder.layers))]
723
+ )
724
+
725
+ if self.num_denoising > 0:
726
+ self.denoising_class_embed = nn.Embedding(num_classes + 1, self.hidden_dim, padding_idx=num_classes)
727
+
728
+ prior_prob = 0.01
729
+ bias_value = -math.log((1 - prior_prob) / prior_prob)
730
+ nn.init.constant_(self.decoder.enc_score_head.bias, bias_value)
731
+ for class_embed in self.decoder.class_embed:
732
+ nn.init.constant_(class_embed.bias, bias_value)
733
+
734
+ def freeze(self, freeze_classifier: bool = True) -> None:
735
+ for param in self.parameters():
736
+ param.requires_grad_(False)
737
+
738
+ if freeze_classifier is False:
739
+ for param in self.decoder.class_embed.parameters():
740
+ param.requires_grad_(True)
741
+ for param in self.decoder.enc_score_head.parameters():
742
+ param.requires_grad_(True)
743
+ if self.num_denoising > 0:
744
+ for param in self.denoising_class_embed.parameters():
745
+ param.requires_grad_(True)
746
+
747
+ def _get_src_permutation_idx(self, indices: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
748
+ batch_idx = torch.concat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
749
+ src_idx = torch.concat([src for (src, _) in indices])
750
+ return (batch_idx, src_idx)
751
+
752
+ def _class_loss(
753
+ self,
754
+ cls_logits: torch.Tensor,
755
+ box_output: torch.Tensor,
756
+ targets: list[dict[str, torch.Tensor]],
757
+ indices: list[torch.Tensor],
758
+ num_boxes: float,
759
+ ) -> torch.Tensor:
760
+ idx = self._get_src_permutation_idx(indices)
761
+ target_classes_o = torch.concat([t["labels"][J] for t, (_, J) in zip(targets, indices)], dim=0)
762
+ target_classes = torch.full(cls_logits.shape[:2], self.num_classes, dtype=torch.int64, device=cls_logits.device)
763
+ target_classes[idx] = target_classes_o
764
+
765
+ target_classes_onehot = torch.zeros(
766
+ [cls_logits.shape[0], cls_logits.shape[1], cls_logits.shape[2] + 1],
767
+ dtype=cls_logits.dtype,
768
+ layout=cls_logits.layout,
769
+ device=cls_logits.device,
770
+ )
771
+ target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
772
+ target_classes_onehot = target_classes_onehot[:, :, :-1]
773
+
774
+ src_boxes = box_output[idx]
775
+ target_boxes = torch.concat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
776
+ ious = torch.diag(
777
+ box_ops.box_iou(
778
+ box_ops.box_convert(src_boxes, in_fmt="cxcywh", out_fmt="xyxy"),
779
+ box_ops.box_convert(target_boxes, in_fmt="cxcywh", out_fmt="xyxy"),
780
+ )
781
+ ).detach()
782
+
783
+ target_score_o = torch.zeros(cls_logits.shape[:2], dtype=cls_logits.dtype, device=cls_logits.device)
784
+ target_score_o[idx] = ious.to(cls_logits.dtype)
785
+ target_score = target_score_o.unsqueeze(-1) * target_classes_onehot
786
+
787
+ loss = varifocal_loss(cls_logits, target_score, target_classes_onehot, alpha=0.75, gamma=2.0)
788
+ loss_ce = (loss.mean(1).sum() / num_boxes) * cls_logits.shape[1]
789
+
790
+ return loss_ce
791
+
792
+ def _box_loss(
793
+ self,
794
+ box_output: torch.Tensor,
795
+ targets: list[dict[str, torch.Tensor]],
796
+ indices: list[torch.Tensor],
797
+ num_boxes: float,
798
+ ) -> tuple[torch.Tensor, torch.Tensor]:
799
+ idx = self._get_src_permutation_idx(indices)
800
+ src_boxes = box_output[idx]
801
+ target_boxes = torch.concat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
802
+
803
+ loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none")
804
+ loss_bbox = loss_bbox.sum() / num_boxes
805
+
806
+ loss_giou = 1 - torch.diag(
807
+ box_ops.generalized_box_iou(
808
+ box_ops.box_convert(src_boxes, in_fmt="cxcywh", out_fmt="xyxy"),
809
+ box_ops.box_convert(target_boxes, in_fmt="cxcywh", out_fmt="xyxy"),
810
+ )
811
+ )
812
+ loss_giou = loss_giou.sum() / num_boxes
813
+
814
+ return (loss_bbox, loss_giou)
815
+
816
+ def _compute_denoising_loss(
817
+ self,
818
+ dn_out_bboxes: torch.Tensor,
819
+ dn_out_logits: torch.Tensor,
820
+ targets: list[dict[str, torch.Tensor]],
821
+ dn_meta: dict[str, Any],
822
+ num_boxes: float,
823
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
824
+ dn_positive_idx = dn_meta["dn_positive_idx"]
825
+ num_groups = dn_meta["dn_num_group"]
826
+
827
+ loss_ce_list = []
828
+ loss_bbox_list = []
829
+ loss_giou_list = []
830
+
831
+ dn_num_boxes = max(num_boxes * num_groups, 1.0)
832
+ for layer_idx in range(dn_out_logits.shape[0]):
833
+ # Construct indices from denoising metadata
834
+ indices = []
835
+ for batch_idx, pos_idx in enumerate(dn_positive_idx):
836
+ if len(pos_idx) > 0:
837
+ src_idx = pos_idx
838
+ num_gt = len(targets[batch_idx]["labels"])
839
+ tgt_idx = torch.arange(num_gt, device=pos_idx.device).repeat(num_groups)
840
+ indices.append((src_idx, tgt_idx))
841
+ else:
842
+ indices.append(
843
+ (
844
+ torch.tensor([], dtype=torch.long, device=dn_out_logits.device),
845
+ torch.tensor([], dtype=torch.long, device=dn_out_logits.device),
846
+ )
847
+ )
848
+
849
+ loss_ce = self._class_loss(
850
+ dn_out_logits[layer_idx], dn_out_bboxes[layer_idx], targets, indices, dn_num_boxes
851
+ )
852
+ loss_bbox, loss_giou = self._box_loss(dn_out_bboxes[layer_idx], targets, indices, dn_num_boxes)
853
+
854
+ loss_ce_list.append(loss_ce)
855
+ loss_bbox_list.append(loss_bbox)
856
+ loss_giou_list.append(loss_giou)
857
+
858
+ loss_ce_dn = torch.stack(loss_ce_list).sum()
859
+ loss_bbox_dn = torch.stack(loss_bbox_list).sum()
860
+ loss_giou_dn = torch.stack(loss_giou_list).sum()
861
+
862
+ return (loss_ce_dn, loss_bbox_dn, loss_giou_dn)
863
+
864
+ @torch.jit.unused # type: ignore[untyped-decorator]
865
+ @torch.compiler.disable() # type: ignore[untyped-decorator]
866
+ def _compute_loss_from_outputs( # pylint: disable=too-many-locals
867
+ self,
868
+ targets: list[dict[str, torch.Tensor]],
869
+ out_bboxes: torch.Tensor,
870
+ out_logits: torch.Tensor,
871
+ enc_topk_bboxes: torch.Tensor,
872
+ enc_topk_logits: torch.Tensor,
873
+ dn_out_bboxes: Optional[torch.Tensor] = None,
874
+ dn_out_logits: Optional[torch.Tensor] = None,
875
+ dn_meta: Optional[dict[str, Any]] = None,
876
+ ) -> dict[str, torch.Tensor]:
877
+ # Compute the average number of target boxes across all nodes
878
+ num_boxes = sum(len(t["labels"]) for t in targets)
879
+ num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=out_logits.device)
880
+ if training_utils.is_dist_available_and_initialized() is True:
881
+ torch.distributed.all_reduce(num_boxes)
882
+
883
+ num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1).item()
884
+
885
+ loss_ce_list = []
886
+ loss_bbox_list = []
887
+ loss_giou_list = []
888
+
889
+ # Decoder losses (all layers)
890
+ for layer_idx in range(out_logits.shape[0]):
891
+ indices = self.matcher(out_logits[layer_idx], out_bboxes[layer_idx], targets)
892
+ loss_ce = self._class_loss(out_logits[layer_idx], out_bboxes[layer_idx], targets, indices, num_boxes)
893
+ loss_bbox, loss_giou = self._box_loss(out_bboxes[layer_idx], targets, indices, num_boxes)
894
+ loss_ce_list.append(loss_ce)
895
+ loss_bbox_list.append(loss_bbox)
896
+ loss_giou_list.append(loss_giou)
897
+
898
+ # Encoder auxiliary loss
899
+ enc_indices = self.matcher(enc_topk_logits, enc_topk_bboxes, targets)
900
+ loss_ce_enc = self._class_loss(enc_topk_logits, enc_topk_bboxes, targets, enc_indices, num_boxes)
901
+ loss_bbox_enc, loss_giou_enc = self._box_loss(enc_topk_bboxes, targets, enc_indices, num_boxes)
902
+ loss_ce_list.append(loss_ce_enc)
903
+ loss_bbox_list.append(loss_bbox_enc)
904
+ loss_giou_list.append(loss_giou_enc)
905
+
906
+ loss_ce = torch.stack(loss_ce_list).sum() # VFL weight is 1
907
+ loss_bbox = torch.stack(loss_bbox_list).sum() * 5
908
+ loss_giou = torch.stack(loss_giou_list).sum() * 2
909
+
910
+ # Add denoising loss if available
911
+ if dn_out_bboxes is not None and dn_out_logits is not None and dn_meta is not None:
912
+ loss_ce_dn, loss_bbox_dn, loss_giou_dn = self._compute_denoising_loss(
913
+ dn_out_bboxes, dn_out_logits, targets, dn_meta, num_boxes
914
+ )
915
+ loss_ce = loss_ce + loss_ce_dn
916
+ loss_bbox = loss_bbox + loss_bbox_dn * 5
917
+ loss_giou = loss_giou + loss_giou_dn * 2
918
+
919
+ losses = {
920
+ "labels": loss_ce,
921
+ "boxes": loss_bbox,
922
+ "giou": loss_giou,
923
+ }
924
+
925
+ return losses
926
+
927
+ @torch.jit.unused # type: ignore[untyped-decorator]
928
+ @torch.compiler.disable() # type: ignore[untyped-decorator]
929
+ def compute_loss(
930
+ self,
931
+ encoder_features: list[torch.Tensor],
932
+ spatial_shapes: list[list[int]],
933
+ level_start_index: list[int],
934
+ targets: list[dict[str, torch.Tensor]],
935
+ images: Any,
936
+ masks: Optional[list[torch.Tensor]] = None,
937
+ ) -> dict[str, torch.Tensor]:
938
+ device = encoder_features[0].device
939
+ for idx, target in enumerate(targets):
940
+ boxes = target["boxes"]
941
+ boxes = box_ops.box_convert(boxes, in_fmt="xyxy", out_fmt="cxcywh")
942
+ boxes = boxes / torch.tensor(images.image_sizes[idx][::-1] * 2, dtype=torch.float32, device=device)
943
+ targets[idx]["boxes"] = boxes
944
+ targets[idx]["labels"] = target["labels"] - 1 # No background
945
+
946
+ denoising_class, denoising_bbox_unact, attn_mask, dn_meta = self._prepare_cdn_queries(targets)
947
+
948
+ out_bboxes, out_logits, enc_topk_bboxes, enc_topk_logits = self.decoder(
949
+ encoder_features,
950
+ spatial_shapes,
951
+ level_start_index,
952
+ denoising_class,
953
+ denoising_bbox_unact,
954
+ attn_mask,
955
+ masks,
956
+ )
957
+
958
+ if dn_meta is not None:
959
+ dn_num_split, _num_queries = dn_meta["dn_num_split"]
960
+ dn_out_bboxes = out_bboxes[:, :, :dn_num_split]
961
+ dn_out_logits = out_logits[:, :, :dn_num_split]
962
+ out_bboxes = out_bboxes[:, :, dn_num_split:]
963
+ out_logits = out_logits[:, :, dn_num_split:]
964
+ else:
965
+ dn_out_bboxes = None
966
+ dn_out_logits = None
967
+
968
+ losses: dict[str, torch.Tensor] = self._compute_loss_from_outputs(
969
+ targets, out_bboxes, out_logits, enc_topk_bboxes, enc_topk_logits, dn_out_bboxes, dn_out_logits, dn_meta
970
+ )
971
+
972
+ return losses
973
+
974
+ def postprocess_detections(
975
+ self, class_logits: torch.Tensor, box_regression: torch.Tensor, image_shapes: list[tuple[int, int]]
976
+ ) -> list[dict[str, torch.Tensor]]:
977
+ prob = class_logits.sigmoid()
978
+ topk_values, topk_indexes = torch.topk(prob.view(class_logits.shape[0], -1), k=self.decoder.num_queries, dim=1)
979
+ scores = topk_values
980
+ topk_boxes = topk_indexes // class_logits.shape[2]
981
+ labels = topk_indexes % class_logits.shape[2]
982
+ labels += 1 # Background offset
983
+
984
+ target_sizes = torch.tensor(image_shapes, device=class_logits.device)
985
+
986
+ # Convert to [x0, y0, x1, y1] format
987
+ boxes = box_ops.box_convert(box_regression, in_fmt="cxcywh", out_fmt="xyxy")
988
+ boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
989
+
990
+ # Convert from relative [0, 1] to absolute [0, height] coordinates
991
+ img_h, img_w = target_sizes.unbind(1)
992
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
993
+ boxes = boxes * scale_fct[:, None, :]
994
+
995
+ detections: list[dict[str, torch.Tensor]] = []
996
+ for s, l, b in zip(scores, labels, boxes):
997
+ detections.append(
998
+ {
999
+ "boxes": b,
1000
+ "scores": s,
1001
+ "labels": l,
1002
+ }
1003
+ )
1004
+
1005
+ return detections
1006
+
1007
+ @torch.jit.unused # type: ignore[untyped-decorator]
1008
+ def _prepare_cdn_queries(
1009
+ self, targets: list[dict[str, torch.Tensor]]
1010
+ ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[dict[str, Any]]]:
1011
+ if self.num_denoising > 0:
1012
+ result: tuple[
1013
+ Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[dict[str, Any]]
1014
+ ] = get_contrastive_denoising_training_group(
1015
+ targets,
1016
+ self.num_classes,
1017
+ self.num_queries,
1018
+ self.denoising_class_embed,
1019
+ num_denoising_queries=self.num_denoising,
1020
+ label_noise_ratio=self.label_noise_ratio,
1021
+ box_noise_scale=self.box_noise_scale,
1022
+ )
1023
+ return result
1024
+
1025
+ return (None, None, None, None)
1026
+
1027
+ def forward(
1028
+ self,
1029
+ x: torch.Tensor,
1030
+ targets: Optional[list[dict[str, torch.Tensor]]] = None,
1031
+ masks: Optional[torch.Tensor] = None,
1032
+ image_sizes: Optional[list[list[int]]] = None,
1033
+ ) -> tuple[list[dict[str, torch.Tensor]], dict[str, torch.Tensor]]:
1034
+ self._input_check(targets)
1035
+ images = self._to_img_list(x, image_sizes)
1036
+
1037
+ # Backbone features
1038
+ features: dict[str, torch.Tensor] = self.backbone.detection_features(x)
1039
+ feature_list = list(features.values())
1040
+
1041
+ # Hybrid encoder
1042
+ mask_list: list[torch.Tensor] = []
1043
+ for feat in feature_list:
1044
+ if masks is not None:
1045
+ mask_size = feat.shape[-2:]
1046
+ m = F.interpolate(masks[None].float(), size=mask_size, mode="nearest").to(torch.bool)[0]
1047
+ else:
1048
+ B, _, H, W = feat.size()
1049
+ m = torch.zeros(B, H, W, dtype=torch.bool, device=x.device)
1050
+ mask_list.append(m)
1051
+
1052
+ encoder_features = self.encoder(feature_list, masks=mask_list)
1053
+
1054
+ # Prepare spatial shapes and level start index
1055
+ spatial_shapes: list[list[int]] = []
1056
+ level_start_index: list[int] = [0]
1057
+ for feat in encoder_features:
1058
+ H = feat.shape[2]
1059
+ W = feat.shape[3]
1060
+ spatial_shapes.append([H, W])
1061
+ level_start_index.append(H * W + level_start_index[-1])
1062
+
1063
+ level_start_index.pop()
1064
+
1065
+ detections: list[dict[str, torch.Tensor]] = []
1066
+ losses: dict[str, torch.Tensor] = {}
1067
+ if self.training is True:
1068
+ assert targets is not None, "targets should not be None when in training mode"
1069
+ losses = self.compute_loss(encoder_features, spatial_shapes, level_start_index, targets, images, mask_list)
1070
+ else:
1071
+ # Inference path - no CDN
1072
+ out_bboxes, out_logits, _, _ = self.decoder(
1073
+ encoder_features, spatial_shapes, level_start_index, padding_mask=mask_list
1074
+ )
1075
+ detections = self.postprocess_detections(out_logits[-1], out_bboxes[-1], images.image_sizes)
1076
+
1077
+ return (detections, losses)
1078
+
1079
+ @torch.no_grad() # type: ignore[untyped-decorator]
1080
+ def reparameterize_model(self) -> None:
1081
+ if self.reparameterized is True:
1082
+ return
1083
+
1084
+ for module in self.modules():
1085
+ if hasattr(module, "reparameterize") is True:
1086
+ module.reparameterize()
1087
+
1088
+ self.reparameterized = True
1089
+
1090
+
1091
+ registry.register_model_config(
1092
+ "rt_detr_v2_s",
1093
+ RT_DETR_v2,
1094
+ config={
1095
+ "num_decoder_layers": 3,
1096
+ "expansion": 0.5,
1097
+ },
1098
+ )
1099
+ registry.register_model_config(
1100
+ "rt_detr_v2_s_dsp",
1101
+ RT_DETR_v2,
1102
+ config={
1103
+ "num_decoder_layers": 3,
1104
+ "expansion": 0.5,
1105
+ "method": "discrete",
1106
+ },
1107
+ )
1108
+ registry.register_model_config(
1109
+ "rt_detr_v2",
1110
+ RT_DETR_v2,
1111
+ config={
1112
+ "num_decoder_layers": 6,
1113
+ },
1114
+ )
1115
+ registry.register_model_config(
1116
+ "rt_detr_v2_dsp",
1117
+ RT_DETR_v2,
1118
+ config={
1119
+ "num_decoder_layers": 6,
1120
+ "method": "discrete",
1121
+ },
1122
+ )
1123
+ registry.register_model_config(
1124
+ "rt_detr_v2_l",
1125
+ RT_DETR_v2,
1126
+ config={
1127
+ "num_decoder_layers": 6,
1128
+ "expansion": 1.0,
1129
+ "depth_multiplier": 1.0,
1130
+ "num_heads": 12, # Deviates from upstream to keep head_dim=32 (power of 2) for MSDA kernel
1131
+ "hidden_dim": 384,
1132
+ "dim_feedforward": 2048,
1133
+ },
1134
+ )
1135
+ registry.register_model_config(
1136
+ "rt_detr_v2_l_dsp",
1137
+ RT_DETR_v2,
1138
+ config={
1139
+ "num_decoder_layers": 6,
1140
+ "expansion": 1.0,
1141
+ "depth_multiplier": 1.0,
1142
+ "num_heads": 12, # Deviates from upstream to keep head_dim=32 (power of 2) for MSDA kernel
1143
+ "hidden_dim": 384,
1144
+ "dim_feedforward": 2048,
1145
+ "method": "discrete",
1146
+ },
1147
+ )