dgenerate-ultralytics-headless 8.3.236__py3-none-any.whl → 8.3.239__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. {dgenerate_ultralytics_headless-8.3.236.dist-info → dgenerate_ultralytics_headless-8.3.239.dist-info}/METADATA +1 -1
  2. {dgenerate_ultralytics_headless-8.3.236.dist-info → dgenerate_ultralytics_headless-8.3.239.dist-info}/RECORD +117 -105
  3. tests/test_exports.py +3 -1
  4. tests/test_python.py +2 -2
  5. tests/test_solutions.py +6 -6
  6. ultralytics/__init__.py +1 -1
  7. ultralytics/cfg/__init__.py +4 -4
  8. ultralytics/cfg/datasets/Argoverse.yaml +7 -6
  9. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  10. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  11. ultralytics/cfg/datasets/VOC.yaml +15 -16
  12. ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
  13. ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
  14. ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
  15. ultralytics/cfg/datasets/dota8.yaml +2 -2
  16. ultralytics/cfg/datasets/kitti.yaml +1 -1
  17. ultralytics/cfg/datasets/xView.yaml +16 -16
  18. ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
  19. ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
  20. ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
  21. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
  22. ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
  23. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
  24. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
  25. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
  26. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
  27. ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
  28. ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
  29. ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
  30. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
  31. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
  32. ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
  33. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
  34. ultralytics/data/augment.py +1 -1
  35. ultralytics/data/base.py +4 -2
  36. ultralytics/data/build.py +4 -4
  37. ultralytics/data/loaders.py +17 -12
  38. ultralytics/data/utils.py +4 -4
  39. ultralytics/engine/exporter.py +40 -25
  40. ultralytics/engine/predictor.py +8 -6
  41. ultralytics/engine/results.py +12 -13
  42. ultralytics/engine/trainer.py +10 -2
  43. ultralytics/engine/tuner.py +2 -3
  44. ultralytics/engine/validator.py +2 -2
  45. ultralytics/models/fastsam/model.py +2 -2
  46. ultralytics/models/fastsam/predict.py +2 -3
  47. ultralytics/models/fastsam/val.py +4 -4
  48. ultralytics/models/rtdetr/predict.py +2 -3
  49. ultralytics/models/rtdetr/val.py +10 -5
  50. ultralytics/models/sam/__init__.py +14 -1
  51. ultralytics/models/sam/build.py +22 -13
  52. ultralytics/models/sam/build_sam3.py +377 -0
  53. ultralytics/models/sam/model.py +13 -5
  54. ultralytics/models/sam/modules/blocks.py +20 -8
  55. ultralytics/models/sam/modules/decoders.py +2 -3
  56. ultralytics/models/sam/modules/encoders.py +4 -1
  57. ultralytics/models/sam/modules/memory_attention.py +6 -2
  58. ultralytics/models/sam/modules/sam.py +159 -10
  59. ultralytics/models/sam/modules/utils.py +134 -4
  60. ultralytics/models/sam/predict.py +2073 -139
  61. ultralytics/models/sam/sam3/__init__.py +3 -0
  62. ultralytics/models/sam/sam3/decoder.py +546 -0
  63. ultralytics/models/sam/sam3/encoder.py +535 -0
  64. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  65. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  66. ultralytics/models/sam/sam3/model_misc.py +198 -0
  67. ultralytics/models/sam/sam3/necks.py +129 -0
  68. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  69. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  70. ultralytics/models/sam/sam3/vitdet.py +546 -0
  71. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  72. ultralytics/models/yolo/classify/val.py +1 -1
  73. ultralytics/models/yolo/detect/train.py +1 -1
  74. ultralytics/models/yolo/detect/val.py +7 -7
  75. ultralytics/models/yolo/obb/val.py +19 -8
  76. ultralytics/models/yolo/pose/val.py +1 -1
  77. ultralytics/models/yolo/segment/val.py +1 -1
  78. ultralytics/nn/autobackend.py +9 -9
  79. ultralytics/nn/modules/block.py +1 -1
  80. ultralytics/nn/modules/transformer.py +21 -1
  81. ultralytics/nn/tasks.py +3 -3
  82. ultralytics/nn/text_model.py +2 -7
  83. ultralytics/solutions/ai_gym.py +1 -1
  84. ultralytics/solutions/analytics.py +6 -6
  85. ultralytics/solutions/config.py +1 -1
  86. ultralytics/solutions/distance_calculation.py +1 -1
  87. ultralytics/solutions/object_counter.py +1 -1
  88. ultralytics/solutions/object_cropper.py +3 -6
  89. ultralytics/solutions/parking_management.py +21 -17
  90. ultralytics/solutions/queue_management.py +5 -5
  91. ultralytics/solutions/region_counter.py +2 -2
  92. ultralytics/solutions/security_alarm.py +1 -1
  93. ultralytics/solutions/solutions.py +45 -22
  94. ultralytics/solutions/speed_estimation.py +1 -1
  95. ultralytics/trackers/basetrack.py +1 -1
  96. ultralytics/trackers/bot_sort.py +4 -3
  97. ultralytics/trackers/byte_tracker.py +4 -4
  98. ultralytics/trackers/utils/gmc.py +6 -7
  99. ultralytics/trackers/utils/kalman_filter.py +2 -1
  100. ultralytics/trackers/utils/matching.py +4 -3
  101. ultralytics/utils/__init__.py +12 -3
  102. ultralytics/utils/benchmarks.py +2 -2
  103. ultralytics/utils/callbacks/tensorboard.py +19 -25
  104. ultralytics/utils/checks.py +4 -3
  105. ultralytics/utils/downloads.py +1 -1
  106. ultralytics/utils/export/tensorflow.py +16 -2
  107. ultralytics/utils/files.py +13 -12
  108. ultralytics/utils/logger.py +62 -27
  109. ultralytics/utils/metrics.py +1 -1
  110. ultralytics/utils/ops.py +7 -9
  111. ultralytics/utils/patches.py +3 -3
  112. ultralytics/utils/plotting.py +7 -12
  113. ultralytics/utils/tuner.py +1 -1
  114. {dgenerate_ultralytics_headless-8.3.236.dist-info → dgenerate_ultralytics_headless-8.3.239.dist-info}/WHEEL +0 -0
  115. {dgenerate_ultralytics_headless-8.3.236.dist-info → dgenerate_ultralytics_headless-8.3.239.dist-info}/entry_points.txt +0 -0
  116. {dgenerate_ultralytics_headless-8.3.236.dist-info → dgenerate_ultralytics_headless-8.3.239.dist-info}/licenses/LICENSE +0 -0
  117. {dgenerate_ultralytics_headless-8.3.236.dist-info → dgenerate_ultralytics_headless-8.3.239.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,377 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
4
+
5
+ import torch.nn as nn
6
+
7
+ from ultralytics.nn.modules.transformer import MLP
8
+ from ultralytics.utils.patches import torch_load
9
+
10
+ from .modules.blocks import PositionEmbeddingSine, RoPEAttention
11
+ from .modules.encoders import MemoryEncoder
12
+ from .modules.memory_attention import MemoryAttention, MemoryAttentionLayer
13
+ from .modules.sam import SAM3Model
14
+ from .sam3.decoder import TransformerDecoder, TransformerDecoderLayer
15
+ from .sam3.encoder import TransformerEncoderFusion, TransformerEncoderLayer
16
+ from .sam3.geometry_encoders import SequenceGeometryEncoder
17
+ from .sam3.maskformer_segmentation import PixelDecoder, UniversalSegmentationHead
18
+ from .sam3.model_misc import DotProductScoring, TransformerWrapper
19
+ from .sam3.necks import Sam3DualViTDetNeck
20
+ from .sam3.sam3_image import SAM3SemanticModel
21
+ from .sam3.text_encoder_ve import VETextEncoder
22
+ from .sam3.vitdet import ViT
23
+ from .sam3.vl_combiner import SAM3VLBackbone
24
+
25
+
26
+ def _create_vision_backbone(compile_mode=None, enable_inst_interactivity=True) -> Sam3DualViTDetNeck:
27
+ """Create SAM3 visual backbone with ViT and neck."""
28
+ # Position encoding
29
+ position_encoding = PositionEmbeddingSine(
30
+ num_pos_feats=256,
31
+ normalize=True,
32
+ scale=None,
33
+ temperature=10000,
34
+ )
35
+
36
+ # ViT backbone
37
+ vit_backbone = ViT(
38
+ img_size=1008,
39
+ pretrain_img_size=336,
40
+ patch_size=14,
41
+ embed_dim=1024,
42
+ depth=32,
43
+ num_heads=16,
44
+ mlp_ratio=4.625,
45
+ norm_layer="LayerNorm",
46
+ drop_path_rate=0.1,
47
+ qkv_bias=True,
48
+ use_abs_pos=True,
49
+ tile_abs_pos=True,
50
+ global_att_blocks=(7, 15, 23, 31),
51
+ rel_pos_blocks=(),
52
+ use_rope=True,
53
+ use_interp_rope=True,
54
+ window_size=24,
55
+ pretrain_use_cls_token=True,
56
+ retain_cls_token=False,
57
+ ln_pre=True,
58
+ ln_post=False,
59
+ return_interm_layers=False,
60
+ bias_patch_embed=False,
61
+ compile_mode=compile_mode,
62
+ )
63
+ return Sam3DualViTDetNeck(
64
+ position_encoding=position_encoding,
65
+ d_model=256,
66
+ scale_factors=[4.0, 2.0, 1.0, 0.5],
67
+ trunk=vit_backbone,
68
+ add_sam2_neck=enable_inst_interactivity,
69
+ )
70
+
71
+
72
+ def _create_sam3_transformer() -> TransformerWrapper:
73
+ """Create SAM3 detector encoder and decoder."""
74
+ encoder: TransformerEncoderFusion = TransformerEncoderFusion(
75
+ layer=TransformerEncoderLayer(
76
+ d_model=256,
77
+ dim_feedforward=2048,
78
+ dropout=0.1,
79
+ pos_enc_at_attn=True,
80
+ pos_enc_at_cross_attn_keys=False,
81
+ pos_enc_at_cross_attn_queries=False,
82
+ pre_norm=True,
83
+ self_attention=nn.MultiheadAttention(
84
+ num_heads=8,
85
+ dropout=0.1,
86
+ embed_dim=256,
87
+ batch_first=True,
88
+ ),
89
+ cross_attention=nn.MultiheadAttention(
90
+ num_heads=8,
91
+ dropout=0.1,
92
+ embed_dim=256,
93
+ batch_first=True,
94
+ ),
95
+ ),
96
+ num_layers=6,
97
+ d_model=256,
98
+ num_feature_levels=1,
99
+ frozen=False,
100
+ use_act_checkpoint=True,
101
+ add_pooled_text_to_img_feat=False,
102
+ pool_text_with_mask=True,
103
+ )
104
+ decoder: TransformerDecoder = TransformerDecoder(
105
+ layer=TransformerDecoderLayer(
106
+ d_model=256,
107
+ dim_feedforward=2048,
108
+ dropout=0.1,
109
+ cross_attention=nn.MultiheadAttention(
110
+ num_heads=8,
111
+ dropout=0.1,
112
+ embed_dim=256,
113
+ ),
114
+ n_heads=8,
115
+ use_text_cross_attention=True,
116
+ ),
117
+ num_layers=6,
118
+ num_queries=200,
119
+ return_intermediate=True,
120
+ box_refine=True,
121
+ num_o2m_queries=0,
122
+ dac=True,
123
+ boxRPB="log",
124
+ d_model=256,
125
+ frozen=False,
126
+ interaction_layer=None,
127
+ dac_use_selfatt_ln=True,
128
+ use_act_checkpoint=True,
129
+ presence_token=True,
130
+ )
131
+
132
+ return TransformerWrapper(encoder=encoder, decoder=decoder, d_model=256)
133
+
134
+
135
+ def build_sam3_image_model(checkpoint_path: str, enable_segmentation: bool = True, compile: bool = False):
136
+ """Build SAM3 image model.
137
+
138
+ Args:
139
+ checkpoint_path: Optional path to model checkpoint
140
+ enable_segmentation: Whether to enable segmentation head
141
+ compile: To enable compilation, set to "default"
142
+
143
+ Returns:
144
+ A SAM3 image model
145
+ """
146
+ try:
147
+ import clip
148
+ except ImportError:
149
+ from ultralytics.utils.checks import check_requirements
150
+
151
+ check_requirements("git+https://github.com/ultralytics/CLIP.git")
152
+ import clip
153
+ # Create visual components
154
+ compile_mode = "default" if compile else None
155
+ vision_encoder = _create_vision_backbone(compile_mode=compile_mode, enable_inst_interactivity=True)
156
+
157
+ # Create text components
158
+ text_encoder = VETextEncoder(
159
+ tokenizer=clip.simple_tokenizer.SimpleTokenizer(),
160
+ d_model=256,
161
+ width=1024,
162
+ heads=16,
163
+ layers=24,
164
+ )
165
+
166
+ # Create visual-language backbone
167
+ backbone = SAM3VLBackbone(visual=vision_encoder, text=text_encoder, scalp=1)
168
+
169
+ # Create transformer components
170
+ transformer = _create_sam3_transformer()
171
+
172
+ # Create dot product scoring
173
+ dot_prod_scoring = DotProductScoring(
174
+ d_model=256,
175
+ d_proj=256,
176
+ prompt_mlp=MLP(
177
+ input_dim=256,
178
+ hidden_dim=2048,
179
+ output_dim=256,
180
+ num_layers=2,
181
+ residual=True,
182
+ out_norm=nn.LayerNorm(256),
183
+ ),
184
+ )
185
+
186
+ # Create segmentation head if enabled
187
+ segmentation_head = (
188
+ UniversalSegmentationHead(
189
+ hidden_dim=256,
190
+ upsampling_stages=3,
191
+ aux_masks=False,
192
+ presence_head=False,
193
+ dot_product_scorer=None,
194
+ act_ckpt=True,
195
+ cross_attend_prompt=nn.MultiheadAttention(
196
+ num_heads=8,
197
+ dropout=0,
198
+ embed_dim=256,
199
+ ),
200
+ pixel_decoder=PixelDecoder(
201
+ num_upsampling_stages=3,
202
+ interpolation_mode="nearest",
203
+ hidden_dim=256,
204
+ compile_mode=compile_mode,
205
+ ),
206
+ )
207
+ if enable_segmentation
208
+ else None
209
+ )
210
+
211
+ # Create geometry encoder
212
+ input_geometry_encoder = SequenceGeometryEncoder(
213
+ pos_enc=PositionEmbeddingSine(
214
+ num_pos_feats=256,
215
+ normalize=True,
216
+ scale=None,
217
+ temperature=10000,
218
+ ),
219
+ encode_boxes_as_points=False,
220
+ boxes_direct_project=True,
221
+ boxes_pool=True,
222
+ boxes_pos_enc=True,
223
+ d_model=256,
224
+ num_layers=3,
225
+ layer=TransformerEncoderLayer(
226
+ d_model=256,
227
+ dim_feedforward=2048,
228
+ dropout=0.1,
229
+ pos_enc_at_attn=False,
230
+ pre_norm=True,
231
+ pos_enc_at_cross_attn_queries=False,
232
+ pos_enc_at_cross_attn_keys=True,
233
+ ),
234
+ use_act_ckpt=True,
235
+ add_cls=True,
236
+ add_post_encode_proj=True,
237
+ )
238
+
239
+ # Create the SAM3SemanticModel model
240
+ model = SAM3SemanticModel(
241
+ backbone=backbone,
242
+ transformer=transformer,
243
+ input_geometry_encoder=input_geometry_encoder,
244
+ segmentation_head=segmentation_head,
245
+ num_feature_levels=1,
246
+ o2m_mask_predict=True,
247
+ dot_prod_scoring=dot_prod_scoring,
248
+ use_instance_query=False,
249
+ multimask_output=True,
250
+ )
251
+
252
+ # Load checkpoint
253
+ model = _load_checkpoint(model, checkpoint_path)
254
+ model.eval()
255
+ return model
256
+
257
+
258
+ def build_interactive_sam3(checkpoint_path: str, compile=None, with_backbone=True) -> SAM3Model:
259
+ """Build the SAM3 Tracker module for video tracking.
260
+
261
+ Returns:
262
+ Sam3TrackerPredictor: Wrapped SAM3 Tracker module
263
+ """
264
+ # Create model components
265
+ memory_encoder = MemoryEncoder(out_dim=64, interpol_size=[1152, 1152])
266
+ memory_attention = MemoryAttention(
267
+ batch_first=True,
268
+ d_model=256,
269
+ pos_enc_at_input=True,
270
+ layer=MemoryAttentionLayer(
271
+ dim_feedforward=2048,
272
+ dropout=0.1,
273
+ pos_enc_at_attn=False,
274
+ pos_enc_at_cross_attn_keys=True,
275
+ pos_enc_at_cross_attn_queries=False,
276
+ self_attn=RoPEAttention(
277
+ embedding_dim=256,
278
+ num_heads=1,
279
+ downsample_rate=1,
280
+ rope_theta=10000.0,
281
+ feat_sizes=[72, 72],
282
+ ),
283
+ d_model=256,
284
+ cross_attn=RoPEAttention(
285
+ embedding_dim=256,
286
+ num_heads=1,
287
+ downsample_rate=1,
288
+ kv_in_dim=64,
289
+ rope_theta=10000.0,
290
+ feat_sizes=[72, 72],
291
+ rope_k_repeat=True,
292
+ ),
293
+ ),
294
+ num_layers=4,
295
+ )
296
+
297
+ backbone = (
298
+ SAM3VLBackbone(scalp=1, visual=_create_vision_backbone(compile_mode=compile), text=None)
299
+ if with_backbone
300
+ else None
301
+ )
302
+ model = SAM3Model(
303
+ image_size=1008,
304
+ image_encoder=backbone,
305
+ memory_attention=memory_attention,
306
+ memory_encoder=memory_encoder,
307
+ backbone_stride=14,
308
+ num_maskmem=7,
309
+ sigmoid_scale_for_mem_enc=20.0,
310
+ sigmoid_bias_for_mem_enc=-10.0,
311
+ use_mask_input_as_output_without_sam=True,
312
+ directly_add_no_mem_embed=True,
313
+ use_high_res_features_in_sam=True,
314
+ multimask_output_in_sam=True,
315
+ iou_prediction_use_sigmoid=True,
316
+ use_obj_ptrs_in_encoder=True,
317
+ add_tpos_enc_to_obj_ptrs=True,
318
+ only_obj_ptrs_in_the_past_for_eval=True,
319
+ pred_obj_scores=True,
320
+ pred_obj_scores_mlp=True,
321
+ fixed_no_obj_ptr=True,
322
+ multimask_output_for_tracking=True,
323
+ use_multimask_token_for_obj_ptr=True,
324
+ multimask_min_pt_num=0,
325
+ multimask_max_pt_num=1,
326
+ use_mlp_for_obj_ptr_proj=True,
327
+ compile_image_encoder=False,
328
+ no_obj_embed_spatial=True,
329
+ proj_tpos_enc_in_obj_ptrs=True,
330
+ use_signed_tpos_enc_to_obj_ptrs=True,
331
+ sam_mask_decoder_extra_args=dict(
332
+ dynamic_multimask_via_stability=True,
333
+ dynamic_multimask_stability_delta=0.05,
334
+ dynamic_multimask_stability_thresh=0.98,
335
+ ),
336
+ )
337
+
338
+ # Load checkpoint if provided
339
+ model = _load_checkpoint(model, checkpoint_path, interactive=True)
340
+
341
+ # Setup device and mode
342
+ model.eval()
343
+ return model
344
+
345
+
346
+ def _load_checkpoint(model, checkpoint, interactive=False):
347
+ """Load SAM3 model checkpoint from file."""
348
+ with open(checkpoint, "rb") as f:
349
+ ckpt = torch_load(f)
350
+ if "model" in ckpt and isinstance(ckpt["model"], dict):
351
+ ckpt = ckpt["model"]
352
+ sam3_image_ckpt = {k.replace("detector.", ""): v for k, v in ckpt.items() if "detector" in k}
353
+ if interactive:
354
+ sam3_image_ckpt.update(
355
+ {
356
+ k.replace("backbone.vision_backbone", "image_encoder.vision_backbone"): v
357
+ for k, v in sam3_image_ckpt.items()
358
+ if "backbone.vision_backbone" in k
359
+ }
360
+ )
361
+ sam3_image_ckpt.update(
362
+ {
363
+ k.replace("tracker.transformer.encoder", "memory_attention"): v
364
+ for k, v in ckpt.items()
365
+ if "tracker.transformer" in k
366
+ }
367
+ )
368
+ sam3_image_ckpt.update(
369
+ {
370
+ k.replace("tracker.maskmem_backbone", "memory_encoder"): v
371
+ for k, v in ckpt.items()
372
+ if "tracker.maskmem_backbone" in k
373
+ }
374
+ )
375
+ sam3_image_ckpt.update({k.replace("tracker.", ""): v for k, v in ckpt.items() if "tracker." in k})
376
+ model.load_state_dict(sam3_image_ckpt, strict=False)
377
+ return model
@@ -21,7 +21,7 @@ from pathlib import Path
21
21
  from ultralytics.engine.model import Model
22
22
  from ultralytics.utils.torch_utils import model_info
23
23
 
24
- from .predict import Predictor, SAM2Predictor
24
+ from .predict import Predictor, SAM2Predictor, SAM3Predictor
25
25
 
26
26
 
27
27
  class SAM(Model):
@@ -44,7 +44,7 @@ class SAM(Model):
44
44
  >>> sam = SAM("sam_b.pt")
45
45
  >>> results = sam.predict("image.jpg", points=[[500, 375]])
46
46
  >>> for r in results:
47
- >>> print(f"Detected {len(r.masks)} masks")
47
+ ... print(f"Detected {len(r.masks)} masks")
48
48
  """
49
49
 
50
50
  def __init__(self, model: str = "sam_b.pt") -> None:
@@ -59,6 +59,7 @@ class SAM(Model):
59
59
  if model and Path(model).suffix not in {".pt", ".pth"}:
60
60
  raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
61
61
  self.is_sam2 = "sam2" in Path(model).stem
62
+ self.is_sam3 = "sam3" in Path(model).stem
62
63
  super().__init__(model=model, task="segment")
63
64
 
64
65
  def _load(self, weights: str, task=None):
@@ -72,9 +73,14 @@ class SAM(Model):
72
73
  >>> sam = SAM("sam_b.pt")
73
74
  >>> sam._load("path/to/custom_weights.pt")
74
75
  """
75
- from .build import build_sam # slow import
76
+ if self.is_sam3:
77
+ from .build_sam3 import build_interactive_sam3
76
78
 
77
- self.model = build_sam(weights)
79
+ self.model = build_interactive_sam3(weights)
80
+ else:
81
+ from .build import build_sam # slow import
82
+
83
+ self.model = build_sam(weights)
78
84
 
79
85
  def predict(self, source, stream: bool = False, bboxes=None, points=None, labels=None, **kwargs):
80
86
  """Perform segmentation prediction on the given image or video source.
@@ -158,4 +164,6 @@ class SAM(Model):
158
164
  >>> print(task_map)
159
165
  {'segment': {'predictor': <class 'ultralytics.models.sam.predict.Predictor'>}}
160
166
  """
161
- return {"segment": {"predictor": SAM2Predictor if self.is_sam2 else Predictor}}
167
+ return {
168
+ "segment": {"predictor": SAM2Predictor if self.is_sam2 else SAM3Predictor if self.is_sam3 else Predictor}
169
+ }
@@ -79,6 +79,7 @@ class MaskDownSampler(nn.Module):
79
79
  padding: int = 0,
80
80
  total_stride: int = 16,
81
81
  activation: type[nn.Module] = nn.GELU,
82
+ interpol_size: tuple[int, int] | None = None,
82
83
  ):
