birder 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (216) hide show
  1. birder/adversarial/base.py +1 -1
  2. birder/adversarial/simba.py +4 -4
  3. birder/common/cli.py +1 -1
  4. birder/common/fs_ops.py +13 -13
  5. birder/common/lib.py +2 -2
  6. birder/common/masking.py +3 -3
  7. birder/common/training_cli.py +24 -2
  8. birder/common/training_utils.py +28 -4
  9. birder/data/collators/detection.py +9 -1
  10. birder/data/transforms/detection.py +27 -8
  11. birder/data/transforms/mosaic.py +1 -1
  12. birder/datahub/classification.py +3 -3
  13. birder/inference/classification.py +3 -3
  14. birder/inference/data_parallel.py +1 -1
  15. birder/inference/detection.py +5 -5
  16. birder/inference/wbf.py +1 -1
  17. birder/introspection/attention_rollout.py +6 -6
  18. birder/introspection/feature_pca.py +4 -4
  19. birder/introspection/gradcam.py +1 -1
  20. birder/introspection/guided_backprop.py +2 -2
  21. birder/introspection/transformer_attribution.py +4 -4
  22. birder/layers/attention_pool.py +2 -2
  23. birder/layers/layer_scale.py +1 -1
  24. birder/model_registry/model_registry.py +2 -1
  25. birder/net/__init__.py +4 -10
  26. birder/net/_rope_vit_configs.py +435 -0
  27. birder/net/_vit_configs.py +466 -0
  28. birder/net/alexnet.py +5 -5
  29. birder/net/base.py +28 -3
  30. birder/net/biformer.py +18 -17
  31. birder/net/cait.py +7 -7
  32. birder/net/cas_vit.py +1 -1
  33. birder/net/coat.py +27 -27
  34. birder/net/conv2former.py +3 -3
  35. birder/net/convmixer.py +1 -1
  36. birder/net/convnext_v1.py +3 -11
  37. birder/net/convnext_v1_iso.py +198 -0
  38. birder/net/convnext_v2.py +2 -10
  39. birder/net/crossformer.py +9 -9
  40. birder/net/crossvit.py +6 -6
  41. birder/net/cspnet.py +1 -1
  42. birder/net/cswin_transformer.py +10 -10
  43. birder/net/davit.py +11 -11
  44. birder/net/deit.py +68 -29
  45. birder/net/deit3.py +69 -204
  46. birder/net/densenet.py +9 -8
  47. birder/net/detection/__init__.py +4 -0
  48. birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
  49. birder/net/detection/base.py +6 -5
  50. birder/net/detection/deformable_detr.py +31 -30
  51. birder/net/detection/detr.py +14 -11
  52. birder/net/detection/efficientdet.py +10 -29
  53. birder/net/detection/faster_rcnn.py +22 -22
  54. birder/net/detection/fcos.py +8 -8
  55. birder/net/detection/plain_detr.py +852 -0
  56. birder/net/detection/retinanet.py +4 -4
  57. birder/net/detection/rt_detr_v1.py +81 -25
  58. birder/net/detection/rt_detr_v2.py +1147 -0
  59. birder/net/detection/ssd.py +5 -5
  60. birder/net/detection/yolo_v2.py +12 -12
  61. birder/net/detection/yolo_v3.py +19 -19
  62. birder/net/detection/yolo_v4.py +16 -16
  63. birder/net/detection/yolo_v4_tiny.py +3 -3
  64. birder/net/dpn.py +1 -2
  65. birder/net/edgenext.py +5 -4
  66. birder/net/edgevit.py +13 -14
  67. birder/net/efficientformer_v1.py +3 -2
  68. birder/net/efficientformer_v2.py +18 -31
  69. birder/net/efficientnet_v2.py +3 -0
  70. birder/net/efficientvim.py +9 -9
  71. birder/net/efficientvit_mit.py +7 -7
  72. birder/net/efficientvit_msft.py +3 -3
  73. birder/net/fasternet.py +3 -3
  74. birder/net/fastvit.py +5 -12
  75. birder/net/flexivit.py +50 -58
  76. birder/net/focalnet.py +5 -9
  77. birder/net/gc_vit.py +11 -11
  78. birder/net/ghostnet_v1.py +1 -1
  79. birder/net/ghostnet_v2.py +1 -1
  80. birder/net/groupmixformer.py +13 -13
  81. birder/net/hgnet_v1.py +6 -6
  82. birder/net/hgnet_v2.py +4 -4
  83. birder/net/hiera.py +6 -6
  84. birder/net/hieradet.py +9 -9
  85. birder/net/hornet.py +3 -3
  86. birder/net/iformer.py +4 -4
  87. birder/net/inception_next.py +5 -15
  88. birder/net/inception_resnet_v1.py +3 -3
  89. birder/net/inception_resnet_v2.py +7 -4
  90. birder/net/inception_v3.py +3 -0
  91. birder/net/inception_v4.py +3 -0
  92. birder/net/levit.py +3 -3
  93. birder/net/lit_v1.py +13 -15
  94. birder/net/lit_v1_tiny.py +9 -9
  95. birder/net/lit_v2.py +14 -15
  96. birder/net/maxvit.py +11 -23
  97. birder/net/metaformer.py +5 -5
  98. birder/net/mim/crossmae.py +6 -6
  99. birder/net/mim/fcmae.py +3 -5
  100. birder/net/mim/mae_hiera.py +7 -7
  101. birder/net/mim/mae_vit.py +4 -6
  102. birder/net/mim/simmim.py +3 -4
  103. birder/net/mobilenet_v1.py +0 -9
  104. birder/net/mobilenet_v2.py +38 -44
  105. birder/net/{mobilenet_v3_large.py → mobilenet_v3.py} +37 -10
  106. birder/net/mobilenet_v4_hybrid.py +4 -4
  107. birder/net/mobileone.py +5 -12
  108. birder/net/mobilevit_v1.py +7 -34
  109. birder/net/mobilevit_v2.py +6 -54
  110. birder/net/moganet.py +8 -5
  111. birder/net/mvit_v2.py +30 -30
  112. birder/net/nextvit.py +2 -2
  113. birder/net/nfnet.py +4 -0
  114. birder/net/pit.py +11 -26
  115. birder/net/pvt_v1.py +9 -9
  116. birder/net/pvt_v2.py +10 -16
  117. birder/net/regionvit.py +15 -15
  118. birder/net/regnet.py +1 -1
  119. birder/net/repghost.py +5 -35
  120. birder/net/repvgg.py +3 -5
  121. birder/net/repvit.py +2 -2
  122. birder/net/resmlp.py +2 -2
  123. birder/net/resnest.py +4 -1
  124. birder/net/resnet_v1.py +125 -1
  125. birder/net/resnet_v2.py +75 -1
  126. birder/net/resnext.py +35 -1
  127. birder/net/rope_deit3.py +62 -151
  128. birder/net/rope_flexivit.py +46 -33
  129. birder/net/rope_vit.py +44 -758
  130. birder/net/sequencer2d.py +3 -4
  131. birder/net/shufflenet_v1.py +1 -1
  132. birder/net/shufflenet_v2.py +1 -1
  133. birder/net/simple_vit.py +69 -21
  134. birder/net/smt.py +8 -8
  135. birder/net/squeezenet.py +5 -12
  136. birder/net/squeezenext.py +0 -24
  137. birder/net/ssl/barlow_twins.py +1 -1
  138. birder/net/ssl/byol.py +2 -2
  139. birder/net/ssl/capi.py +4 -4
  140. birder/net/ssl/data2vec.py +1 -1
  141. birder/net/ssl/data2vec2.py +1 -1
  142. birder/net/ssl/dino_v2.py +13 -3
  143. birder/net/ssl/franca.py +28 -4
  144. birder/net/ssl/i_jepa.py +5 -5
  145. birder/net/ssl/ibot.py +1 -1
  146. birder/net/ssl/mmcr.py +1 -1
  147. birder/net/swiftformer.py +13 -3
  148. birder/net/swin_transformer_v1.py +4 -5
  149. birder/net/swin_transformer_v2.py +5 -8
  150. birder/net/tiny_vit.py +6 -19
  151. birder/net/transnext.py +19 -19
  152. birder/net/uniformer.py +4 -4
  153. birder/net/van.py +2 -2
  154. birder/net/vgg.py +1 -10
  155. birder/net/vit.py +72 -987
  156. birder/net/vit_parallel.py +35 -20
  157. birder/net/vit_sam.py +23 -48
  158. birder/net/vovnet_v2.py +1 -1
  159. birder/net/xcit.py +16 -13
  160. birder/ops/msda.py +4 -4
  161. birder/ops/swattention.py +10 -10
  162. birder/results/classification.py +3 -3
  163. birder/results/gui.py +8 -8
  164. birder/scripts/benchmark.py +37 -12
  165. birder/scripts/evaluate.py +1 -1
  166. birder/scripts/predict.py +3 -3
  167. birder/scripts/predict_detection.py +2 -2
  168. birder/scripts/train.py +63 -15
  169. birder/scripts/train_barlow_twins.py +10 -7
  170. birder/scripts/train_byol.py +10 -7
  171. birder/scripts/train_capi.py +15 -10
  172. birder/scripts/train_data2vec.py +10 -7
  173. birder/scripts/train_data2vec2.py +10 -7
  174. birder/scripts/train_detection.py +29 -14
  175. birder/scripts/train_dino_v1.py +13 -9
  176. birder/scripts/train_dino_v2.py +27 -14
  177. birder/scripts/train_dino_v2_dist.py +28 -15
  178. birder/scripts/train_franca.py +16 -9
  179. birder/scripts/train_i_jepa.py +12 -9
  180. birder/scripts/train_ibot.py +15 -11
  181. birder/scripts/train_kd.py +64 -17
  182. birder/scripts/train_mim.py +11 -8
  183. birder/scripts/train_mmcr.py +11 -8
  184. birder/scripts/train_rotnet.py +11 -7
  185. birder/scripts/train_simclr.py +10 -7
  186. birder/scripts/train_vicreg.py +10 -7
  187. birder/tools/adversarial.py +4 -4
  188. birder/tools/auto_anchors.py +5 -5
  189. birder/tools/avg_model.py +1 -1
  190. birder/tools/convert_model.py +30 -22
  191. birder/tools/det_results.py +1 -1
  192. birder/tools/download_model.py +1 -1
  193. birder/tools/ensemble_model.py +1 -1
  194. birder/tools/introspection.py +12 -3
  195. birder/tools/labelme_to_coco.py +2 -2
  196. birder/tools/model_info.py +15 -15
  197. birder/tools/pack.py +8 -8
  198. birder/tools/quantize_model.py +53 -4
  199. birder/tools/results.py +2 -2
  200. birder/tools/show_det_iterator.py +19 -6
  201. birder/tools/show_iterator.py +2 -2
  202. birder/tools/similarity.py +5 -5
  203. birder/tools/stats.py +4 -6
  204. birder/tools/voc_to_coco.py +1 -1
  205. birder/version.py +1 -1
  206. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
  207. birder-0.4.1.dist-info/RECORD +300 -0
  208. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
  209. birder/net/mobilenet_v3_small.py +0 -43
  210. birder/net/se_resnet_v1.py +0 -105
  211. birder/net/se_resnet_v2.py +0 -59
  212. birder/net/se_resnext.py +0 -30
  213. birder-0.3.3.dist-info/RECORD +0 -299
  214. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
  215. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
  216. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,852 @@
