dgenerate-ultralytics-headless 8.3.236__py3-none-any.whl → 8.3.237__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {dgenerate_ultralytics_headless-8.3.236.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/METADATA +1 -1
  2. {dgenerate_ultralytics_headless-8.3.236.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/RECORD +38 -25
  3. ultralytics/__init__.py +1 -1
  4. ultralytics/engine/exporter.py +17 -10
  5. ultralytics/engine/predictor.py +3 -2
  6. ultralytics/engine/trainer.py +8 -0
  7. ultralytics/models/rtdetr/val.py +5 -1
  8. ultralytics/models/sam/__init__.py +14 -1
  9. ultralytics/models/sam/build.py +17 -8
  10. ultralytics/models/sam/build_sam3.py +374 -0
  11. ultralytics/models/sam/model.py +12 -4
  12. ultralytics/models/sam/modules/blocks.py +20 -8
  13. ultralytics/models/sam/modules/decoders.py +2 -3
  14. ultralytics/models/sam/modules/encoders.py +4 -1
  15. ultralytics/models/sam/modules/memory_attention.py +6 -2
  16. ultralytics/models/sam/modules/sam.py +150 -6
  17. ultralytics/models/sam/modules/utils.py +134 -4
  18. ultralytics/models/sam/predict.py +2076 -118
  19. ultralytics/models/sam/sam3/__init__.py +3 -0
  20. ultralytics/models/sam/sam3/decoder.py +546 -0
  21. ultralytics/models/sam/sam3/encoder.py +535 -0
  22. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  23. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  24. ultralytics/models/sam/sam3/model_misc.py +198 -0
  25. ultralytics/models/sam/sam3/necks.py +129 -0
  26. ultralytics/models/sam/sam3/sam3_image.py +357 -0
  27. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  28. ultralytics/models/sam/sam3/tokenizer_ve.py +242 -0
  29. ultralytics/models/sam/sam3/vitdet.py +546 -0
  30. ultralytics/models/sam/sam3/vl_combiner.py +165 -0
  31. ultralytics/models/yolo/obb/val.py +18 -7
  32. ultralytics/nn/modules/transformer.py +21 -1
  33. ultralytics/utils/checks.py +2 -2
  34. ultralytics/utils/ops.py +1 -3
  35. {dgenerate_ultralytics_headless-8.3.236.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/WHEEL +0 -0
  36. {dgenerate_ultralytics_headless-8.3.236.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/entry_points.txt +0 -0
  37. {dgenerate_ultralytics_headless-8.3.236.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/licenses/LICENSE +0 -0
  38. {dgenerate_ultralytics_headless-8.3.236.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,374 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
4
+
5
+ import torch.nn as nn
6
+
7
+ from ultralytics.nn.modules.transformer import MLP
8
+ from ultralytics.utils.patches import torch_load
9
+
10
+ from .modules.blocks import PositionEmbeddingSine, RoPEAttention
11
+ from .modules.encoders import MemoryEncoder
12
+ from .modules.memory_attention import MemoryAttention, MemoryAttentionLayer
13
+ from .modules.sam import SAM3Model
14
+ from .sam3.decoder import TransformerDecoder, TransformerDecoderLayer
15
+ from .sam3.encoder import TransformerEncoderFusion, TransformerEncoderLayer
16
+ from .sam3.geometry_encoders import SequenceGeometryEncoder
17
+ from .sam3.maskformer_segmentation import PixelDecoder, UniversalSegmentationHead
18
+ from .sam3.model_misc import DotProductScoring, TransformerWrapper
19
+ from .sam3.necks import Sam3DualViTDetNeck
20
+ from .sam3.sam3_image import SAM3SemanticModel
21
+ from .sam3.text_encoder_ve import VETextEncoder
22
+ from .sam3.tokenizer_ve import SimpleTokenizer
23
+ from .sam3.vitdet import ViT
24
+ from .sam3.vl_combiner import SAM3VLBackbone
25
+
26
+
27
+ def _create_vision_backbone(compile_mode=None, enable_inst_interactivity=True) -> Sam3DualViTDetNeck:
28
+ """Create SAM3 visual backbone with ViT and neck."""
29
+ # Position encoding
30
+ position_encoding = PositionEmbeddingSine(
31
+ num_pos_feats=256,
32
+ normalize=True,
33
+ scale=None,
34
+ temperature=10000,
35
+ )
36
+
37
+ # ViT backbone
38
+ vit_backbone = ViT(
39
+ img_size=1008,
40
+ pretrain_img_size=336,
41
+ patch_size=14,
42
+ embed_dim=1024,
43
+ depth=32,
44
+ num_heads=16,
45
+ mlp_ratio=4.625,
46
+ norm_layer="LayerNorm",
47
+ drop_path_rate=0.1,
48
+ qkv_bias=True,
49
+ use_abs_pos=True,
50
+ tile_abs_pos=True,
51
+ global_att_blocks=(7, 15, 23, 31),
52
+ rel_pos_blocks=(),
53
+ use_rope=True,
54
+ use_interp_rope=True,
55
+ window_size=24,
56
+ pretrain_use_cls_token=True,
57
+ retain_cls_token=False,
58
+ ln_pre=True,
59
+ ln_post=False,
60
+ return_interm_layers=False,
61
+ bias_patch_embed=False,
62
+ compile_mode=compile_mode,
63
+ )
64
+ return Sam3DualViTDetNeck(
65
+ position_encoding=position_encoding,
66
+ d_model=256,
67
+ scale_factors=[4.0, 2.0, 1.0, 0.5],
68
+ trunk=vit_backbone,
69
+ add_sam2_neck=enable_inst_interactivity,
70
+ )
71
+
72
+
73
+ def _create_sam3_transformer() -> TransformerWrapper:
74
+ """Create SAM3 detector encoder and decoder."""
75
+ encoder: TransformerEncoderFusion = TransformerEncoderFusion(
76
+ layer=TransformerEncoderLayer(
77
+ d_model=256,
78
+ dim_feedforward=2048,
79
+ dropout=0.1,
80
+ pos_enc_at_attn=True,
81
+ pos_enc_at_cross_attn_keys=False,
82
+ pos_enc_at_cross_attn_queries=False,
83
+ pre_norm=True,
84
+ self_attention=nn.MultiheadAttention(
85
+ num_heads=8,
86
+ dropout=0.1,
87
+ embed_dim=256,
88
+ batch_first=True,
89
+ ),
90
+ cross_attention=nn.MultiheadAttention(
91
+ num_heads=8,
92
+ dropout=0.1,
93
+ embed_dim=256,
94
+ batch_first=True,
95
+ ),
96
+ ),
97
+ num_layers=6,
98
+ d_model=256,
99
+ num_feature_levels=1,
100
+ frozen=False,
101
+ use_act_checkpoint=True,
102
+ add_pooled_text_to_img_feat=False,
103
+ pool_text_with_mask=True,
104
+ )
105
+ decoder: TransformerDecoder = TransformerDecoder(
106
+ layer=TransformerDecoderLayer(
107
+ d_model=256,
108
+ dim_feedforward=2048,
109
+ dropout=0.1,
110
+ cross_attention=nn.MultiheadAttention(
111
+ num_heads=8,
112
+ dropout=0.1,
113
+ embed_dim=256,
114
+ ),
115
+ n_heads=8,
116
+ use_text_cross_attention=True,
117
+ ),
118
+ num_layers=6,
119
+ num_queries=200,
120
+ return_intermediate=True,
121
+ box_refine=True,
122
+ num_o2m_queries=0,
123
+ dac=True,
124
+ boxRPB="log",
125
+ d_model=256,
126
+ frozen=False,
127
+ interaction_layer=None,
128
+ dac_use_selfatt_ln=True,
129
+ use_act_checkpoint=True,
130
+ presence_token=True,
131
+ )
132
+
133
+ return TransformerWrapper(encoder=encoder, decoder=decoder, d_model=256)
134
+
135
+
136
+ def build_sam3_image_model(
137
+ checkpoint_path: str, bpe_path: str, enable_segmentation: bool = True, compile: bool = False
138
+ ):
139
+ """Build SAM3 image model.
140
+
141
+ Args:
142
+ checkpoint_path: Optional path to model checkpoint
143
+ bpe_path: Path to the BPE tokenizer vocabulary
144
+ enable_segmentation: Whether to enable segmentation head
145
+ compile: To enable compilation, set to "default"
146
+
147
+ Returns:
148
+ A SAM3 image model
149
+ """
150
+ # Create visual components
151
+ compile_mode = "default" if compile else None
152
+ vision_encoder = _create_vision_backbone(compile_mode=compile_mode, enable_inst_interactivity=True)
153
+
154
+ # Create text components
155
+ text_encoder = VETextEncoder(
156
+ tokenizer=SimpleTokenizer(bpe_path=bpe_path),
157
+ d_model=256,
158
+ width=1024,
159
+ heads=16,
160
+ layers=24,
161
+ )
162
+
163
+ # Create visual-language backbone
164
+ backbone = SAM3VLBackbone(visual=vision_encoder, text=text_encoder, scalp=1)
165
+
166
+ # Create transformer components
167
+ transformer = _create_sam3_transformer()
168
+
169
+ # Create dot product scoring
170
+ dot_prod_scoring = DotProductScoring(
171
+ d_model=256,
172
+ d_proj=256,
173
+ prompt_mlp=MLP(
174
+ input_dim=256,
175
+ hidden_dim=2048,
176
+ output_dim=256,
177
+ num_layers=2,
178
+ residual=True,
179
+ out_norm=nn.LayerNorm(256),
180
+ ),
181
+ )
182
+
183
+ # Create segmentation head if enabled
184
+ segmentation_head = (
185
+ UniversalSegmentationHead(
186
+ hidden_dim=256,
187
+ upsampling_stages=3,
188
+ aux_masks=False,
189
+ presence_head=False,
190
+ dot_product_scorer=None,
191
+ act_ckpt=True,
192
+ cross_attend_prompt=nn.MultiheadAttention(
193
+ num_heads=8,
194
+ dropout=0,
195
+ embed_dim=256,
196
+ ),
197
+ pixel_decoder=PixelDecoder(
198
+ num_upsampling_stages=3,
199
+ interpolation_mode="nearest",
200
+ hidden_dim=256,
201
+ compile_mode=compile_mode,
202
+ ),
203
+ )
204
+ if enable_segmentation
205
+ else None
206
+ )
207
+
208
+ # Create geometry encoder
209
+ input_geometry_encoder = SequenceGeometryEncoder(
210
+ pos_enc=PositionEmbeddingSine(
211
+ num_pos_feats=256,
212
+ normalize=True,
213
+ scale=None,
214
+ temperature=10000,
215
+ ),
216
+ encode_boxes_as_points=False,
217
+ boxes_direct_project=True,
218
+ boxes_pool=True,
219
+ boxes_pos_enc=True,
220
+ d_model=256,
221
+ num_layers=3,
222
+ layer=TransformerEncoderLayer(
223
+ d_model=256,
224
+ dim_feedforward=2048,
225
+ dropout=0.1,
226
+ pos_enc_at_attn=False,
227
+ pre_norm=True,
228
+ pos_enc_at_cross_attn_queries=False,
229
+ pos_enc_at_cross_attn_keys=True,
230
+ ),
231
+ use_act_ckpt=True,
232
+ add_cls=True,
233
+ add_post_encode_proj=True,
234
+ )
235
+
236
+ # Create the SAM3SemanticModel model
237
+ model = SAM3SemanticModel(
238
+ backbone=backbone,
239
+ transformer=transformer,
240
+ input_geometry_encoder=input_geometry_encoder,
241
+ segmentation_head=segmentation_head,
242
+ num_feature_levels=1,
243
+ o2m_mask_predict=True,
244
+ dot_prod_scoring=dot_prod_scoring,
245
+ use_instance_query=False,
246
+ multimask_output=True,
247
+ )
248
+
249
+ # Load checkpoint
250
+ model = _load_checkpoint(model, checkpoint_path)
251
+ model.eval()
252
+ return model
253
+
254
+
255
+ def build_interactive_sam3(checkpoint_path: str, compile=None, with_backbone=True) -> SAM3Model:
256
+ """Build the SAM3 Tracker module for video tracking.
257
+
258
+ Returns:
259
+ Sam3TrackerPredictor: Wrapped SAM3 Tracker module
260
+ """
261
+ # Create model components
262
+ memory_encoder = MemoryEncoder(out_dim=64, interpol_size=[1152, 1152])
263
+ memory_attention = MemoryAttention(
264
+ batch_first=True,
265
+ d_model=256,
266
+ pos_enc_at_input=True,
267
+ layer=MemoryAttentionLayer(
268
+ dim_feedforward=2048,
269
+ dropout=0.1,
270
+ pos_enc_at_attn=False,
271
+ pos_enc_at_cross_attn_keys=True,
272
+ pos_enc_at_cross_attn_queries=False,
273
+ self_attn=RoPEAttention(
274
+ embedding_dim=256,
275
+ num_heads=1,
276
+ downsample_rate=1,
277
+ rope_theta=10000.0,
278
+ feat_sizes=[72, 72],
279
+ ),
280
+ d_model=256,
281
+ cross_attn=RoPEAttention(
282
+ embedding_dim=256,
283
+ num_heads=1,
284
+ downsample_rate=1,
285
+ kv_in_dim=64,
286
+ rope_theta=10000.0,
287
+ feat_sizes=[72, 72],
288
+ rope_k_repeat=True,
289
+ ),
290
+ ),
291
+ num_layers=4,
292
+ )
293
+
294
+ backbone = (
295
+ SAM3VLBackbone(scalp=1, visual=_create_vision_backbone(compile_mode=compile), text=None)
296
+ if with_backbone
297
+ else None
298
+ )
299
+ model = SAM3Model(
300
+ image_size=1008,
301
+ image_encoder=backbone,
302
+ memory_attention=memory_attention,
303
+ memory_encoder=memory_encoder,
304
+ backbone_stride=14,
305
+ num_maskmem=7,
306
+ sigmoid_scale_for_mem_enc=20.0,
307
+ sigmoid_bias_for_mem_enc=-10.0,
308
+ use_mask_input_as_output_without_sam=True,
309
+ directly_add_no_mem_embed=True,
310
+ use_high_res_features_in_sam=True,
311
+ multimask_output_in_sam=True,
312
+ iou_prediction_use_sigmoid=True,
313
+ use_obj_ptrs_in_encoder=True,
314
+ add_tpos_enc_to_obj_ptrs=True,
315
+ only_obj_ptrs_in_the_past_for_eval=True,
316
+ pred_obj_scores=True,
317
+ pred_obj_scores_mlp=True,
318
+ fixed_no_obj_ptr=True,
319
+ multimask_output_for_tracking=True,
320
+ use_multimask_token_for_obj_ptr=True,
321
+ multimask_min_pt_num=0,
322
+ multimask_max_pt_num=1,
323
+ use_mlp_for_obj_ptr_proj=True,
324
+ compile_image_encoder=False,
325
+ no_obj_embed_spatial=True,
326
+ proj_tpos_enc_in_obj_ptrs=True,
327
+ use_signed_tpos_enc_to_obj_ptrs=True,
328
+ sam_mask_decoder_extra_args=dict(
329
+ dynamic_multimask_via_stability=True,
330
+ dynamic_multimask_stability_delta=0.05,
331
+ dynamic_multimask_stability_thresh=0.98,
332
+ ),
333
+ )
334
+
335
+ # Load checkpoint if provided
336
+ model = _load_checkpoint(model, checkpoint_path, interactive=True)
337
+
338
+ # Setup device and mode
339
+ model.eval()
340
+ return model
341
+
342
+
343
+ def _load_checkpoint(model, checkpoint, interactive=False):
344
+ """Load SAM3 model checkpoint from file."""
345
+ with open(checkpoint, "rb") as f:
346
+ ckpt = torch_load(f)
347
+ if "model" in ckpt and isinstance(ckpt["model"], dict):
348
+ ckpt = ckpt["model"]
349
+ sam3_image_ckpt = {k.replace("detector.", ""): v for k, v in ckpt.items() if "detector" in k}
350
+ if interactive:
351
+ sam3_image_ckpt.update(
352
+ {
353
+ k.replace("backbone.vision_backbone", "image_encoder.vision_backbone"): v
354
+ for k, v in sam3_image_ckpt.items()
355
+ if "backbone.vision_backbone" in k
356
+ }
357
+ )
358
+ sam3_image_ckpt.update(
359
+ {
360
+ k.replace("tracker.transformer.encoder", "memory_attention"): v
361
+ for k, v in ckpt.items()
362
+ if "tracker.transformer" in k
363
+ }
364
+ )
365
+ sam3_image_ckpt.update(
366
+ {
367
+ k.replace("tracker.maskmem_backbone", "memory_encoder"): v
368
+ for k, v in ckpt.items()
369
+ if "tracker.maskmem_backbone" in k
370
+ }
371
+ )
372
+ sam3_image_ckpt.update({k.replace("tracker.", ""): v for k, v in ckpt.items() if "tracker." in k})
373
+ model.load_state_dict(sam3_image_ckpt, strict=False)
374
+ return model
@@ -21,7 +21,7 @@ from pathlib import Path
21
21
  from ultralytics.engine.model import Model
22
22
  from ultralytics.utils.torch_utils import model_info
23
23
 
24
- from .predict import Predictor, SAM2Predictor
24
+ from .predict import Predictor, SAM2Predictor, SAM3Predictor
25
25
 
26
26
 
27
27
  class SAM(Model):
@@ -59,6 +59,7 @@ class SAM(Model):
59
59
  if model and Path(model).suffix not in {".pt", ".pth"}:
60
60
  raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
61
61
  self.is_sam2 = "sam2" in Path(model).stem
62
+ self.is_sam3 = "sam3" in Path(model).stem
62
63
  super().__init__(model=model, task="segment")
63
64
 
64
65
  def _load(self, weights: str, task=None):
@@ -72,9 +73,14 @@ class SAM(Model):
72
73
  >>> sam = SAM("sam_b.pt")
73
74
  >>> sam._load("path/to/custom_weights.pt")
74
75
  """
75
- from .build import build_sam # slow import
76
+ if self.is_sam3:
77
+ from .build_sam3 import build_interactive_sam3
76
78
 
77
- self.model = build_sam(weights)
79
+ self.model = build_interactive_sam3(weights)
80
+ else:
81
+ from .build import build_sam # slow import
82
+
83
+ self.model = build_sam(weights)
78
84
 
79
85
  def predict(self, source, stream: bool = False, bboxes=None, points=None, labels=None, **kwargs):
80
86
  """Perform segmentation prediction on the given image or video source.
@@ -158,4 +164,6 @@ class SAM(Model):
158
164
  >>> print(task_map)
159
165
  {'segment': {'predictor': <class 'ultralytics.models.sam.predict.Predictor'>}}
160
166
  """
161
- return {"segment": {"predictor": SAM2Predictor if self.is_sam2 else Predictor}}
167
+ return {
168
+ "segment": {"predictor": SAM2Predictor if self.is_sam2 else SAM3Predictor if self.is_sam3 else Predictor}
169
+ }
@@ -79,6 +79,7 @@ class MaskDownSampler(nn.Module):
79
79
  padding: int = 0,
80
80
  total_stride: int = 16,
81
81
  activation: type[nn.Module] = nn.GELU,
82
+ interpol_size: tuple[int, int] | None = None,
82
83
  ):
83
84
  """Initialize a mask downsampler module for progressive downsampling and channel expansion."""
84
85
  super().__init__()
@@ -102,9 +103,24 @@ class MaskDownSampler(nn.Module):
102
103
  mask_in_chans = mask_out_chans
103
104
 
104
105
  self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
106
+ self.interpol_size = interpol_size
107
+ if self.interpol_size is not None:
108
+ assert isinstance(self.interpol_size, (list, tuple)), (
109
+ f"Unsupported type {type(self.interpol_size)}. Should be a list or tuple."
110
+ )
111
+ self.interpol_size = list(interpol_size)
112
+ assert len(self.interpol_size) == 2
105
113
 
106
114
  def forward(self, x: Tensor) -> Tensor:
107
115
  """Downsample and encode input mask to embed_dim channels using convolutional layers and LayerNorm2d."""
116
+ if self.interpol_size is not None and self.interpol_size != list(x.shape[-2:]):
117
+ x = F.interpolate(
118
+ x.float(),
119
+ size=self.interpol_size,
120
+ align_corners=False,
121
+ mode="bilinear",
122
+ antialias=True,
123
+ ).to(x.dtype)
108
124
  return self.encoder(x)
109
125
 
110
126
 
@@ -429,13 +445,7 @@ class RoPEAttention(Attention):
429
445
  )
430
446
 
431
447
  # Attention
432
- _, _, _, c_per_head = q.shape
433
- attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
434
- attn = attn / math.sqrt(c_per_head)
435
- attn = torch.softmax(attn, dim=-1)
436
-
437
- # Get output
438
- out = attn @ v
448
+ out = F.scaled_dot_product_attention(q, k, v)
439
449
 
440
450
  out = self._recombine_heads(out)
441
451
  out = self.out_proj(out)
@@ -1033,6 +1043,7 @@ class PatchEmbed(nn.Module):
1033
1043
  padding: tuple[int, int] = (0, 0),
1034
1044
  in_chans: int = 3,
1035
1045
  embed_dim: int = 768,
1046
+ bias: bool = True,
1036
1047
  ) -> None:
1037
1048
  """Initialize the PatchEmbed module for converting image patches to embeddings.
1038
1049
 
@@ -1045,10 +1056,11 @@ class PatchEmbed(nn.Module):
1045
1056
  padding (tuple[int, int]): Padding applied to the input before convolution.
1046
1057
  in_chans (int): Number of input image channels.
1047
1058
  embed_dim (int): Dimensionality of the output patch embeddings.
1059
+ bias (bool): Whether to include a bias term in the convolutional layer.
1048
1060
  """
1049
1061
  super().__init__()
1050
1062
 
1051
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
1063
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
1052
1064
 
1053
1065
  def forward(self, x: torch.Tensor) -> torch.Tensor:
1054
1066
  """Compute patch embedding by applying convolution and transposing resulting tensor."""
@@ -436,9 +436,8 @@ class SAM2MaskDecoder(nn.Module):
436
436
  def _get_stability_scores(self, mask_logits):
437
437
  """Compute mask stability scores based on IoU between upper and lower thresholds."""
438
438
  mask_logits = mask_logits.flatten(-2)
439
- stability_delta = self.dynamic_multimask_stability_delta
440
- area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
441
- area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
439
+ area_i = torch.sum(mask_logits > self.dynamic_multimask_stability_delta, dim=-1).float()
440
+ area_u = torch.sum(mask_logits > -self.dynamic_multimask_stability_delta, dim=-1).float()
442
441
  return torch.where(area_u > 0, area_i / area_u, 1.0)
443
442
 
444
443
  def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
@@ -361,6 +361,7 @@ class MemoryEncoder(nn.Module):
361
361
  self,
362
362
  out_dim,
363
363
  in_dim=256, # in_dim of pix_feats
364
+ interpol_size: tuple[int, int] | None = None,
364
365
  ):
365
366
  """Initialize the MemoryEncoder for encoding pixel features and masks into memory representations.
366
367
 
@@ -370,10 +371,12 @@ class MemoryEncoder(nn.Module):
370
371
  Args:
371
372
  out_dim (int): Output dimension of the encoded features.
372
373
  in_dim (int): Input dimension of the pixel features.
374
+ interpol_size (tuple[int, int] | None): Size to interpolate masks to. If None, uses the size of pixel
375
+ features.
373
376
  """
374
377
  super().__init__()
375
378
 
376
- self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1)
379
+ self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1, interpol_size=interpol_size)
377
380
 
378
381
  self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
379
382
  self.fuser = Fuser(CXBlock(dim=256), num_layers=2)
@@ -59,6 +59,8 @@ class MemoryAttentionLayer(nn.Module):
59
59
  pos_enc_at_attn: bool = False,
60
60
  pos_enc_at_cross_attn_keys: bool = True,
61
61
  pos_enc_at_cross_attn_queries: bool = False,
62
+ self_attn: nn.Module | None = None,
63
+ cross_attn: nn.Module | None = None,
62
64
  ):
63
65
  """Initialize a memory attention layer with self-attention, cross-attention, and feedforward components.
64
66
 
@@ -69,13 +71,15 @@ class MemoryAttentionLayer(nn.Module):
69
71
  pos_enc_at_attn (bool): Whether to add positional encoding at attention.
70
72
  pos_enc_at_cross_attn_keys (bool): Whether to add positional encoding to cross-attention keys.
71
73
  pos_enc_at_cross_attn_queries (bool): Whether to add positional encoding to cross-attention queries.
74
+ self_attn (nn.Module | None): Custom self-attention module. If None, a default RoPEAttention is used.
75
+ cross_attn (nn.Module | None): Custom cross-attention module. If None, a default RoPEAttention is used.
72
76
  """
73
77
  super().__init__()
74
78
  self.d_model = d_model
75
79
  self.dim_feedforward = dim_feedforward
76
80
  self.dropout_value = dropout
77
- self.self_attn = RoPEAttention(embedding_dim=256, num_heads=1, downsample_rate=1)
78
- self.cross_attn_image = RoPEAttention(
81
+ self.self_attn = self_attn or RoPEAttention(embedding_dim=256, num_heads=1, downsample_rate=1)
82
+ self.cross_attn_image = cross_attn or RoPEAttention(
79
83
  rope_k_repeat=True,
80
84
  embedding_dim=256,
81
85
  num_heads=1,