dgenerate-ultralytics-headless 8.3.236__py3-none-any.whl → 8.3.237__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.236.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/METADATA +1 -1
- {dgenerate_ultralytics_headless-8.3.236.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/RECORD +38 -25
- ultralytics/__init__.py +1 -1
- ultralytics/engine/exporter.py +17 -10
- ultralytics/engine/predictor.py +3 -2
- ultralytics/engine/trainer.py +8 -0
- ultralytics/models/rtdetr/val.py +5 -1
- ultralytics/models/sam/__init__.py +14 -1
- ultralytics/models/sam/build.py +17 -8
- ultralytics/models/sam/build_sam3.py +374 -0
- ultralytics/models/sam/model.py +12 -4
- ultralytics/models/sam/modules/blocks.py +20 -8
- ultralytics/models/sam/modules/decoders.py +2 -3
- ultralytics/models/sam/modules/encoders.py +4 -1
- ultralytics/models/sam/modules/memory_attention.py +6 -2
- ultralytics/models/sam/modules/sam.py +150 -6
- ultralytics/models/sam/modules/utils.py +134 -4
- ultralytics/models/sam/predict.py +2076 -118
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +535 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +198 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +357 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/tokenizer_ve.py +242 -0
- ultralytics/models/sam/sam3/vitdet.py +546 -0
- ultralytics/models/sam/sam3/vl_combiner.py +165 -0
- ultralytics/models/yolo/obb/val.py +18 -7
- ultralytics/nn/modules/transformer.py +21 -1
- ultralytics/utils/checks.py +2 -2
- ultralytics/utils/ops.py +1 -3
- {dgenerate_ultralytics_headless-8.3.236.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.236.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.236.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.236.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,374 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
|
|
3
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
|
4
|
+
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
|
|
7
|
+
from ultralytics.nn.modules.transformer import MLP
|
|
8
|
+
from ultralytics.utils.patches import torch_load
|
|
9
|
+
|
|
10
|
+
from .modules.blocks import PositionEmbeddingSine, RoPEAttention
|
|
11
|
+
from .modules.encoders import MemoryEncoder
|
|
12
|
+
from .modules.memory_attention import MemoryAttention, MemoryAttentionLayer
|
|
13
|
+
from .modules.sam import SAM3Model
|
|
14
|
+
from .sam3.decoder import TransformerDecoder, TransformerDecoderLayer
|
|
15
|
+
from .sam3.encoder import TransformerEncoderFusion, TransformerEncoderLayer
|
|
16
|
+
from .sam3.geometry_encoders import SequenceGeometryEncoder
|
|
17
|
+
from .sam3.maskformer_segmentation import PixelDecoder, UniversalSegmentationHead
|
|
18
|
+
from .sam3.model_misc import DotProductScoring, TransformerWrapper
|
|
19
|
+
from .sam3.necks import Sam3DualViTDetNeck
|
|
20
|
+
from .sam3.sam3_image import SAM3SemanticModel
|
|
21
|
+
from .sam3.text_encoder_ve import VETextEncoder
|
|
22
|
+
from .sam3.tokenizer_ve import SimpleTokenizer
|
|
23
|
+
from .sam3.vitdet import ViT
|
|
24
|
+
from .sam3.vl_combiner import SAM3VLBackbone
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _create_vision_backbone(compile_mode=None, enable_inst_interactivity=True) -> Sam3DualViTDetNeck:
|
|
28
|
+
"""Create SAM3 visual backbone with ViT and neck."""
|
|
29
|
+
# Position encoding
|
|
30
|
+
position_encoding = PositionEmbeddingSine(
|
|
31
|
+
num_pos_feats=256,
|
|
32
|
+
normalize=True,
|
|
33
|
+
scale=None,
|
|
34
|
+
temperature=10000,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
# ViT backbone
|
|
38
|
+
vit_backbone = ViT(
|
|
39
|
+
img_size=1008,
|
|
40
|
+
pretrain_img_size=336,
|
|
41
|
+
patch_size=14,
|
|
42
|
+
embed_dim=1024,
|
|
43
|
+
depth=32,
|
|
44
|
+
num_heads=16,
|
|
45
|
+
mlp_ratio=4.625,
|
|
46
|
+
norm_layer="LayerNorm",
|
|
47
|
+
drop_path_rate=0.1,
|
|
48
|
+
qkv_bias=True,
|
|
49
|
+
use_abs_pos=True,
|
|
50
|
+
tile_abs_pos=True,
|
|
51
|
+
global_att_blocks=(7, 15, 23, 31),
|
|
52
|
+
rel_pos_blocks=(),
|
|
53
|
+
use_rope=True,
|
|
54
|
+
use_interp_rope=True,
|
|
55
|
+
window_size=24,
|
|
56
|
+
pretrain_use_cls_token=True,
|
|
57
|
+
retain_cls_token=False,
|
|
58
|
+
ln_pre=True,
|
|
59
|
+
ln_post=False,
|
|
60
|
+
return_interm_layers=False,
|
|
61
|
+
bias_patch_embed=False,
|
|
62
|
+
compile_mode=compile_mode,
|
|
63
|
+
)
|
|
64
|
+
return Sam3DualViTDetNeck(
|
|
65
|
+
position_encoding=position_encoding,
|
|
66
|
+
d_model=256,
|
|
67
|
+
scale_factors=[4.0, 2.0, 1.0, 0.5],
|
|
68
|
+
trunk=vit_backbone,
|
|
69
|
+
add_sam2_neck=enable_inst_interactivity,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _create_sam3_transformer() -> TransformerWrapper:
|
|
74
|
+
"""Create SAM3 detector encoder and decoder."""
|
|
75
|
+
encoder: TransformerEncoderFusion = TransformerEncoderFusion(
|
|
76
|
+
layer=TransformerEncoderLayer(
|
|
77
|
+
d_model=256,
|
|
78
|
+
dim_feedforward=2048,
|
|
79
|
+
dropout=0.1,
|
|
80
|
+
pos_enc_at_attn=True,
|
|
81
|
+
pos_enc_at_cross_attn_keys=False,
|
|
82
|
+
pos_enc_at_cross_attn_queries=False,
|
|
83
|
+
pre_norm=True,
|
|
84
|
+
self_attention=nn.MultiheadAttention(
|
|
85
|
+
num_heads=8,
|
|
86
|
+
dropout=0.1,
|
|
87
|
+
embed_dim=256,
|
|
88
|
+
batch_first=True,
|
|
89
|
+
),
|
|
90
|
+
cross_attention=nn.MultiheadAttention(
|
|
91
|
+
num_heads=8,
|
|
92
|
+
dropout=0.1,
|
|
93
|
+
embed_dim=256,
|
|
94
|
+
batch_first=True,
|
|
95
|
+
),
|
|
96
|
+
),
|
|
97
|
+
num_layers=6,
|
|
98
|
+
d_model=256,
|
|
99
|
+
num_feature_levels=1,
|
|
100
|
+
frozen=False,
|
|
101
|
+
use_act_checkpoint=True,
|
|
102
|
+
add_pooled_text_to_img_feat=False,
|
|
103
|
+
pool_text_with_mask=True,
|
|
104
|
+
)
|
|
105
|
+
decoder: TransformerDecoder = TransformerDecoder(
|
|
106
|
+
layer=TransformerDecoderLayer(
|
|
107
|
+
d_model=256,
|
|
108
|
+
dim_feedforward=2048,
|
|
109
|
+
dropout=0.1,
|
|
110
|
+
cross_attention=nn.MultiheadAttention(
|
|
111
|
+
num_heads=8,
|
|
112
|
+
dropout=0.1,
|
|
113
|
+
embed_dim=256,
|
|
114
|
+
),
|
|
115
|
+
n_heads=8,
|
|
116
|
+
use_text_cross_attention=True,
|
|
117
|
+
),
|
|
118
|
+
num_layers=6,
|
|
119
|
+
num_queries=200,
|
|
120
|
+
return_intermediate=True,
|
|
121
|
+
box_refine=True,
|
|
122
|
+
num_o2m_queries=0,
|
|
123
|
+
dac=True,
|
|
124
|
+
boxRPB="log",
|
|
125
|
+
d_model=256,
|
|
126
|
+
frozen=False,
|
|
127
|
+
interaction_layer=None,
|
|
128
|
+
dac_use_selfatt_ln=True,
|
|
129
|
+
use_act_checkpoint=True,
|
|
130
|
+
presence_token=True,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
return TransformerWrapper(encoder=encoder, decoder=decoder, d_model=256)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def build_sam3_image_model(
|
|
137
|
+
checkpoint_path: str, bpe_path: str, enable_segmentation: bool = True, compile: bool = False
|
|
138
|
+
):
|
|
139
|
+
"""Build SAM3 image model.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
checkpoint_path: Optional path to model checkpoint
|
|
143
|
+
bpe_path: Path to the BPE tokenizer vocabulary
|
|
144
|
+
enable_segmentation: Whether to enable segmentation head
|
|
145
|
+
compile: To enable compilation, set to "default"
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
A SAM3 image model
|
|
149
|
+
"""
|
|
150
|
+
# Create visual components
|
|
151
|
+
compile_mode = "default" if compile else None
|
|
152
|
+
vision_encoder = _create_vision_backbone(compile_mode=compile_mode, enable_inst_interactivity=True)
|
|
153
|
+
|
|
154
|
+
# Create text components
|
|
155
|
+
text_encoder = VETextEncoder(
|
|
156
|
+
tokenizer=SimpleTokenizer(bpe_path=bpe_path),
|
|
157
|
+
d_model=256,
|
|
158
|
+
width=1024,
|
|
159
|
+
heads=16,
|
|
160
|
+
layers=24,
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
# Create visual-language backbone
|
|
164
|
+
backbone = SAM3VLBackbone(visual=vision_encoder, text=text_encoder, scalp=1)
|
|
165
|
+
|
|
166
|
+
# Create transformer components
|
|
167
|
+
transformer = _create_sam3_transformer()
|
|
168
|
+
|
|
169
|
+
# Create dot product scoring
|
|
170
|
+
dot_prod_scoring = DotProductScoring(
|
|
171
|
+
d_model=256,
|
|
172
|
+
d_proj=256,
|
|
173
|
+
prompt_mlp=MLP(
|
|
174
|
+
input_dim=256,
|
|
175
|
+
hidden_dim=2048,
|
|
176
|
+
output_dim=256,
|
|
177
|
+
num_layers=2,
|
|
178
|
+
residual=True,
|
|
179
|
+
out_norm=nn.LayerNorm(256),
|
|
180
|
+
),
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
# Create segmentation head if enabled
|
|
184
|
+
segmentation_head = (
|
|
185
|
+
UniversalSegmentationHead(
|
|
186
|
+
hidden_dim=256,
|
|
187
|
+
upsampling_stages=3,
|
|
188
|
+
aux_masks=False,
|
|
189
|
+
presence_head=False,
|
|
190
|
+
dot_product_scorer=None,
|
|
191
|
+
act_ckpt=True,
|
|
192
|
+
cross_attend_prompt=nn.MultiheadAttention(
|
|
193
|
+
num_heads=8,
|
|
194
|
+
dropout=0,
|
|
195
|
+
embed_dim=256,
|
|
196
|
+
),
|
|
197
|
+
pixel_decoder=PixelDecoder(
|
|
198
|
+
num_upsampling_stages=3,
|
|
199
|
+
interpolation_mode="nearest",
|
|
200
|
+
hidden_dim=256,
|
|
201
|
+
compile_mode=compile_mode,
|
|
202
|
+
),
|
|
203
|
+
)
|
|
204
|
+
if enable_segmentation
|
|
205
|
+
else None
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
# Create geometry encoder
|
|
209
|
+
input_geometry_encoder = SequenceGeometryEncoder(
|
|
210
|
+
pos_enc=PositionEmbeddingSine(
|
|
211
|
+
num_pos_feats=256,
|
|
212
|
+
normalize=True,
|
|
213
|
+
scale=None,
|
|
214
|
+
temperature=10000,
|
|
215
|
+
),
|
|
216
|
+
encode_boxes_as_points=False,
|
|
217
|
+
boxes_direct_project=True,
|
|
218
|
+
boxes_pool=True,
|
|
219
|
+
boxes_pos_enc=True,
|
|
220
|
+
d_model=256,
|
|
221
|
+
num_layers=3,
|
|
222
|
+
layer=TransformerEncoderLayer(
|
|
223
|
+
d_model=256,
|
|
224
|
+
dim_feedforward=2048,
|
|
225
|
+
dropout=0.1,
|
|
226
|
+
pos_enc_at_attn=False,
|
|
227
|
+
pre_norm=True,
|
|
228
|
+
pos_enc_at_cross_attn_queries=False,
|
|
229
|
+
pos_enc_at_cross_attn_keys=True,
|
|
230
|
+
),
|
|
231
|
+
use_act_ckpt=True,
|
|
232
|
+
add_cls=True,
|
|
233
|
+
add_post_encode_proj=True,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
# Create the SAM3SemanticModel model
|
|
237
|
+
model = SAM3SemanticModel(
|
|
238
|
+
backbone=backbone,
|
|
239
|
+
transformer=transformer,
|
|
240
|
+
input_geometry_encoder=input_geometry_encoder,
|
|
241
|
+
segmentation_head=segmentation_head,
|
|
242
|
+
num_feature_levels=1,
|
|
243
|
+
o2m_mask_predict=True,
|
|
244
|
+
dot_prod_scoring=dot_prod_scoring,
|
|
245
|
+
use_instance_query=False,
|
|
246
|
+
multimask_output=True,
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
# Load checkpoint
|
|
250
|
+
model = _load_checkpoint(model, checkpoint_path)
|
|
251
|
+
model.eval()
|
|
252
|
+
return model
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def build_interactive_sam3(checkpoint_path: str, compile=None, with_backbone=True) -> SAM3Model:
|
|
256
|
+
"""Build the SAM3 Tracker module for video tracking.
|
|
257
|
+
|
|
258
|
+
Returns:
|
|
259
|
+
Sam3TrackerPredictor: Wrapped SAM3 Tracker module
|
|
260
|
+
"""
|
|
261
|
+
# Create model components
|
|
262
|
+
memory_encoder = MemoryEncoder(out_dim=64, interpol_size=[1152, 1152])
|
|
263
|
+
memory_attention = MemoryAttention(
|
|
264
|
+
batch_first=True,
|
|
265
|
+
d_model=256,
|
|
266
|
+
pos_enc_at_input=True,
|
|
267
|
+
layer=MemoryAttentionLayer(
|
|
268
|
+
dim_feedforward=2048,
|
|
269
|
+
dropout=0.1,
|
|
270
|
+
pos_enc_at_attn=False,
|
|
271
|
+
pos_enc_at_cross_attn_keys=True,
|
|
272
|
+
pos_enc_at_cross_attn_queries=False,
|
|
273
|
+
self_attn=RoPEAttention(
|
|
274
|
+
embedding_dim=256,
|
|
275
|
+
num_heads=1,
|
|
276
|
+
downsample_rate=1,
|
|
277
|
+
rope_theta=10000.0,
|
|
278
|
+
feat_sizes=[72, 72],
|
|
279
|
+
),
|
|
280
|
+
d_model=256,
|
|
281
|
+
cross_attn=RoPEAttention(
|
|
282
|
+
embedding_dim=256,
|
|
283
|
+
num_heads=1,
|
|
284
|
+
downsample_rate=1,
|
|
285
|
+
kv_in_dim=64,
|
|
286
|
+
rope_theta=10000.0,
|
|
287
|
+
feat_sizes=[72, 72],
|
|
288
|
+
rope_k_repeat=True,
|
|
289
|
+
),
|
|
290
|
+
),
|
|
291
|
+
num_layers=4,
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
backbone = (
|
|
295
|
+
SAM3VLBackbone(scalp=1, visual=_create_vision_backbone(compile_mode=compile), text=None)
|
|
296
|
+
if with_backbone
|
|
297
|
+
else None
|
|
298
|
+
)
|
|
299
|
+
model = SAM3Model(
|
|
300
|
+
image_size=1008,
|
|
301
|
+
image_encoder=backbone,
|
|
302
|
+
memory_attention=memory_attention,
|
|
303
|
+
memory_encoder=memory_encoder,
|
|
304
|
+
backbone_stride=14,
|
|
305
|
+
num_maskmem=7,
|
|
306
|
+
sigmoid_scale_for_mem_enc=20.0,
|
|
307
|
+
sigmoid_bias_for_mem_enc=-10.0,
|
|
308
|
+
use_mask_input_as_output_without_sam=True,
|
|
309
|
+
directly_add_no_mem_embed=True,
|
|
310
|
+
use_high_res_features_in_sam=True,
|
|
311
|
+
multimask_output_in_sam=True,
|
|
312
|
+
iou_prediction_use_sigmoid=True,
|
|
313
|
+
use_obj_ptrs_in_encoder=True,
|
|
314
|
+
add_tpos_enc_to_obj_ptrs=True,
|
|
315
|
+
only_obj_ptrs_in_the_past_for_eval=True,
|
|
316
|
+
pred_obj_scores=True,
|
|
317
|
+
pred_obj_scores_mlp=True,
|
|
318
|
+
fixed_no_obj_ptr=True,
|
|
319
|
+
multimask_output_for_tracking=True,
|
|
320
|
+
use_multimask_token_for_obj_ptr=True,
|
|
321
|
+
multimask_min_pt_num=0,
|
|
322
|
+
multimask_max_pt_num=1,
|
|
323
|
+
use_mlp_for_obj_ptr_proj=True,
|
|
324
|
+
compile_image_encoder=False,
|
|
325
|
+
no_obj_embed_spatial=True,
|
|
326
|
+
proj_tpos_enc_in_obj_ptrs=True,
|
|
327
|
+
use_signed_tpos_enc_to_obj_ptrs=True,
|
|
328
|
+
sam_mask_decoder_extra_args=dict(
|
|
329
|
+
dynamic_multimask_via_stability=True,
|
|
330
|
+
dynamic_multimask_stability_delta=0.05,
|
|
331
|
+
dynamic_multimask_stability_thresh=0.98,
|
|
332
|
+
),
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
# Load checkpoint if provided
|
|
336
|
+
model = _load_checkpoint(model, checkpoint_path, interactive=True)
|
|
337
|
+
|
|
338
|
+
# Setup device and mode
|
|
339
|
+
model.eval()
|
|
340
|
+
return model
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def _load_checkpoint(model, checkpoint, interactive=False):
|
|
344
|
+
"""Load SAM3 model checkpoint from file."""
|
|
345
|
+
with open(checkpoint, "rb") as f:
|
|
346
|
+
ckpt = torch_load(f)
|
|
347
|
+
if "model" in ckpt and isinstance(ckpt["model"], dict):
|
|
348
|
+
ckpt = ckpt["model"]
|
|
349
|
+
sam3_image_ckpt = {k.replace("detector.", ""): v for k, v in ckpt.items() if "detector" in k}
|
|
350
|
+
if interactive:
|
|
351
|
+
sam3_image_ckpt.update(
|
|
352
|
+
{
|
|
353
|
+
k.replace("backbone.vision_backbone", "image_encoder.vision_backbone"): v
|
|
354
|
+
for k, v in sam3_image_ckpt.items()
|
|
355
|
+
if "backbone.vision_backbone" in k
|
|
356
|
+
}
|
|
357
|
+
)
|
|
358
|
+
sam3_image_ckpt.update(
|
|
359
|
+
{
|
|
360
|
+
k.replace("tracker.transformer.encoder", "memory_attention"): v
|
|
361
|
+
for k, v in ckpt.items()
|
|
362
|
+
if "tracker.transformer" in k
|
|
363
|
+
}
|
|
364
|
+
)
|
|
365
|
+
sam3_image_ckpt.update(
|
|
366
|
+
{
|
|
367
|
+
k.replace("tracker.maskmem_backbone", "memory_encoder"): v
|
|
368
|
+
for k, v in ckpt.items()
|
|
369
|
+
if "tracker.maskmem_backbone" in k
|
|
370
|
+
}
|
|
371
|
+
)
|
|
372
|
+
sam3_image_ckpt.update({k.replace("tracker.", ""): v for k, v in ckpt.items() if "tracker." in k})
|
|
373
|
+
model.load_state_dict(sam3_image_ckpt, strict=False)
|
|
374
|
+
return model
|
ultralytics/models/sam/model.py
CHANGED
|
@@ -21,7 +21,7 @@ from pathlib import Path
|
|
|
21
21
|
from ultralytics.engine.model import Model
|
|
22
22
|
from ultralytics.utils.torch_utils import model_info
|
|
23
23
|
|
|
24
|
-
from .predict import Predictor, SAM2Predictor
|
|
24
|
+
from .predict import Predictor, SAM2Predictor, SAM3Predictor
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
class SAM(Model):
|
|
@@ -59,6 +59,7 @@ class SAM(Model):
|
|
|
59
59
|
if model and Path(model).suffix not in {".pt", ".pth"}:
|
|
60
60
|
raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
|
|
61
61
|
self.is_sam2 = "sam2" in Path(model).stem
|
|
62
|
+
self.is_sam3 = "sam3" in Path(model).stem
|
|
62
63
|
super().__init__(model=model, task="segment")
|
|
63
64
|
|
|
64
65
|
def _load(self, weights: str, task=None):
|
|
@@ -72,9 +73,14 @@ class SAM(Model):
|
|
|
72
73
|
>>> sam = SAM("sam_b.pt")
|
|
73
74
|
>>> sam._load("path/to/custom_weights.pt")
|
|
74
75
|
"""
|
|
75
|
-
|
|
76
|
+
if self.is_sam3:
|
|
77
|
+
from .build_sam3 import build_interactive_sam3
|
|
76
78
|
|
|
77
|
-
|
|
79
|
+
self.model = build_interactive_sam3(weights)
|
|
80
|
+
else:
|
|
81
|
+
from .build import build_sam # slow import
|
|
82
|
+
|
|
83
|
+
self.model = build_sam(weights)
|
|
78
84
|
|
|
79
85
|
def predict(self, source, stream: bool = False, bboxes=None, points=None, labels=None, **kwargs):
|
|
80
86
|
"""Perform segmentation prediction on the given image or video source.
|
|
@@ -158,4 +164,6 @@ class SAM(Model):
|
|
|
158
164
|
>>> print(task_map)
|
|
159
165
|
{'segment': {'predictor': <class 'ultralytics.models.sam.predict.Predictor'>}}
|
|
160
166
|
"""
|
|
161
|
-
return {
|
|
167
|
+
return {
|
|
168
|
+
"segment": {"predictor": SAM2Predictor if self.is_sam2 else SAM3Predictor if self.is_sam3 else Predictor}
|
|
169
|
+
}
|
|
@@ -79,6 +79,7 @@ class MaskDownSampler(nn.Module):
|
|
|
79
79
|
padding: int = 0,
|
|
80
80
|
total_stride: int = 16,
|
|
81
81
|
activation: type[nn.Module] = nn.GELU,
|
|
82
|
+
interpol_size: tuple[int, int] | None = None,
|
|
82
83
|
):
|
|
83
84
|
"""Initialize a mask downsampler module for progressive downsampling and channel expansion."""
|
|
84
85
|
super().__init__()
|
|
@@ -102,9 +103,24 @@ class MaskDownSampler(nn.Module):
|
|
|
102
103
|
mask_in_chans = mask_out_chans
|
|
103
104
|
|
|
104
105
|
self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
|
|
106
|
+
self.interpol_size = interpol_size
|
|
107
|
+
if self.interpol_size is not None:
|
|
108
|
+
assert isinstance(self.interpol_size, (list, tuple)), (
|
|
109
|
+
f"Unsupported type {type(self.interpol_size)}. Should be a list or tuple."
|
|
110
|
+
)
|
|
111
|
+
self.interpol_size = list(interpol_size)
|
|
112
|
+
assert len(self.interpol_size) == 2
|
|
105
113
|
|
|
106
114
|
def forward(self, x: Tensor) -> Tensor:
|
|
107
115
|
"""Downsample and encode input mask to embed_dim channels using convolutional layers and LayerNorm2d."""
|
|
116
|
+
if self.interpol_size is not None and self.interpol_size != list(x.shape[-2:]):
|
|
117
|
+
x = F.interpolate(
|
|
118
|
+
x.float(),
|
|
119
|
+
size=self.interpol_size,
|
|
120
|
+
align_corners=False,
|
|
121
|
+
mode="bilinear",
|
|
122
|
+
antialias=True,
|
|
123
|
+
).to(x.dtype)
|
|
108
124
|
return self.encoder(x)
|
|
109
125
|
|
|
110
126
|
|
|
@@ -429,13 +445,7 @@ class RoPEAttention(Attention):
|
|
|
429
445
|
)
|
|
430
446
|
|
|
431
447
|
# Attention
|
|
432
|
-
|
|
433
|
-
attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
|
|
434
|
-
attn = attn / math.sqrt(c_per_head)
|
|
435
|
-
attn = torch.softmax(attn, dim=-1)
|
|
436
|
-
|
|
437
|
-
# Get output
|
|
438
|
-
out = attn @ v
|
|
448
|
+
out = F.scaled_dot_product_attention(q, k, v)
|
|
439
449
|
|
|
440
450
|
out = self._recombine_heads(out)
|
|
441
451
|
out = self.out_proj(out)
|
|
@@ -1033,6 +1043,7 @@ class PatchEmbed(nn.Module):
|
|
|
1033
1043
|
padding: tuple[int, int] = (0, 0),
|
|
1034
1044
|
in_chans: int = 3,
|
|
1035
1045
|
embed_dim: int = 768,
|
|
1046
|
+
bias: bool = True,
|
|
1036
1047
|
) -> None:
|
|
1037
1048
|
"""Initialize the PatchEmbed module for converting image patches to embeddings.
|
|
1038
1049
|
|
|
@@ -1045,10 +1056,11 @@ class PatchEmbed(nn.Module):
|
|
|
1045
1056
|
padding (tuple[int, int]): Padding applied to the input before convolution.
|
|
1046
1057
|
in_chans (int): Number of input image channels.
|
|
1047
1058
|
embed_dim (int): Dimensionality of the output patch embeddings.
|
|
1059
|
+
bias (bool): Whether to include a bias term in the convolutional layer.
|
|
1048
1060
|
"""
|
|
1049
1061
|
super().__init__()
|
|
1050
1062
|
|
|
1051
|
-
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
|
|
1063
|
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
|
|
1052
1064
|
|
|
1053
1065
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
1054
1066
|
"""Compute patch embedding by applying convolution and transposing resulting tensor."""
|
|
@@ -436,9 +436,8 @@ class SAM2MaskDecoder(nn.Module):
|
|
|
436
436
|
def _get_stability_scores(self, mask_logits):
|
|
437
437
|
"""Compute mask stability scores based on IoU between upper and lower thresholds."""
|
|
438
438
|
mask_logits = mask_logits.flatten(-2)
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
|
|
439
|
+
area_i = torch.sum(mask_logits > self.dynamic_multimask_stability_delta, dim=-1).float()
|
|
440
|
+
area_u = torch.sum(mask_logits > -self.dynamic_multimask_stability_delta, dim=-1).float()
|
|
442
441
|
return torch.where(area_u > 0, area_i / area_u, 1.0)
|
|
443
442
|
|
|
444
443
|
def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
|
|
@@ -361,6 +361,7 @@ class MemoryEncoder(nn.Module):
|
|
|
361
361
|
self,
|
|
362
362
|
out_dim,
|
|
363
363
|
in_dim=256, # in_dim of pix_feats
|
|
364
|
+
interpol_size: tuple[int, int] | None = None,
|
|
364
365
|
):
|
|
365
366
|
"""Initialize the MemoryEncoder for encoding pixel features and masks into memory representations.
|
|
366
367
|
|
|
@@ -370,10 +371,12 @@ class MemoryEncoder(nn.Module):
|
|
|
370
371
|
Args:
|
|
371
372
|
out_dim (int): Output dimension of the encoded features.
|
|
372
373
|
in_dim (int): Input dimension of the pixel features.
|
|
374
|
+
interpol_size (tuple[int, int] | None): Size to interpolate masks to. If None, uses the size of pixel
|
|
375
|
+
features.
|
|
373
376
|
"""
|
|
374
377
|
super().__init__()
|
|
375
378
|
|
|
376
|
-
self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1)
|
|
379
|
+
self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1, interpol_size=interpol_size)
|
|
377
380
|
|
|
378
381
|
self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
|
|
379
382
|
self.fuser = Fuser(CXBlock(dim=256), num_layers=2)
|
|
@@ -59,6 +59,8 @@ class MemoryAttentionLayer(nn.Module):
|
|
|
59
59
|
pos_enc_at_attn: bool = False,
|
|
60
60
|
pos_enc_at_cross_attn_keys: bool = True,
|
|
61
61
|
pos_enc_at_cross_attn_queries: bool = False,
|
|
62
|
+
self_attn: nn.Module | None = None,
|
|
63
|
+
cross_attn: nn.Module | None = None,
|
|
62
64
|
):
|
|
63
65
|
"""Initialize a memory attention layer with self-attention, cross-attention, and feedforward components.
|
|
64
66
|
|
|
@@ -69,13 +71,15 @@ class MemoryAttentionLayer(nn.Module):
|
|
|
69
71
|
pos_enc_at_attn (bool): Whether to add positional encoding at attention.
|
|
70
72
|
pos_enc_at_cross_attn_keys (bool): Whether to add positional encoding to cross-attention keys.
|
|
71
73
|
pos_enc_at_cross_attn_queries (bool): Whether to add positional encoding to cross-attention queries.
|
|
74
|
+
self_attn (nn.Module | None): Custom self-attention module. If None, a default RoPEAttention is used.
|
|
75
|
+
cross_attn (nn.Module | None): Custom cross-attention module. If None, a default RoPEAttention is used.
|
|
72
76
|
"""
|
|
73
77
|
super().__init__()
|
|
74
78
|
self.d_model = d_model
|
|
75
79
|
self.dim_feedforward = dim_feedforward
|
|
76
80
|
self.dropout_value = dropout
|
|
77
|
-
self.self_attn = RoPEAttention(embedding_dim=256, num_heads=1, downsample_rate=1)
|
|
78
|
-
self.cross_attn_image = RoPEAttention(
|
|
81
|
+
self.self_attn = self_attn or RoPEAttention(embedding_dim=256, num_heads=1, downsample_rate=1)
|
|
82
|
+
self.cross_attn_image = cross_attn or RoPEAttention(
|
|
79
83
|
rope_k_repeat=True,
|
|
80
84
|
embedding_dim=256,
|
|
81
85
|
num_heads=1,
|