83
84
  """Initialize a mask downsampler module for progressive downsampling and channel expansion."""
84
85
  super().__init__()
@@ -102,9 +103,24 @@ class MaskDownSampler(nn.Module):
102
103
  mask_in_chans = mask_out_chans
103
104
 
104
105
  self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
106
+ self.interpol_size = interpol_size
107
+ if self.interpol_size is not None:
108
+ assert isinstance(self.interpol_size, (list, tuple)), (
109
+ f"Unsupported type {type(self.interpol_size)}. Should be a list or tuple."
110
+ )
111
+ self.interpol_size = list(interpol_size)
112
+ assert len(self.interpol_size) == 2
105
113
 
106
114
  def forward(self, x: Tensor) -> Tensor:
107
115
  """Downsample and encode input mask to embed_dim channels using convolutional layers and LayerNorm2d."""
116
+ if self.interpol_size is not None and self.interpol_size != list(x.shape[-2:]):
117
+ x = F.interpolate(
118
+ x.float(),
119
+ size=self.interpol_size,
120
+ align_corners=False,
121
+ mode="bilinear",
122
+ antialias=True,
123
+ ).to(x.dtype)
108
124
  return self.encoder(x)
109
125
 
110
126
 
@@ -429,13 +445,7 @@ class RoPEAttention(Attention):
429
445
  )