1
+ """
2
+ Plain DETR, adapted from
3
+ https://github.com/impiga/Plain-DETR
4
+
5
+ Paper "DETR Doesn't Need Multi-Scale or Locality Design", https://arxiv.org/abs/2308.01904
6
+
7
+ Changes from original:
8
+ * Move background index to first from last (to be inline with the rest of Birder detectors)
9
+ * Removed two stage support
10
+ * Only support pre-norm (original supports both pre- and post-norm)
11
+ """
12
+
13
+ # Reference license: MIT
14
+
15
+ import copy
16
+ import math
17
+ from typing import Any
18
+ from typing import Literal
19
+ from typing import Optional
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from torch import nn
24
+ from torchvision.ops import MLP
25
+ from torchvision.ops import boxes as box_ops
26
+ from torchvision.ops import sigmoid_focal_loss
27
+
28
+ from birder.common import training_utils
29
+ from birder.model_registry import registry
30
+ from birder.net.base import DetectorBackbone
31
+ from birder.net.detection.base import DetectionBaseNet
32
+ from birder.net.detection.deformable_detr import HungarianMatcher
33
+ from birder.net.detection.deformable_detr import inverse_sigmoid
34
+ from birder.net.detection.detr import PositionEmbeddingSine
35
+ from birder.ops.soft_nms import SoftNMS
36
+
37
+
38
+ def _get_clones(module: nn.Module, N: int) -> nn.ModuleList:
39
+ return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
40
+
41
+
42
+ class MultiheadAttention(nn.Module):
43
+ def __init__(self, d_model: int, num_heads: int, attn_drop: float = 0.0, proj_drop: float = 0.0) -> None:
44
+ super().__init__()
45
+ assert d_model % num_heads == 0, "d_model should be divisible by num_heads"
46
+
47
+ self.num_heads = num_heads
48
+ self.head_dim = d_model // num_heads
49
+ self.scale = self.head_dim**-0.5
50
+
51
+ self.q_proj = nn.Linear(d_model, d_model)
52
+ self.k_proj = nn.Linear(d_model, d_model)
53
+ self.v_proj = nn.Linear(d_model, d_model)
54
+ self.attn_drop = nn.Dropout(attn_drop)
55
+ self.proj = nn.Linear(d_model, d_model)
56
+ self.proj_drop = nn.Dropout(proj_drop)
57
+
58
+ self.reset_parameters()
59
+
60
+ def reset_parameters(self) -> None:
61
+ nn.init.xavier_uniform_(self.q_proj.weight)
62
+ nn.init.xavier_uniform_(self.k_proj.weight)
63
+ nn.init.xavier_uniform_(self.v_proj.weight)
64
+ nn.init.xavier_uniform_(self.proj.weight)
65
+ if self.q_proj.bias is not None:
66
+ nn.init.zeros_(self.q_proj.bias)
67
+ nn.init.zeros_(self.k_proj.bias)
68
+ nn.init.zeros_(self.v_proj.bias)
69
+ nn.init.zeros_(self.proj.bias)
70
+
71
+ def forward(
72
+ self,
73
+ query: torch.Tensor,
74
+ key: torch.Tensor,
75
+ value: torch.Tensor,
76
+ key_padding_mask: Optional[torch.Tensor] = None,
77
+ ) -> torch.Tensor:
78
+ B, l_q, C = query.size()
79
+ q = self.q_proj(query).reshape(B, l_q, self.num_heads, self.head_dim).transpose(1, 2)
80
+ k = self.k_proj(key).reshape(B, key.size(1), self.num_heads, self.head_dim).transpose(1, 2)
81
+ v = self.v_proj(value).reshape(B, value.size(1), self.num_heads, self.head_dim).transpose(1, 2)
82
+
83
+ if key_padding_mask is not None:
84
+ # key_padding_mask is expected to be boolean (True = masked)
85
+ # SDPA expects True = attend, so we invert
86
+ attn_mask = ~key_padding_mask[:, None, None, :]
87
+ else:
88
+ attn_mask = None
89
+
90
+ attn = F.scaled_dot_product_attention( # pylint: disable=not-callable
91
+ q, k, v, attn_mask=attn_mask, dropout_p=self.attn_drop.p if self.training else 0.0, scale=self.scale
92
+ )
93
+
94
+ attn = attn.transpose(1, 2).reshape(B, l_q, C)
95
+ x = self.proj(attn)
96
+ x = self.proj_drop(x)
97
+
98
+ return x
99
+
100
+
101
+ class GlobalCrossAttention(nn.Module):
102
+ """
103
+ Global cross-attention with Box-to-Pixel Relative Position Bias (BoxRPB)
104
+
105
+ This utilizes Box-to-Pixel Relative Position Bias (BoxRPB) to guide attention
106
+ using the spatial relationship between query boxes and image features.
107
+ The bias calculation is decomposed into axial (x and y) components.
108
+ """
109
+
110
+ def __init__(
111
+ self,
112
+ embed_dim: int,
113
+ num_heads: int,
114
+ dropout: float,
115
+ rpe_hidden_dim: int,
116
+ feature_stride: int,
117
+ rpe_type: Literal["linear", "log"],
118
+ ) -> None:
119
+ super().__init__()
120
+ self.embed_dim = embed_dim
121
+ self.num_heads = num_heads
122
+ self.head_dim = embed_dim // num_heads
123
+ self.scale = self.head_dim**-0.5
124
+ self.feature_stride = feature_stride
125
+ self.rpe_type = rpe_type
126
+
127
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
128
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
129
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
130
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
131
+
132
+ self.attn_drop = nn.Dropout(dropout)
133
+ self.proj_drop = nn.Dropout(dropout)
134
+
135
+ self.cpb_mlp_x = nn.Sequential(
136
+ nn.Linear(2, rpe_hidden_dim),
137
+ nn.ReLU(inplace=True),
138
+ nn.Linear(rpe_hidden_dim, num_heads, bias=False),
139
+ )
140
+ self.cpb_mlp_y = nn.Sequential(
141
+ nn.Linear(2, rpe_hidden_dim),
142
+ nn.ReLU(inplace=True),
143
+ nn.Linear(rpe_hidden_dim, num_heads, bias=False),
144
+ )
145
+
146
+ def forward(
147
+ self,
148
+ query: torch.Tensor,
149
+ key: torch.Tensor,
150
+ value: torch.Tensor,
151
+ reference_points: torch.Tensor,
152
+ spatial_shape: tuple[int, int],
153
+ key_padding_mask: Optional[torch.Tensor] = None,
154
+ ) -> torch.Tensor:
155
+ B, num_queries, _ = query.size()
156
+ H, W = spatial_shape
157
+
158
+ q = self.q_proj(query)
159
+ k = self.k_proj(key)
160
+ v = self.v_proj(value)
161
+
162
+ q = q.view(B, num_queries, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
163
+ k = k.view(B, H * W, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
164
+ v = v.view(B, H * W, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
165
+ q = q * self.scale
166
+
167
+ attn = q @ k.transpose(-2, -1)
168
+ rpe = self._compute_box_rpe(reference_points, H, W, query.device)
169
+ attn = attn + rpe
170
+
171
+ if key_padding_mask is not None:
172
+ attn = attn.masked_fill(key_padding_mask[:, None, None, :], float("-inf"))
173
+
174
+ attn = F.softmax(attn, dim=-1)
175
+ attn = self.attn_drop(attn)
176
+
177
+ out = attn @ v
178
+ out = out.permute(0, 2, 1, 3).reshape(B, num_queries, self.embed_dim)
179
+ out = self.out_proj(out)
180
+ out = self.proj_drop(out)
181
+
182
+ return out
183
+
184
+ # pylint: disable=too-many-locals
185
+ def _compute_box_rpe(self, reference_points: torch.Tensor, H: int, W: int, device: torch.device) -> torch.Tensor:
186
+ B, n_q, _ = reference_points.size()
187
+ stride = self.feature_stride
188
+
189
+ # cxcywh to xyxy
190
+ cx, cy, bw, bh = reference_points.unbind(-1)
191
+ x1 = cx - bw / 2
192
+ y1 = cy - bh / 2
193
+ x2 = cx + bw / 2
194
+ y2 = cy + bh / 2
195
+
196
+ # Scale to pixel coordinates
197
+ x1 = x1 * (W * stride)
198
+ y1 = y1 * (H * stride)
199
+ x2 = x2 * (W * stride)
200
+ y2 = y2 * (H * stride)
201
+
202
+ # Pixel grid (cell centers)
203
+ pos_x = torch.linspace(0.5, W - 0.5, W, device=device) * stride
204
+ pos_y = torch.linspace(0.5, H - 0.5, H, device=device) * stride
205
+
206
+ # Box edge to pixel distances
207
+ delta_x1 = x1[:, :, None] - pos_x[None, None, :]
208
+ delta_x2 = x2[:, :, None] - pos_x[None, None, :]
209
+ delta_y1 = y1[:, :, None] - pos_y[None, None, :]
210
+ delta_y2 = y2[:, :, None] - pos_y[None, None, :]
211
+
212
+ if self.rpe_type == "log":
213
+ delta_x1 = torch.sign(delta_x1) * torch.log2(torch.abs(delta_x1) + 1.0) / 3.0
214
+ delta_x2 = torch.sign(delta_x2) * torch.log2(torch.abs(delta_x2) + 1.0) / 3.0
215
+ delta_y1 = torch.sign(delta_y1) * torch.log2(torch.abs(delta_y1) + 1.0) / 3.0
216
+ delta_y2 = torch.sign(delta_y2) * torch.log2(torch.abs(delta_y2) + 1.0) / 3.0
217
+
218
+ delta_x = torch.stack([delta_x1, delta_x2], dim=-1)
219
+ delta_y = torch.stack([delta_y1, delta_y2], dim=-1)
220
+
221
+ rpe_x = self.cpb_mlp_x(delta_x)
222
+ rpe_y = self.cpb_mlp_y(delta_y)
223
+
224
+ # Axial decomposition: rpe[h,w] = rpe_y[h] + rpe_x[w]
225
+ rpe = rpe_y[:, :, :, None, :] + rpe_x[:, :, None, :, :]
226
+ rpe = rpe.reshape(B, n_q, H * W, self.num_heads)
227
+ rpe = rpe.permute(0, 3, 1, 2)
228
+
229
+ return rpe
230
+
231
+
232
+ class GlobalDecoderLayer(nn.Module):
233
+ """
234
+ Transformer decoder layer with global cross-attention and BoxRPB
235
+ """
236
+
237
+ def __init__(
238
+ self,
239
+ d_model: int,
240
+ num_heads: int,
241
+ dim_feedforward: int,
242
+ dropout: float,
243
+ rpe_hidden_dim: int,
244
+ feature_stride: int,
245
+ rpe_type: Literal["linear", "log"],
246
+ ) -> None:
247
+ super().__init__()
248
+
249
+ self.self_attn = MultiheadAttention(d_model, num_heads, attn_drop=dropout)
250
+ self.cross_attn = GlobalCrossAttention(
251
+ embed_dim=d_model,
252
+ num_heads=num_heads,
253
+ dropout=dropout,
254
+ rpe_hidden_dim=rpe_hidden_dim,
255
+ feature_stride=feature_stride,
256
+ rpe_type=rpe_type,
257
+ )
258
+
259
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
260
+ self.dropout = nn.Dropout(dropout)
261
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
262
+
263
+ self.norm1 = nn.LayerNorm(d_model)
264
+ self.norm2 = nn.LayerNorm(d_model)
265
+ self.norm3 = nn.LayerNorm(d_model)
266
+ self.dropout1 = nn.Dropout(dropout)
267
+ self.dropout2 = nn.Dropout(dropout)
268
+ self.dropout3 = nn.Dropout(dropout)
269
+
270
+ self.activation = nn.ReLU()
271
+
272
+ def forward(
273
+ self,
274
+ tgt: torch.Tensor,
275
+ memory: torch.Tensor,
276
+ query_pos: torch.Tensor,
277
+ memory_pos: torch.Tensor,
278
+ reference_points: torch.Tensor,
279
+ spatial_shape: tuple[int, int],
280
+ memory_key_padding_mask: Optional[torch.Tensor] = None,
281
+ ) -> torch.Tensor:
282
+ tgt2 = self.norm1(tgt)
283
+ qk = tgt2 + query_pos
284
+ tgt2 = self.self_attn(qk, qk, tgt2)
285
+ tgt = tgt + self.dropout1(tgt2)
286
+
287
+ tgt2 = self.cross_attn(
288
+ query=self.norm2(tgt) + query_pos,
289
+ key=memory + memory_pos,
290
+ value=memory,
291
+ reference_points=reference_points,
292
+ spatial_shape=spatial_shape,
293
+ key_padding_mask=memory_key_padding_mask,
294
+ )
295
+ tgt = tgt + self.dropout2(tgt2)
296
+
297
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(self.norm3(tgt)))))
298
+ tgt = tgt + self.dropout3(tgt2)
299
+
300
+ return tgt
301
+
302
+
303
+ class GlobalDecoder(nn.Module):
304
+ def __init__(
305
+ self, decoder_layer: nn.Module, num_layers: int, norm: nn.Module, return_intermediate: bool, d_model: int
306
+ ) -> None:
307
+ super().__init__()
308
+ self.layers = _get_clones(decoder_layer, num_layers)
309
+ self.num_layers = num_layers
310
+ self.norm = norm
311
+ self.return_intermediate = return_intermediate
312
+ self.d_model = d_model
313
+
314
+ self.bbox_embed: Optional[nn.ModuleList] = None
315
+ self.class_embed: Optional[nn.ModuleList] = None
316
+
317
+ self.reset_parameters()
318
+
319
+ def reset_parameters(self) -> None:
320
+ for m in self.modules():
321
+ if isinstance(m, nn.Linear):
322
+ nn.init.trunc_normal_(m.weight, std=0.02)
323
+ if m.bias is not None:
324
+ nn.init.zeros_(m.bias)
325
+ elif isinstance(m, nn.LayerNorm):
326
+ nn.init.zeros_(m.bias)
327
+ nn.init.ones_(m.weight)
328
+
329
+ for m in self.modules():
330
+ if m is not self and hasattr(m, "reset_parameters") is True and callable(m.reset_parameters) is True:
331
+ m.reset_parameters()
332
+
333
+ def forward(
334
+ self,
335
+ tgt: torch.Tensor,
336
+ memory: torch.Tensor,
337
+ query_pos: torch.Tensor,
338
+ memory_pos: torch.Tensor,
339
+ reference_points: torch.Tensor,
340
+ spatial_shape: tuple[int, int],
341
+ memory_key_padding_mask: Optional[torch.Tensor] = None,
342
+ ) -> tuple[torch.Tensor, torch.Tensor]:
343
+ output = tgt
344
+ intermediate = []
345
+ intermediate_reference_points = []
346
+
347
+ if self.bbox_embed is not None:
348
+ for layer, bbox_embed in zip(self.layers, self.bbox_embed):
349
+ reference_points_input = reference_points.detach().clamp(0, 1)
350
+
351
+ output = layer(
352
+ output,
353
+ memory,
354
+ query_pos=query_pos,
355
+ memory_pos=memory_pos,
356
+ reference_points=reference_points_input,
357
+ spatial_shape=spatial_shape,
358
+ memory_key_padding_mask=memory_key_padding_mask,
359
+ )
360
+
361
+ output_for_pred = self.norm(output)
362
+ tmp = bbox_embed(output_for_pred)
363
+ new_reference_points = tmp + inverse_sigmoid(reference_points)
364
+ new_reference_points = new_reference_points.sigmoid()
365
+ reference_points = new_reference_points.detach()
366
+
367
+ if self.return_intermediate is True:
368
+ intermediate.append(output_for_pred)
369
+ intermediate_reference_points.append(new_reference_points)
370
+
371
+ if self.return_intermediate is True:
372
+ return torch.stack(intermediate), torch.stack(intermediate_reference_points)
373
+
374
+ return output_for_pred.unsqueeze(0), new_reference_points.unsqueeze(0)
375
+
376
+ for layer in self.layers:
377
+ reference_points_input = reference_points.detach().clamp(0, 1)
378
+
379
+ output = layer(
380
+ output,
381
+ memory,
382
+ query_pos=query_pos,
383
+ memory_pos=memory_pos,
384
+ reference_points=reference_points_input,
385
+ spatial_shape=spatial_shape,
386
+ memory_key_padding_mask=memory_key_padding_mask,
387
+ )
388
+
389
+ output_for_pred = self.norm(output)
390
+
391
+ if self.return_intermediate is True:
392
+ intermediate.append(output_for_pred)
393
+ intermediate_reference_points.append(reference_points)
394
+
395
+ if self.return_intermediate is True:
396
+ return torch.stack(intermediate), torch.stack(intermediate_reference_points)
397
+
398
+ return output_for_pred.unsqueeze(0), reference_points.unsqueeze(0)
399
+
400
+
401
+ class TransformerEncoderLayer(nn.Module):
402
+ def __init__(self, d_model: int, num_heads: int, dim_feedforward: int, dropout: float) -> None:
403
+ super().__init__()
404
+ self.self_attn = MultiheadAttention(d_model, num_heads, attn_drop=dropout)
405
+
406
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
407
+ self.dropout = nn.Dropout(dropout)
408
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
409
+
410
+ self.norm1 = nn.LayerNorm(d_model)
411
+ self.norm2 = nn.LayerNorm(d_model)
412
+ self.dropout1 = nn.Dropout(dropout)
413
+ self.dropout2 = nn.Dropout(dropout)
414
+
415
+ self.activation = nn.ReLU()
416
+
417
+ def forward(
418
+ self, src: torch.Tensor, pos: torch.Tensor, src_key_padding_mask: Optional[torch.Tensor] = None
419
+ ) -> torch.Tensor:
420
+ src2 = self.norm1(src)
421
+ q = src2 + pos
422
+ k = src2 + pos
423
+
424
+ src2 = self.self_attn(q, k, src2, key_padding_mask=src_key_padding_mask)
425
+ src = src + self.dropout1(src2)
426
+
427
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(self.norm2(src)))))
428
+ src = src + self.dropout2(src2)
429
+
430
+ return src
431
+
432
+
433
+ class TransformerEncoder(nn.Module):
434
+ def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
435
+ super().__init__()
436
+ self.layers = _get_clones(encoder_layer, num_layers)
437
+
438
+ def forward(self, x: torch.Tensor, pos: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
439
+ out = x
440
+ for layer in self.layers:
441
+ out = layer(out, pos=pos, src_key_padding_mask=mask)
442
+
443
+ return out
444
+
445
+
446
+ # pylint: disable=invalid-name
447
+ class Plain_DETR(DetectionBaseNet):
448
+ default_size = (640, 640)
449
+ block_group_regex = r"encoder\.layers\.(\d+)|decoder\.layers\.(\d+)"
450
+
451
+ # pylint: disable=too-many-locals
452
+ def __init__(
453
+ self,
454
+ num_classes: int,
455
+ backbone: DetectorBackbone,
456
+ *,
457
+ config: Optional[dict[str, Any]] = None,
458
+ size: Optional[tuple[int, int]] = None,
459
+ export_mode: bool = False,
460
+ ) -> None:
461
+ super().__init__(num_classes, backbone, config=config, size=size, export_mode=export_mode)
462
+ assert self.config is not None, "must set config"
463
+
464
+ # Sigmoid based classification (like multi-label networks)
465
+ self.num_classes = self.num_classes - 1
466
+
467
+ hidden_dim = 256
468
+ num_heads = 8
469
+ dropout = 0.0
470
+ return_intermediate = True
471
+ dim_feedforward: int = self.config.get("dim_feedforward", 2048)
472
+ num_encoder_layers: int = self.config["num_encoder_layers"]
473
+ num_decoder_layers: int = self.config["num_decoder_layers"]
474
+ num_queries_one2one: int = self.config.get("num_queries_one2one", 300)
475
+ num_queries_one2many: int = self.config.get("num_queries_one2many", 0)
476
+ k_one2many: int = self.config.get("k_one2many", 6)
477
+ lambda_one2many: float = self.config.get("lambda_one2many", 1.0)
478
+ rpe_hidden_dim: int = self.config.get("rpe_hidden_dim", 512)
479
+ rpe_type: Literal["linear", "log"] = self.config.get("rpe_type", "linear")
480
+ box_refine: bool = self.config.get("box_refine", True)
481
+ soft_nms: bool = self.config.get("soft_nms", False)
482
+
483
+ self.soft_nms = None
484
+ if soft_nms is True:
485
+ self.soft_nms = SoftNMS()
486
+
487
+ self.hidden_dim = hidden_dim
488
+ self.num_queries_one2one = num_queries_one2one
489
+ self.num_queries_one2many = num_queries_one2many
490
+ self.k_one2many = k_one2many
491
+ self.lambda_one2many = lambda_one2many
492
+ self.box_refine = box_refine
493
+ self.num_queries = self.num_queries_one2one + self.num_queries_one2many
494
+ if hasattr(self.backbone, "max_stride") is True:
495
+ self.feature_stride = self.backbone.max_stride
496
+ else:
497
+ self.feature_stride = 32
498
+
499
+ if num_encoder_layers == 0:
500
+ self.encoder = None
501
+ else:
502
+ encoder_layer = TransformerEncoderLayer(hidden_dim, num_heads, dim_feedforward, dropout)
503
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers)
504
+
505
+ decoder_layer = GlobalDecoderLayer(
506
+ hidden_dim,
507
+ num_heads=num_heads,
508
+ dim_feedforward=dim_feedforward,
509
+ dropout=dropout,
510
+ rpe_hidden_dim=rpe_hidden_dim,
511
+ feature_stride=self.feature_stride,
512
+ rpe_type=rpe_type,
513
+ )
514
+ decoder_norm = nn.LayerNorm(hidden_dim)
515
+ self.decoder = GlobalDecoder(
516
+ decoder_layer,
517
+ num_decoder_layers,
518
+ decoder_norm,
519
+ return_intermediate=return_intermediate,
520
+ d_model=hidden_dim,
521
+ )
522
+
523
+ self.class_embed = nn.Linear(hidden_dim, self.num_classes)
524
+ self.bbox_embed = MLP(hidden_dim, [hidden_dim, hidden_dim, 4], activation_layer=nn.ReLU)
525
+ self.query_embed = nn.Embedding(self.num_queries, hidden_dim * 2)
526
+ self.reference_point_head = MLP(hidden_dim, [hidden_dim, hidden_dim, 4], activation_layer=nn.ReLU)
527
+ self.input_proj = nn.Conv2d(
528
+ self.backbone.return_channels[-1], hidden_dim, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)
529
+ )
530
+ self.pos_enc = PositionEmbeddingSine(hidden_dim // 2, normalize=True)
531
+ self.matcher = HungarianMatcher(cost_class=2, cost_bbox=5, cost_giou=2)
532
+
533
+ if box_refine is True:
534
+ self.class_embed = _get_clones(self.class_embed, num_decoder_layers)
535
+ self.bbox_embed = _get_clones(self.bbox_embed, num_decoder_layers)
536
+ self.decoder.bbox_embed = self.bbox_embed
537
+ else:
538
+ self.class_embed = nn.ModuleList([self.class_embed for _ in range(num_decoder_layers)])
539
+ self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(num_decoder_layers)])
540
+
541
+ if self.export_mode is False:
542
+ self.forward = torch.compiler.disable(recursive=False)(self.forward) # type: ignore[method-assign]
543
+
544
+ # Weights initialization
545
+ prior_prob = 0.01
546
+ bias_value = -math.log((1 - prior_prob) / prior_prob)
547
+ for class_embed in self.class_embed:
548
+ nn.init.constant_(class_embed.bias, bias_value)
549
+
550
+ for idx, bbox_embed in enumerate(self.bbox_embed):
551
+ last_linear = [m for m in bbox_embed.modules() if isinstance(m, nn.Linear)][-1]
552
+ nn.init.zeros_(last_linear.weight)
553
+ nn.init.zeros_(last_linear.bias)
554
+ if idx == 0:
555
+ nn.init.constant_(last_linear.bias[2:], -2.0) # Small initial wh
556
+
557
+ ref_last_linear = [m for m in self.reference_point_head.modules() if isinstance(m, nn.Linear)][-1]
558
+ nn.init.zeros_(ref_last_linear.weight)
559
+ nn.init.zeros_(ref_last_linear.bias)
560
+
561
+ def reset_classifier(self, num_classes: int) -> None:
562
+ self.num_classes = num_classes
563
+ num_decoder_layers = len(self.class_embed)
564
+ self.class_embed = nn.ModuleList([nn.Linear(self.hidden_dim, num_classes) for _ in range(num_decoder_layers)])
565
+
566
+ prior_prob = 0.01
567
+ bias_value = -math.log((1 - prior_prob) / prior_prob)
568
+ for class_embed in self.class_embed:
569
+ nn.init.constant_(class_embed.bias, bias_value)
570
+
571
+ def freeze(self, freeze_classifier: bool = True) -> None:
572
+ for param in self.parameters():
573
+ param.requires_grad_(False)
574
+
575
+ if freeze_classifier is False:
576
+ for param in self.class_embed.parameters():
577
+ param.requires_grad_(True)
578
+
579
+ def _get_src_permutation_idx(self, indices: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
580
+ batch_idx = torch.concat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
581
+ src_idx = torch.concat([src for (src, _) in indices])
582
+ return (batch_idx, src_idx)
583
+
584
+ def _class_loss(
585
+ self,
586
+ cls_logits: torch.Tensor,
587
+ targets: list[dict[str, torch.Tensor]],
588
+ indices: list[torch.Tensor],
589
+ num_boxes: int,
590
+ ) -> torch.Tensor:
591
+ idx = self._get_src_permutation_idx(indices)
592
+ target_classes_o = torch.concat([t["labels"][J] for t, (_, J) in zip(targets, indices)], dim=0)
593
+
594
+ target_classes_onehot = torch.zeros(
595
+ cls_logits.size(0),
596
+ cls_logits.size(1),
597
+ cls_logits.size(2) + 1,
598
+ dtype=cls_logits.dtype,
599
+ device=cls_logits.device,
600
+ )
601
+ target_classes_onehot[idx[0], idx[1], target_classes_o] = 1
602
+ target_classes_onehot = target_classes_onehot[:, :, :-1]
603
+
604
+ loss = sigmoid_focal_loss(cls_logits, target_classes_onehot, alpha=0.25, gamma=2.0)
605
+ loss_ce = (loss.mean(1).sum() / num_boxes) * cls_logits.size(1)
606
+
607
+ return loss_ce
608
+
609
+ def _box_loss(
610
+ self,
611
+ box_output: torch.Tensor,
612
+ targets: list[dict[str, torch.Tensor]],
613
+ indices: list[torch.Tensor],
614
+ num_boxes: int,
615
+ ) -> tuple[torch.Tensor, torch.Tensor]:
616
+ idx = self._get_src_permutation_idx(indices)
617
+ src_boxes = box_output[idx]
618
+ target_boxes = torch.concat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
619
+
620
+ loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none")
621
+ loss_bbox = loss_bbox.sum() / num_boxes
622
+
623
+ loss_giou = 1 - torch.diag(
624
+ box_ops.generalized_box_iou(
625
+ box_ops.box_convert(src_boxes, in_fmt="cxcywh", out_fmt="xyxy"),
626
+ box_ops.box_convert(target_boxes, in_fmt="cxcywh", out_fmt="xyxy"),
627
+ )
628
+ )
629
+ loss_giou = loss_giou.sum() / num_boxes
630
+
631
+ return (loss_bbox, loss_giou)
632
+
633
+ @torch.jit.unused # type: ignore[untyped-decorator]
634
+ @torch.compiler.disable() # type: ignore[untyped-decorator]
635
+ def compute_loss(
636
+ self,
637
+ targets: list[dict[str, torch.Tensor]],
638
+ cls_logits: torch.Tensor,
639
+ box_output: torch.Tensor,
640
+ cls_logits_one2many: Optional[torch.Tensor] = None,
641
+ box_output_one2many: Optional[torch.Tensor] = None,
642
+ ) -> dict[str, torch.Tensor]:
643
+ # Compute the average number of target boxes across all nodes, for normalization purposes
644
+ num_boxes = sum(len(t["labels"]) for t in targets)
645
+ num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=cls_logits.device)
646
+ if training_utils.is_dist_available_and_initialized() is True:
647
+ torch.distributed.all_reduce(num_boxes)
648
+
649
+ num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1).item()
650
+
651
+ loss_ce_list = []
652
+ loss_bbox_list = []
653
+ loss_giou_list = []
654
+ for idx in range(cls_logits.size(0)):
655
+ indices = self.matcher(cls_logits[idx], box_output[idx], targets)
656
+ loss_ce_i = self._class_loss(cls_logits[idx], targets, indices, num_boxes)
657
+ loss_bbox_i, loss_giou_i = self._box_loss(box_output[idx], targets, indices, num_boxes)
658
+ loss_ce_list.append(loss_ce_i)
659
+ loss_bbox_list.append(loss_bbox_i)
660
+ loss_giou_list.append(loss_giou_i)
661
+
662
+ loss_ce = torch.stack(loss_ce_list).sum() * 2
663
+ loss_bbox = torch.stack(loss_bbox_list).sum() * 5
664
+ loss_giou = torch.stack(loss_giou_list).sum() * 2
665
+
666
+ # One2many loss (hybrid matching)
667
+ if cls_logits_one2many is not None and box_output_one2many is not None:
668
+ targets_one2many = [
669
+ {"boxes": t["boxes"].repeat(self.k_one2many, 1), "labels": t["labels"].repeat(self.k_one2many)}
670
+ for t in targets
671
+ ]
672
+ num_boxes_one2many = num_boxes * self.k_one2many
673
+
674
+ loss_ce_list_one2many = []
675
+ loss_bbox_list_one2many = []
676
+ loss_giou_list_one2many = []
677
+ for idx in range(cls_logits_one2many.size(0)):
678
+ indices = self.matcher(cls_logits_one2many[idx], box_output_one2many[idx], targets_one2many)
679
+ loss_ce_i = self._class_loss(cls_logits_one2many[idx], targets_one2many, indices, num_boxes_one2many)
680
+ loss_bbox_i, loss_giou_i = self._box_loss(
681
+ box_output_one2many[idx], targets_one2many, indices, num_boxes_one2many
682
+ )
683
+ loss_ce_list_one2many.append(loss_ce_i)
684
+ loss_bbox_list_one2many.append(loss_bbox_i)
685
+ loss_giou_list_one2many.append(loss_giou_i)
686
+
687
+ loss_ce += torch.stack(loss_ce_list_one2many).sum() * 2 * self.lambda_one2many
688
+ loss_bbox += torch.stack(loss_bbox_list_one2many).sum() * 5 * self.lambda_one2many
689
+ loss_giou += torch.stack(loss_giou_list_one2many).sum() * 2 * self.lambda_one2many
690
+
691
+ losses = {
692
+ "labels": loss_ce,
693
+ "boxes": loss_bbox,
694
+ "giou": loss_giou,
695
+ }
696
+
697
+ return losses
698
+
699
+ def postprocess_detections(
700
+ self, class_logits: torch.Tensor, box_regression: torch.Tensor, image_shapes: list[tuple[int, int]]
701
+ ) -> list[dict[str, torch.Tensor]]:
702
+ prob = class_logits.sigmoid()
703
+ scores, labels = prob.max(-1)
704
+ labels = labels + 1 # Background offset
705
+
706
+ # TorchScript doesn't support creating tensor from tuples, convert everything to lists
707
+ target_sizes = torch.tensor([list(s) for s in image_shapes], device=class_logits.device)
708
+
709
+ # Convert to [x0, y0, x1, y1] format
710
+ boxes = box_ops.box_convert(box_regression, in_fmt="cxcywh", out_fmt="xyxy")
711
+
712
+ # Convert from relative [0, 1] to absolute [0, height] coordinates
713
+ img_h, img_w = target_sizes.unbind(1)
714
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
715
+ boxes = boxes * scale_fct[:, None, :]
716
+
717
+ detections: list[dict[str, torch.Tensor]] = []
718
+ for s, l, b in zip(scores, labels, boxes):
719
+ # Non-maximum suppression
720
+ if self.soft_nms is not None:
721
+ soft_scores, keep = self.soft_nms(b, s, l, score_threshold=0.001)
722
+ s[keep] = soft_scores
723
+
724
+ b = b[keep]
725
+ s = s[keep]
726
+ l = l[keep] # noqa: E741
727
+
728
+ detections.append(
729
+ {
730
+ "boxes": b,
731
+ "scores": s,
732
+ "labels": l,
733
+ }
734
+ )
735
+
736
+ return detections
737
+
738
+ # pylint: disable=too-many-locals
739
+ def forward(
740
+ self,
741
+ x: torch.Tensor,
742
+ targets: Optional[list[dict[str, torch.Tensor]]] = None,
743
+ masks: Optional[torch.Tensor] = None,
744
+ image_sizes: Optional[list[list[int]]] = None,
745
+ ) -> tuple[list[dict[str, torch.Tensor]], dict[str, torch.Tensor]]:
746
+ self._input_check(targets)
747
+ images = self._to_img_list(x, image_sizes)
748
+
749
+ features: dict[str, torch.Tensor] = self.backbone.detection_features(x)
750
+ src = features[self.backbone.return_stages[-1]]
751
+ src = self.input_proj(src)
752
+ B, _, H, W = src.size()
753
+
754
+ if masks is not None:
755
+ masks = F.interpolate(masks[None].float(), size=(H, W), mode="nearest").to(torch.bool)[0]
756
+ mask_flatten = masks.flatten(1)
757
+ else:
758
+ mask_flatten = None
759
+
760
+ pos = self.pos_enc(src, masks)
761
+ src = src.flatten(2).permute(0, 2, 1)
762
+ pos = pos.flatten(2).permute(0, 2, 1)
763
+
764
+ if self.encoder is not None:
765
+ memory = self.encoder(src, pos=pos, mask=mask_flatten)
766
+ else:
767
+ memory = src
768
+
769
+ # Use all queries during training, only one2one during inference
770
+ if self.training is True and self.num_queries_one2many > 0:
771
+ num_queries_to_use = self.num_queries_one2one + self.num_queries_one2many
772
+ else:
773
+ num_queries_to_use = self.num_queries_one2one
774
+
775
+ query_embed = self.query_embed.weight[:num_queries_to_use]
776
+ query_embed, query_pos = torch.split(query_embed, self.hidden_dim, dim=1)
777
+ query_embed = query_embed.unsqueeze(0).expand(B, -1, -1)
778
+ query_pos = query_pos.unsqueeze(0).expand(B, -1, -1)
779
+
780
+ reference_points = self.reference_point_head(query_pos).sigmoid()
781
+
782
+ hs, inter_references = self.decoder(
783
+ tgt=query_embed,
784
+ memory=memory,
785
+ query_pos=query_pos,
786
+ memory_pos=pos,
787
+ reference_points=reference_points,
788
+ spatial_shape=(H, W),
789
+ memory_key_padding_mask=mask_flatten,
790
+ )
791
+
792
+ outputs_classes = []
793
+ outputs_coords = []
794
+ for lvl, (class_embed, bbox_embed) in enumerate(zip(self.class_embed, self.bbox_embed)):
795
+ outputs_class = class_embed(hs[lvl])
796
+ outputs_classes.append(outputs_class)
797
+
798
+ if self.box_refine is True:
799
+ outputs_coord = inter_references[lvl]
800
+ else:
801
+ tmp = bbox_embed(hs[lvl])
802
+ tmp = tmp + inverse_sigmoid(reference_points)
803
+ outputs_coord = tmp.sigmoid()
804
+
805
+ outputs_coords.append(outputs_coord)
806
+
807
+ outputs_class = torch.stack(outputs_classes)
808
+ outputs_coord = torch.stack(outputs_coords)
809
+
810
+ losses = {}
811
+ detections: list[dict[str, torch.Tensor]] = []
812
+ if self.training is True:
813
+ assert targets is not None, "targets should not be none when in training mode"
814
+
815
+ for idx, target in enumerate(targets):
816
+ boxes = target["boxes"]
817
+ boxes = box_ops.box_convert(boxes, in_fmt="xyxy", out_fmt="cxcywh")
818
+ boxes = boxes / torch.tensor(images.image_sizes[idx][::-1] * 2, dtype=torch.float32, device=x.device)
819
+ targets[idx]["boxes"] = boxes
820
+ targets[idx]["labels"] = target["labels"] - 1 # No background
821
+
822
+ # Split outputs for one2one and one2many
823
+ outputs_class_one2one = outputs_class[:, :, : self.num_queries_one2one]
824
+ outputs_coord_one2one = outputs_coord[:, :, : self.num_queries_one2one]
825
+
826
+ if self.num_queries_one2many > 0:
827
+ outputs_class_one2many = outputs_class[:, :, self.num_queries_one2one :]
828
+ outputs_coord_one2many = outputs_coord[:, :, self.num_queries_one2one :]
829
+ else:
830
+ outputs_class_one2many = None
831
+ outputs_coord_one2many = None
832
+
833
+ losses = self.compute_loss(
834
+ targets, outputs_class_one2one, outputs_coord_one2one, outputs_class_one2many, outputs_coord_one2many
835
+ )
836
+
837
+ else:
838
+ detections = self.postprocess_detections(outputs_class[-1], outputs_coord[-1], images.image_sizes)
839
+
840
+ return (detections, losses)
841
+
842
+
843
+ registry.register_model_config(
844
+ "plain_detr_lite",
845
+ Plain_DETR,
846
+ config={"num_encoder_layers": 1, "num_decoder_layers": 3, "box_refine": False},
847
+ )
848
+ registry.register_model_config(
849
+ "plain_detr",
850
+ Plain_DETR,
851
+ config={"num_encoder_layers": 0, "num_decoder_layers": 6, "num_queries_one2many": 1500},
852
+ )