dgenerate-ultralytics-headless 8.3.236__py3-none-any.whl → 8.3.237__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {dgenerate_ultralytics_headless-8.3.236.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/METADATA +1 -1
  2. {dgenerate_ultralytics_headless-8.3.236.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/RECORD +38 -25
  3. ultralytics/__init__.py +1 -1
  4. ultralytics/engine/exporter.py +17 -10
  5. ultralytics/engine/predictor.py +3 -2
  6. ultralytics/engine/trainer.py +8 -0
  7. ultralytics/models/rtdetr/val.py +5 -1
  8. ultralytics/models/sam/__init__.py +14 -1
  9. ultralytics/models/sam/build.py +17 -8
  10. ultralytics/models/sam/build_sam3.py +374 -0
  11. ultralytics/models/sam/model.py +12 -4
  12. ultralytics/models/sam/modules/blocks.py +20 -8
  13. ultralytics/models/sam/modules/decoders.py +2 -3
  14. ultralytics/models/sam/modules/encoders.py +4 -1
  15. ultralytics/models/sam/modules/memory_attention.py +6 -2
  16. ultralytics/models/sam/modules/sam.py +150 -6
  17. ultralytics/models/sam/modules/utils.py +134 -4
  18. ultralytics/models/sam/predict.py +2076 -118
  19. ultralytics/models/sam/sam3/__init__.py +3 -0
  20. ultralytics/models/sam/sam3/decoder.py +546 -0
  21. ultralytics/models/sam/sam3/encoder.py +535 -0
  22. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  23. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  24. ultralytics/models/sam/sam3/model_misc.py +198 -0
  25. ultralytics/models/sam/sam3/necks.py +129 -0
  26. ultralytics/models/sam/sam3/sam3_image.py +357 -0
  27. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  28. ultralytics/models/sam/sam3/tokenizer_ve.py +242 -0
  29. ultralytics/models/sam/sam3/vitdet.py +546 -0
  30. ultralytics/models/sam/sam3/vl_combiner.py +165 -0
  31. ultralytics/models/yolo/obb/val.py +18 -7
  32. ultralytics/nn/modules/transformer.py +21 -1
  33. ultralytics/utils/checks.py +2 -2
  34. ultralytics/utils/ops.py +1 -3
  35. {dgenerate_ultralytics_headless-8.3.236.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/WHEEL +0 -0
  36. {dgenerate_ultralytics_headless-8.3.236.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/entry_points.txt +0 -0
  37. {dgenerate_ultralytics_headless-8.3.236.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/licenses/LICENSE +0 -0
  38. {dgenerate_ultralytics_headless-8.3.236.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,3 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
@@ -0,0 +1,546 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
4
+ """
5
+ Transformer decoder.
6
+ Inspired from Pytorch's version, adds the pre-norm variant.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import numpy as np
12
+ import torch
13
+ from torch import nn
14
+ from torchvision.ops.roi_align import RoIAlign
15
+
16
+ from ultralytics.nn.modules.transformer import MLP
17
+ from ultralytics.nn.modules.utils import _get_clones, inverse_sigmoid
18
+ from ultralytics.utils.ops import xywh2xyxy
19
+
20
+ from .model_misc import gen_sineembed_for_position
21
+
22
+
23
+ class TransformerDecoderLayer(nn.Module):
24
+ """TransformerDecoderLayer is made up of self-attn, cross-attn, and feedforward network (FFN)."""
25
+
26
+ def __init__(
27
+ self,
28
+ d_model: int,
29
+ dim_feedforward: int,
30
+ dropout: float,
31
+ cross_attention: nn.Module,
32
+ n_heads: int,
33
+ use_text_cross_attention: bool = False,
34
+ ):
35
+ """Initialize the TransformerDecoderLayer."""
36
+ super().__init__()
37
+ # cross attention
38
+ self.cross_attn = cross_attention
39
+ self.dropout1 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
40
+ self.norm1 = nn.LayerNorm(d_model)
41
+
42
+ # cross attention text
43
+ self.use_text_cross_attention = use_text_cross_attention
44
+ if use_text_cross_attention:
45
+ self.ca_text = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
46
+ self.catext_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
47
+ self.catext_norm = nn.LayerNorm(d_model)
48
+
49
+ # self attention
50
+ self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
51
+ self.dropout2 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
52
+ self.norm2 = nn.LayerNorm(d_model)
53
+
54
+ # ffn
55
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
56
+ self.activation = nn.ReLU()
57
+ self.dropout3 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
58
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
59
+ self.dropout4 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
60
+ self.norm3 = nn.LayerNorm(d_model)
61
+
62
+ @staticmethod
63
+ def with_pos_embed(tensor, pos):
64
+ """Add positional embedding to the tensor."""
65
+ return tensor if pos is None else tensor + pos
66
+
67
+ def forward_ffn(self, tgt):
68
+ """Feedforward network forward pass."""
69
+ tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
70
+ tgt = tgt + self.dropout4(tgt2)
71
+ tgt = self.norm3(tgt)
72
+ return tgt
73
+
74
+ def forward(
75
+ self,
76
+ # for tgt
77
+ tgt: torch.Tensor, # nq, bs, d_model
78
+ tgt_query_pos: torch.Tensor = None, # pos for query. MLP(Sine(pos))
79
+ memory_text: torch.Tensor = None, # num_token, bs, d_model
80
+ text_attention_mask: torch.Tensor = None, # bs, num_token
81
+ # for memory
82
+ memory: torch.Tensor = None, # hw, bs, d_model
83
+ memory_key_padding_mask: torch.Tensor = None,
84
+ memory_pos: torch.Tensor = None, # pos for memory
85
+ # sa
86
+ self_attn_mask: torch.Tensor = None, # mask used for self-attention
87
+ cross_attn_mask: torch.Tensor = None, # mask used for cross-attention
88
+ # dac
89
+ dac=False,
90
+ dac_use_selfatt_ln=True,
91
+ presence_token=None,
92
+ # skip inside deformable attn
93
+ **kwargs, # additional kwargs for compatibility
94
+ ):
95
+ """Input: - tgt/tgt_query_pos: nq, bs, d_model. -."""
96
+ # self attention
97
+ tgt, tgt_query_pos = self._apply_self_attention(
98
+ tgt, tgt_query_pos, dac, dac_use_selfatt_ln, presence_token, self_attn_mask
99
+ )
100
+
101
+ if self.use_text_cross_attention:
102
+ tgt2 = self.ca_text(
103
+ self.with_pos_embed(tgt, tgt_query_pos),
104
+ memory_text.to(tgt.dtype),
105
+ memory_text.to(tgt.dtype),
106
+ key_padding_mask=text_attention_mask,
107
+ )[0]
108
+ tgt = tgt + self.catext_dropout(tgt2)
109
+ tgt = self.catext_norm(tgt)
110
+
111
+ if presence_token is not None:
112
+ presence_token_mask = torch.zeros_like(cross_attn_mask[:, :1, :])
113
+ cross_attn_mask = torch.cat([presence_token_mask, cross_attn_mask], dim=1) # (bs*nheads, 1+nq, hw)
114
+
115
+ # Cross attention to image
116
+ tgt2 = self.cross_attn(
117
+ query=self.with_pos_embed(tgt, tgt_query_pos),
118
+ key=self.with_pos_embed(memory, memory_pos),
119
+ value=memory,
120
+ attn_mask=cross_attn_mask,
121
+ key_padding_mask=(memory_key_padding_mask.transpose(0, 1) if memory_key_padding_mask is not None else None),
122
+ need_weights=False,
123
+ )[0]
124
+
125
+ tgt = tgt + self.dropout1(tgt2)
126
+ tgt = self.norm1(tgt)
127
+
128
+ # ffn
129
+ tgt = self.forward_ffn(tgt.to(memory.dtype))
130
+
131
+ presence_token_out = None
132
+ if presence_token is not None:
133
+ presence_token_out = tgt[:1]
134
+ tgt = tgt[1:]
135
+
136
+ return tgt, presence_token_out
137
+
138
+ def _apply_self_attention(self, tgt, tgt_query_pos, dac, dac_use_selfatt_ln, presence_token, self_attn_mask):
139
+ """Apply self-attention with optional DAC splitting."""
140
+ if self.self_attn is None:
141
+ return tgt
142
+
143
+ if dac:
144
+ # Split queries for DAC (detect-and-classify)
145
+ assert tgt.shape[0] % 2 == 0, "DAC requires even number of queries"
146
+ num_o2o_queries = tgt.shape[0] // 2
147
+ tgt_o2o = tgt[:num_o2o_queries]
148
+ tgt_query_pos_o2o = tgt_query_pos[:num_o2o_queries]
149
+ tgt_o2m = tgt[num_o2o_queries:]
150
+ else:
151
+ tgt_o2o = tgt
152
+ tgt_query_pos_o2o = tgt_query_pos
153
+
154
+ # Handle presence token
155
+ if presence_token is not None:
156
+ tgt_o2o = torch.cat([presence_token, tgt_o2o], dim=0)
157
+ tgt_query_pos_o2o = torch.cat([torch.zeros_like(presence_token), tgt_query_pos_o2o], dim=0).to(
158
+ tgt_o2o.dtype
159
+ )
160
+ tgt_query_pos = torch.cat([torch.zeros_like(presence_token), tgt_query_pos], dim=0)
161
+
162
+ # Self-attention
163
+ q = k = self.with_pos_embed(tgt_o2o, tgt_query_pos_o2o)
164
+ tgt2 = self.self_attn(q, k, tgt_o2o, attn_mask=self_attn_mask)[0].to(tgt.dtype)
165
+ tgt_o2o = tgt_o2o + self.dropout2(tgt2)
166
+
167
+ # Recombine and normalize
168
+ if dac:
169
+ if not dac_use_selfatt_ln:
170
+ tgt_o2o = self.norm2(tgt_o2o)
171
+ tgt = torch.cat((tgt_o2o, tgt_o2m), dim=0)
172
+ if dac_use_selfatt_ln:
173
+ tgt = self.norm2(tgt)
174
+ else:
175
+ tgt = tgt_o2o
176
+ tgt = self.norm2(tgt)
177
+
178
+ return tgt, tgt_query_pos
179
+
180
+
181
+ class TransformerDecoder(nn.Module):
182
+ """Transformer Decoder consisting of multiple layers."""
183
+
184
+ def __init__(
185
+ self,
186
+ d_model: int,
187
+ frozen: bool,
188
+ interaction_layer,
189
+ layer,
190
+ num_layers: int,
191
+ num_queries: int,
192
+ return_intermediate: bool,
193
+ box_refine: bool = False,
194
+ num_o2m_queries: int = 0,
195
+ dac: bool = False,
196
+ boxRPB: str = "none",
197
+ # Experimental: An object query for SAM 2 tasks
198
+ instance_query: bool = False,
199
+ # Defines the number of additional instance queries,
200
+ # 1 or 4 are the most likely for single vs multi mask support
201
+ num_instances: int = 1, # Irrelevant if instance_query is False
202
+ dac_use_selfatt_ln: bool = True,
203
+ use_act_checkpoint: bool = False,
204
+ compile_mode=None,
205
+ presence_token: bool = False,
206
+ clamp_presence_logits: bool = True,
207
+ clamp_presence_logit_max_val: float = 10.0,
208
+ use_normed_output_consistently: bool = True,
209
+ separate_box_head_instance: bool = False,
210
+ separate_norm_instance: bool = False,
211
+ ):
212
+ """Initialize the TransformerDecoder."""
213
+ super().__init__()
214
+ self.d_model = d_model
215
+ self.layers = _get_clones(layer, num_layers)
216
+ self.fine_layers = (
217
+ _get_clones(interaction_layer, num_layers) if interaction_layer is not None else [None] * num_layers
218
+ )
219
+ self.num_layers = num_layers
220
+ self.num_queries = num_queries
221
+ self.dac = dac
222
+ if dac:
223
+ self.num_o2m_queries = num_queries
224
+ tot_num_queries = num_queries
225
+ else:
226
+ self.num_o2m_queries = num_o2m_queries
227
+ tot_num_queries = num_queries + num_o2m_queries
228
+ self.norm = nn.LayerNorm(d_model)
229
+ self.return_intermediate = return_intermediate
230
+ self.bbox_embed = MLP(d_model, d_model, 4, 3)
231
+ self.query_embed = nn.Embedding(tot_num_queries, d_model)
232
+ self.instance_query_embed = None
233
+ self.instance_query_reference_points = None
234
+ self.use_instance_query = instance_query
235
+ self.num_instances = num_instances
236
+ self.use_normed_output_consistently = use_normed_output_consistently
237
+
238
+ self.instance_norm = nn.LayerNorm(d_model) if separate_norm_instance else None
239
+ self.instance_bbox_embed = None
240
+ if separate_box_head_instance:
241
+ self.instance_bbox_embed = MLP(d_model, d_model, 4, 3)
242
+ if instance_query:
243
+ self.instance_query_embed = nn.Embedding(num_instances, d_model)
244
+ self.box_refine = box_refine
245
+ if box_refine:
246
+ nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
247
+ nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
248
+
249
+ self.reference_points = nn.Embedding(num_queries, 4)
250
+ if instance_query:
251
+ self.instance_reference_points = nn.Embedding(num_instances, 4)
252
+
253
+ assert boxRPB in ["none", "log", "linear", "both"]
254
+ self.boxRPB = boxRPB
255
+ if boxRPB != "none":
256
+ try:
257
+ nheads = self.layers[0].cross_attn_image.num_heads
258
+ except AttributeError:
259
+ nheads = self.layers[0].cross_attn.num_heads
260
+
261
+ n_input = 4 if boxRPB == "both" else 2
262
+ self.boxRPB_embed_x = MLP(n_input, d_model, nheads, 2)
263
+ self.boxRPB_embed_y = MLP(n_input, d_model, nheads, 2)
264
+ self.compilable_cord_cache = None
265
+ self.compilable_stored_size = None
266
+ self.coord_cache = {}
267
+
268
+ self.roi_pooler = (
269
+ RoIAlign(output_size=7, spatial_scale=1, sampling_ratio=-1, aligned=True)
270
+ if interaction_layer is not None
271
+ else None
272
+ )
273
+ if frozen:
274
+ for p in self.parameters():
275
+ p.requires_grad_(False)
276
+
277
+ self.presence_token = None
278
+ self.clamp_presence_logits = clamp_presence_logits
279
+ self.clamp_presence_logit_max_val = clamp_presence_logit_max_val
280
+ if presence_token:
281
+ self.presence_token = nn.Embedding(1, d_model)
282
+ self.presence_token_head = MLP(d_model, d_model, 1, 3)
283
+ self.presence_token_out_norm = nn.LayerNorm(d_model)
284
+
285
+ self.ref_point_head = MLP(2 * self.d_model, self.d_model, self.d_model, 2)
286
+ self.dac_use_selfatt_ln = dac_use_selfatt_ln
287
+ self.use_act_checkpoint = use_act_checkpoint
288
+
289
+ nn.init.normal_(self.query_embed.weight.data)
290
+ if self.instance_query_embed is not None:
291
+ nn.init.normal_(self.instance_query_embed.weight.data)
292
+
293
+ assert self.roi_pooler is None
294
+ assert self.return_intermediate, "support return_intermediate only"
295
+ assert self.box_refine, "support box refine only"
296
+
297
+ self.compile_mode = compile_mode
298
+ self.compiled = False
299
+ # We defer compilation till after the first forward, to first warm-up the boxRPB cache
300
+
301
+ # assign layer index to each layer so that some layers can decide what to do
302
+ # based on which layer index they are (e.g. cross attention to memory bank only
303
+ # in selected layers)
304
+ for layer_idx, layer in enumerate(self.layers):
305
+ layer.layer_idx = layer_idx
306
+
307
+ @staticmethod
308
+ def _get_coords(H, W, device, dtype):
309
+ """Get normalized coordinates for height and width."""
310
+ coords_h = torch.arange(0, H, dtype=dtype, device=device) / H
311
+ coords_w = torch.arange(0, W, dtype=dtype, device=device) / W
312
+ return coords_h, coords_w
313
+
314
+ def _get_rpb_matrix(self, reference_boxes, feat_size):
315
+ """Get the relative position bias (RPB) matrix for box-relative position bias."""
316
+ H, W = feat_size
317
+ boxes_xyxy = xywh2xyxy(reference_boxes).transpose(0, 1)
318
+ bs, num_queries, _ = boxes_xyxy.shape
319
+ if self.compilable_cord_cache is None:
320
+ self.compilable_cord_cache = self._get_coords(H, W, reference_boxes.device, reference_boxes.dtype)
321
+ self.compilable_stored_size = (H, W)
322
+
323
+ if torch.compiler.is_dynamo_compiling() or self.compilable_stored_size == (
324
+ H,
325
+ W,
326
+ ):
327
+ # good, hitting the cache, will be compilable
328
+ coords_h, coords_w = self.compilable_cord_cache
329
+ else:
330
+ # cache miss, will create compilation issue
331
+ # In case we're not compiling, we'll still rely on the dict-based cache
332
+ if feat_size not in self.coord_cache:
333
+ self.coord_cache[feat_size] = self._get_coords(H, W, reference_boxes.device)
334
+ coords_h, coords_w = self.coord_cache[feat_size]
335
+
336
+ assert coords_h.shape == (H,)
337
+ assert coords_w.shape == (W,)
338
+
339
+ deltas_y = coords_h.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 1:4:2]
340
+ deltas_y = deltas_y.view(bs, num_queries, -1, 2)
341
+ deltas_x = coords_w.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 0:3:2]
342
+ deltas_x = deltas_x.view(bs, num_queries, -1, 2)
343
+
344
+ if self.boxRPB in ["log", "both"]:
345
+ deltas_x_log = deltas_x * 8 # normalize to -8, 8
346
+ deltas_x_log = torch.sign(deltas_x_log) * torch.log2(torch.abs(deltas_x_log) + 1.0) / np.log2(8)
347
+
348
+ deltas_y_log = deltas_y * 8 # normalize to -8, 8
349
+ deltas_y_log = torch.sign(deltas_y_log) * torch.log2(torch.abs(deltas_y_log) + 1.0) / np.log2(8)
350
+ if self.boxRPB == "log":
351
+ deltas_x = deltas_x_log
352
+ deltas_y = deltas_y_log
353
+ else:
354
+ deltas_x = torch.cat([deltas_x, deltas_x_log], dim=-1)
355
+ deltas_y = torch.cat([deltas_y, deltas_y_log], dim=-1)
356
+
357
+ if self.training:
358
+ assert self.use_act_checkpoint, "activation ckpt not enabled in decoder"
359
+ deltas_x = self.boxRPB_embed_x(x=deltas_x) # bs, num_queries, W, n_heads
360
+ deltas_y = self.boxRPB_embed_y(x=deltas_y) # bs, num_queries, H, n_heads
361
+
362
+ if not torch.compiler.is_dynamo_compiling():
363
+ assert deltas_x.shape[:3] == (bs, num_queries, W)
364
+ assert deltas_y.shape[:3] == (bs, num_queries, H)
365
+
366
+ B = deltas_y.unsqueeze(3) + deltas_x.unsqueeze(2) # bs, num_queries, H, W, n_heads
367
+ if not torch.compiler.is_dynamo_compiling():
368
+ assert B.shape[:4] == (bs, num_queries, H, W)
369
+ B = B.flatten(2, 3) # bs, num_queries, H*W, n_heads
370
+ B = B.permute(0, 3, 1, 2) # bs, n_heads, num_queries, H*W
371
+ B = B.contiguous() # memeff attn likes ordered strides
372
+ if not torch.compiler.is_dynamo_compiling():
373
+ assert B.shape[2:] == (num_queries, H * W)
374
+ return B
375
+
376
+ def forward(
377
+ self,
378
+ tgt,
379
+ memory,
380
+ tgt_mask: torch.Tensor = None,
381
+ memory_mask: torch.Tensor = None,
382
+ memory_key_padding_mask: torch.Tensor = None,
383
+ pos: torch.Tensor = None,
384
+ reference_boxes: torch.Tensor = None, # num_queries, bs, 4
385
+ # for memory
386
+ spatial_shapes: torch.Tensor = None, # bs, num_levels, 2
387
+ valid_ratios: torch.Tensor = None,
388
+ # for text
389
+ memory_text: torch.Tensor = None,
390
+ text_attention_mask: torch.Tensor = None,
391
+ # if `apply_dac` is None, it will default to `self.dac`
392
+ apply_dac: bool | None = None,
393
+ is_instance_prompt=False,
394
+ decoder_extra_kwargs: dict | None = None,
395
+ # ROI memory bank
396
+ obj_roi_memory_feat=None,
397
+ obj_roi_memory_mask=None,
398
+ box_head_trk=None,
399
+ ):
400
+ """Forward pass of the TransformerDecoder."""
401
+ if memory_mask is not None:
402
+ assert self.boxRPB == "none", (
403
+ "inputting a memory_mask in the presence of boxRPB is unexpected/not implemented"
404
+ )
405
+
406
+ apply_dac = apply_dac if apply_dac is not None else self.dac
407
+ if apply_dac:
408
+ assert (tgt.shape[0] == self.num_queries) or (
409
+ self.use_instance_query and (tgt.shape[0] == self.instance_query_embed.num_embeddings)
410
+ )
411
+
412
+ tgt = tgt.repeat(2, 1, 1)
413
+ # note that we don't tile tgt_mask, since DAC doesn't
414
+ # use self-attention in o2m queries
415
+ if reference_boxes is not None:
416
+ assert (reference_boxes.shape[0] == self.num_queries) or (
417
+ self.use_instance_query and (reference_boxes.shape[0] == self.instance_query_embed.num_embeddings)
418
+ )
419
+ reference_boxes = reference_boxes.repeat(2, 1, 1)
420
+
421
+ bs = tgt.shape[1]
422
+ intermediate = []
423
+ intermediate_presence_logits = []
424
+ presence_feats = None
425
+
426
+ if self.box_refine:
427
+ if reference_boxes is None:
428
+ # In this case, we're in a one-stage model, so we generate the reference boxes
429
+ reference_boxes = self.reference_points.weight.unsqueeze(1)
430
+ reference_boxes = reference_boxes.repeat(2, bs, 1) if apply_dac else reference_boxes.repeat(1, bs, 1)
431
+ reference_boxes = reference_boxes.sigmoid()
432
+ intermediate_ref_boxes = [reference_boxes]
433
+ else:
434
+ reference_boxes = None
435
+ intermediate_ref_boxes = None
436
+
437
+ output = tgt
438
+ presence_out = None
439
+ if self.presence_token is not None and is_instance_prompt is False:
440
+ # expand to batch dim
441
+ presence_out = self.presence_token.weight[None].expand(1, bs, -1)
442
+
443
+ box_head = self.bbox_embed
444
+ if is_instance_prompt and self.instance_bbox_embed is not None:
445
+ box_head = self.instance_bbox_embed
446
+
447
+ out_norm = self.norm
448
+ if is_instance_prompt and self.instance_norm is not None:
449
+ out_norm = self.instance_norm
450
+
451
+ for layer_idx, layer in enumerate(self.layers):
452
+ reference_points_input = (
453
+ reference_boxes[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[None, :]
454
+ ) # nq, bs, nlevel, 4
455
+
456
+ query_sine_embed = gen_sineembed_for_position(
457
+ reference_points_input[:, :, 0, :], self.d_model
458
+ ) # nq, bs, d_model*2
459
+
460
+ # conditional query
461
+ query_pos = self.ref_point_head(query_sine_embed) # nq, bs, d_model
462
+
463
+ if self.boxRPB != "none" and reference_boxes is not None:
464
+ assert spatial_shapes.shape[0] == 1, "only single scale support implemented"
465
+ memory_mask = self._get_rpb_matrix(
466
+ reference_boxes,
467
+ (spatial_shapes[0, 0], spatial_shapes[0, 1]),
468
+ )
469
+ memory_mask = memory_mask.flatten(0, 1) # (bs*n_heads, nq, H*W)
470
+ if self.training:
471
+ assert self.use_act_checkpoint, "Activation checkpointing not enabled in the decoder"
472
+ output, presence_out = layer(
473
+ tgt=output,
474
+ tgt_query_pos=query_pos,
475
+ memory_text=memory_text,
476
+ text_attention_mask=text_attention_mask,
477
+ memory=memory,
478
+ memory_key_padding_mask=memory_key_padding_mask,
479
+ memory_pos=pos,
480
+ self_attn_mask=tgt_mask,
481
+ cross_attn_mask=memory_mask,
482
+ dac=apply_dac,
483
+ dac_use_selfatt_ln=self.dac_use_selfatt_ln,
484
+ presence_token=presence_out,
485
+ **(decoder_extra_kwargs or {}),
486
+ # ROI memory bank
487
+ obj_roi_memory_feat=obj_roi_memory_feat,
488
+ obj_roi_memory_mask=obj_roi_memory_mask,
489
+ )
490
+
491
+ # iter update
492
+ if self.box_refine:
493
+ reference_before_sigmoid = inverse_sigmoid(reference_boxes)
494
+ if box_head_trk is None:
495
+ # delta_unsig = self.bbox_embed(output)
496
+ if not self.use_normed_output_consistently:
497
+ delta_unsig = box_head(output)
498
+ else:
499
+ delta_unsig = box_head(out_norm(output))
500
+ else:
501
+ # box_head_trk use a separate box head for tracking queries
502
+ Q_det = decoder_extra_kwargs["Q_det"]
503
+ assert output.size(0) >= Q_det
504
+ delta_unsig_det = self.bbox_embed(output[:Q_det])
505
+ delta_unsig_trk = box_head_trk(output[Q_det:])
506
+ delta_unsig = torch.cat([delta_unsig_det, delta_unsig_trk], dim=0)
507
+ outputs_unsig = delta_unsig + reference_before_sigmoid
508
+ new_reference_points = outputs_unsig.sigmoid()
509
+
510
+ reference_boxes = new_reference_points.detach()
511
+ if layer_idx != self.num_layers - 1:
512
+ intermediate_ref_boxes.append(new_reference_points)
513
+ else:
514
+ raise NotImplementedError("not implemented yet")
515
+
516
+ intermediate.append(out_norm(output))
517
+ if self.presence_token is not None and is_instance_prompt is False:
518
+ # norm, mlp head
519
+ intermediate_layer_presence_logits = self.presence_token_head(
520
+ self.presence_token_out_norm(presence_out)
521
+ ).squeeze(-1)
522
+
523
+ # clamp to mitigate numerical issues
524
+ if self.clamp_presence_logits:
525
+ intermediate_layer_presence_logits.clamp(
526
+ min=-self.clamp_presence_logit_max_val,
527
+ max=self.clamp_presence_logit_max_val,
528
+ )
529
+
530
+ intermediate_presence_logits.append(intermediate_layer_presence_logits)
531
+ presence_feats = presence_out.clone()
532
+
533
+ if not self.compiled and self.compile_mode is not None:
534
+ self.forward = torch.compile(self.forward, mode=self.compile_mode, fullgraph=True)
535
+ self.compiled = True
536
+
537
+ return (
538
+ torch.stack(intermediate),
539
+ torch.stack(intermediate_ref_boxes),
540
+ (
541
+ torch.stack(intermediate_presence_logits)
542
+ if self.presence_token is not None and is_instance_prompt is False
543
+ else None
544
+ ),
545
+ presence_feats,
546
+ )