430
446
 
431
447
  # Attention
432
- _, _, _, c_per_head = q.shape
433
- attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
434
- attn = attn / math.sqrt(c_per_head)
435
- attn = torch.softmax(attn, dim=-1)
436
-
437
- # Get output
438
- out = attn @ v
448
+ out = F.scaled_dot_product_attention(q, k, v)
439
449
 
440
450
  out = self._recombine_heads(out)
441
451
  out = self.out_proj(out)
@@ -1033,6 +1043,7 @@ class PatchEmbed(nn.Module):
1033
1043
  padding: tuple[int, int] = (0, 0),
1034
1044
  in_chans: int = 3,
1035
1045
  embed_dim: int = 768,
1046
+ bias: bool = True,
1036
1047
  ) -> None:
1037
1048
  """Initialize the PatchEmbed module for converting image patches to embeddings.
1038
1049
 
@@ -1045,10 +1056,11 @@ class PatchEmbed(nn.Module):
1045
1056
  padding (tuple[int, int]): Padding applied to the input before convolution.
1046
1057
  in_chans (int): Number of input image channels.
1047
1058
  embed_dim (int): Dimensionality of the output patch embeddings.
1059
+ bias (bool): Whether to include a bias term in the convolutional layer.
1048
1060
  """
1049
1061
  super().__init__()
1050
1062
 
1051
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
1063
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
1052
1064
 
1053
1065
  def forward(self, x: torch.Tensor) -> torch.Tensor:
1054
1066
  """Compute patch embedding by applying convolution and transposing resulting tensor."""
@@ -436,9 +436,8 @@ class SAM2MaskDecoder(nn.Module):
436
436
  def _get_stability_scores(self, mask_logits):
437
437
  """Compute mask stability scores based on IoU between upper and lower thresholds."""
438
438
  mask_logits = mask_logits.flatten(-2)
439
- stability_delta = self.dynamic_multimask_stability_delta
440
- area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
441
- area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
439
+ area_i = torch.sum(mask_logits > self.dynamic_multimask_stability_delta, dim=-1).float()
440
+ area_u = torch.sum(mask_logits > -self.dynamic_multimask_stability_delta, dim=-1).float()
442
441
  return torch.where(area_u > 0, area_i / area_u, 1.0)
443
442
 
444
443
  def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
@@ -361,6 +361,7 @@ class MemoryEncoder(nn.Module):
361
361
  self,
362
362
  out_dim,
363
363
  in_dim=256, # in_dim of pix_feats
364
+ interpol_size: tuple[int, int] | None = None,
364
365
  ):
365
366
  """Initialize the MemoryEncoder for encoding pixel features and masks into memory representations.
366
367
 
@@ -370,10 +371,12 @@ class MemoryEncoder(nn.Module):
370
371
  Args:
371
372
  out_dim (int): Output dimension of the encoded features.
372
373
  in_dim (int): Input dimension of the pixel features.
374
+ interpol_size (tuple[int, int] | None): Size to interpolate masks to. If None, uses the size of pixel
375
+ features.
373
376
  """
374
377
  super().__init__()
375
378
 
376
- self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1)
379
+ self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1, interpol_size=interpol_size)
377
380
 
378
381
  self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
379
382
  self.fuser = Fuser(CXBlock(dim=256), num_layers=2)
@@ -59,6 +59,8 @@ class MemoryAttentionLayer(nn.Module):
59
59
  pos_enc_at_attn: bool = False,
60
60
  pos_enc_at_cross_attn_keys: bool = True,
61
61
  pos_enc_at_cross_attn_queries: bool = False,
62
+ self_attn: nn.Module | None = None,
63
+ cross_attn: nn.Module | None = None,
62
64
  ):
63
65
  """Initialize a memory attention layer with self-attention, cross-attention, and feedforward components.
64
66
 
@@ -69,13 +71,15 @@ class MemoryAttentionLayer(nn.Module):
69
71
  pos_enc_at_attn (bool): Whether to add positional encoding at attention.
70
72
  pos_enc_at_cross_attn_keys (bool): Whether to add positional encoding to cross-attention keys.
71
73
  pos_enc_at_cross_attn_queries (bool): Whether to add positional encoding to cross-attention queries.
74
+ self_attn (nn.Module | None): Custom self-attention module. If None, a default RoPEAttention is used.
75
+ cross_attn (nn.Module | None): Custom cross-attention module. If None, a default RoPEAttention is used.
72
76
  """
73
77
  super().__init__()
74
78
  self.d_model = d_model
75
79
  self.dim_feedforward = dim_feedforward
76
80
  self.dropout_value = dropout
77
- self.self_attn = RoPEAttention(embedding_dim=256, num_heads=1, downsample_rate=1)
78
- self.cross_attn_image = RoPEAttention(
81
+ self.self_attn = self_attn or RoPEAttention(embedding_dim=256, num_heads=1, downsample_rate=1)
82
+ self.cross_attn_image = cross_attn or RoPEAttention(
79
83
  rope_k_repeat=True,
80
84
  embedding_dim=256,
81
85
  num_heads=1,