dgenerate-ultralytics-headless 8.3.236__py3-none-any.whl → 8.3.237__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {dgenerate_ultralytics_headless-8.3.236.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/METADATA +1 -1
  2. {dgenerate_ultralytics_headless-8.3.236.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/RECORD +38 -25
  3. ultralytics/__init__.py +1 -1
  4. ultralytics/engine/exporter.py +17 -10
  5. ultralytics/engine/predictor.py +3 -2
  6. ultralytics/engine/trainer.py +8 -0
  7. ultralytics/models/rtdetr/val.py +5 -1
  8. ultralytics/models/sam/__init__.py +14 -1
  9. ultralytics/models/sam/build.py +17 -8
  10. ultralytics/models/sam/build_sam3.py +374 -0
  11. ultralytics/models/sam/model.py +12 -4
  12. ultralytics/models/sam/modules/blocks.py +20 -8
  13. ultralytics/models/sam/modules/decoders.py +2 -3
  14. ultralytics/models/sam/modules/encoders.py +4 -1
  15. ultralytics/models/sam/modules/memory_attention.py +6 -2
  16. ultralytics/models/sam/modules/sam.py +150 -6
  17. ultralytics/models/sam/modules/utils.py +134 -4
  18. ultralytics/models/sam/predict.py +2076 -118
  19. ultralytics/models/sam/sam3/__init__.py +3 -0
  20. ultralytics/models/sam/sam3/decoder.py +546 -0
  21. ultralytics/models/sam/sam3/encoder.py +535 -0
  22. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  23. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  24. ultralytics/models/sam/sam3/model_misc.py +198 -0
  25. ultralytics/models/sam/sam3/necks.py +129 -0
  26. ultralytics/models/sam/sam3/sam3_image.py +357 -0
  27. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  28. ultralytics/models/sam/sam3/tokenizer_ve.py +242 -0
  29. ultralytics/models/sam/sam3/vitdet.py +546 -0
  30. ultralytics/models/sam/sam3/vl_combiner.py +165 -0
  31. ultralytics/models/yolo/obb/val.py +18 -7
  32. ultralytics/nn/modules/transformer.py +21 -1
  33. ultralytics/utils/checks.py +2 -2
  34. ultralytics/utils/ops.py +1 -3
  35. {dgenerate_ultralytics_headless-8.3.236.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/WHEEL +0 -0
  36. {dgenerate_ultralytics_headless-8.3.236.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/entry_points.txt +0 -0
  37. {dgenerate_ultralytics_headless-8.3.236.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/licenses/LICENSE +0 -0
  38. {dgenerate_ultralytics_headless-8.3.236.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,198 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
4
+
5
+ """Various utility models."""
6
+
7
+ from __future__ import annotations
8
+
9
+ import math
10
+
11
+ import numpy as np
12
+ import torch
13
+ from torch import Tensor, nn
14
+
15
+
16
+ class DotProductScoring(torch.nn.Module):
17
+ """A module that computes dot-product scores between a set of query features and a."""
18
+
19
+ def __init__(
20
+ self,
21
+ d_model,
22
+ d_proj,
23
+ prompt_mlp=None,
24
+ clamp_logits=True,
25
+ clamp_max_val=12.0,
26
+ ):
27
+ """Initialize the DotProductScoring module."""
28
+ super().__init__()
29
+ self.d_proj = d_proj
30
+ assert isinstance(prompt_mlp, torch.nn.Module) or prompt_mlp is None
31
+ self.prompt_mlp = prompt_mlp # an optional MLP projection for prompt
32
+ self.prompt_proj = torch.nn.Linear(d_model, d_proj)
33
+ self.hs_proj = torch.nn.Linear(d_model, d_proj)
34
+ self.scale = float(1.0 / np.sqrt(d_proj))
35
+ self.clamp_logits = clamp_logits
36
+ if self.clamp_logits:
37
+ self.clamp_max_val = clamp_max_val
38
+
39
+ def mean_pool_text(self, prompt, prompt_mask):
40
+ """Mean-pool the prompt embeddings over the valid tokens only."""
41
+ # is_valid has shape (seq, bs, 1), where 1 is valid and 0 is padding
42
+ is_valid = (~prompt_mask).to(prompt.dtype).permute(1, 0)[..., None]
43
+ # num_valid has shape (bs, 1)
44
+ num_valid = torch.clamp(torch.sum(is_valid, dim=0), min=1.0)
45
+ # mean pool over all the valid tokens -- pooled_prompt has shape (bs, proj_dim)
46
+ pooled_prompt = (prompt * is_valid).sum(dim=0) / num_valid
47
+ return pooled_prompt
48
+
49
+ def forward(self, hs, prompt, prompt_mask):
50
+ """Compute dot-product scores between hs and prompt."""
51
+ # hs has shape (num_layer, bs, num_query, d_model)
52
+ # prompt has shape (seq, bs, d_model)
53
+ # prompt_mask has shape (bs, seq), where 1 is valid and 0 is padding
54
+ assert hs.dim() == 4 and prompt.dim() == 3 and prompt_mask.dim() == 2
55
+
56
+ # apply MLP on prompt if specified
57
+ if self.prompt_mlp is not None:
58
+ prompt = self.prompt_mlp(prompt.to(hs.dtype))
59
+
60
+ # first, get the mean-pooled version of the prompt
61
+ pooled_prompt = self.mean_pool_text(prompt, prompt_mask)
62
+
63
+ # then, project pooled_prompt and hs to d_proj dimensions
64
+ proj_pooled_prompt = self.prompt_proj(pooled_prompt) # (bs, d_proj)
65
+ proj_hs = self.hs_proj(hs) # (num_layer, bs, num_query, d_proj)
66
+
67
+ # finally, get dot-product scores of shape (num_layer, bs, num_query, 1)
68
+ scores = torch.matmul(proj_hs, proj_pooled_prompt.unsqueeze(-1))
69
+ scores *= self.scale
70
+
71
+ # clamp scores to a max value to avoid numerical issues in loss or matcher
72
+ if self.clamp_logits:
73
+ scores.clamp_(min=-self.clamp_max_val, max=self.clamp_max_val)
74
+
75
+ return scores
76
+
77
+
78
+ class LayerScale(nn.Module):
79
+ """LayerScale module as introduced in "Meta Pseudo Labels" and used in."""
80
+
81
+ def __init__(
82
+ self,
83
+ dim: int,
84
+ init_values: float | Tensor = 1e-5,
85
+ inplace: bool = False,
86
+ ) -> None:
87
+ """Initialize the LayerScale module."""
88
+ super().__init__()
89
+ self.inplace = inplace
90
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
91
+
92
+ def forward(self, x: Tensor) -> Tensor:
93
+ """Apply LayerScale to the input tensor."""
94
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
95
+
96
+
97
+ class TransformerWrapper(nn.Module):
98
+ """A wrapper for the transformer consisting of an encoder and a decoder."""
99
+
100
+ def __init__(
101
+ self,
102
+ encoder,
103
+ decoder,
104
+ d_model: int,
105
+ two_stage_type="none", # ["none"] only for now
106
+ pos_enc_at_input_dec=True,
107
+ ):
108
+ """Initialize the TransformerWrapper."""
109
+ super().__init__()
110
+ self.encoder = encoder
111
+ self.decoder = decoder
112
+ self.num_queries = decoder.num_queries if decoder is not None else None
113
+ self.pos_enc_at_input_dec = pos_enc_at_input_dec
114
+
115
+ # for two stage
116
+ assert two_stage_type in ["none"], f"unknown param {two_stage_type} of two_stage_type"
117
+ self.two_stage_type = two_stage_type
118
+
119
+ self._reset_parameters()
120
+ self.d_model = d_model
121
+
122
+ def _reset_parameters(self):
123
+ """Initialize the parameters of the model."""
124
+ for n, p in self.named_parameters():
125
+ if p.dim() > 1:
126
+ if "box_embed" not in n and "query_embed" not in n and "reference_points" not in n:
127
+ nn.init.xavier_uniform_(p)
128
+
129
+
130
+ def get_valid_ratio(mask):
131
+ """Compute the valid ratio of height and width from the mask."""
132
+ _, H, W = mask.shape
133
+ valid_H = torch.sum(~mask[:, :, 0], 1)
134
+ valid_W = torch.sum(~mask[:, 0, :], 1)
135
+ valid_ratio_h = valid_H.float() / H
136
+ valid_ratio_w = valid_W.float() / W
137
+ valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
138
+ return valid_ratio
139
+
140
+
141
+ def gen_sineembed_for_position(pos_tensor: torch.Tensor, num_feats: int = 256):
142
+ """Generate sinusoidal position embeddings for 2D or 4D coordinate tensors.
143
+
144
+ This function creates sinusoidal embeddings using sine and cosine functions at different frequencies, similar to the
145
+ positional encoding used in Transformer models. It supports both 2D position tensors (x, y) and 4D tensors (x, y, w,
146
+ h) for bounding box coordinates.
147
+
148
+ Args:
149
+ pos_tensor (torch.Tensor): Input position tensor of shape (n_query, bs, 2) for 2D coordinates or (n_query, bs,
150
+ 4) for 4D coordinates (bounding boxes).
151
+ num_feats (int): Number of feature dimensions for the output embedding. Must be even. Defaults to 256.
152
+
153
+ Returns:
154
+ (torch.Tensor): Sinusoidal position embeddings of shape (n_query, bs, num_feats) for 2D input or (n_query, bs,
155
+ num_feats * 2) for 4D input.
156
+
157
+ Raises:
158
+ AssertionError: If num_feats is not even.
159
+ ValueError: If pos_tensor.size(-1) is not 2 or 4.
160
+
161
+ Examples:
162
+ >>> pos_2d = torch.rand(100, 8, 2) # 100 queries, batch size 8, 2D coordinates
163
+ >>> embeddings_2d = gen_sineembed_for_position(pos_2d, num_feats=256)
164
+ >>> embeddings_2d.shape
165
+ torch.Size([100, 8, 256])
166
+ >>> pos_4d = torch.rand(50, 4, 4) # 50 queries, batch size 4, 4D coordinates
167
+ >>> embeddings_4d = gen_sineembed_for_position(pos_4d, num_feats=128)
168
+ >>> embeddings_4d.shape
169
+ torch.Size([50, 4, 256])
170
+ """
171
+ assert num_feats % 2 == 0
172
+ num_feats = num_feats // 2
173
+ # n_query, bs, _ = pos_tensor.size()
174
+ # sineembed_tensor = torch.zeros(n_query, bs, 256)
175
+ scale = 2 * math.pi
176
+ dim_t = torch.arange(num_feats, dtype=pos_tensor.dtype, device=pos_tensor.device)
177
+ dim_t = 10000 ** (2 * (torch.div(dim_t, 2, rounding_mode="floor")) / num_feats)
178
+ x_embed = pos_tensor[:, :, 0] * scale
179
+ y_embed = pos_tensor[:, :, 1] * scale
180
+ pos_x = x_embed[:, :, None] / dim_t
181
+ pos_y = y_embed[:, :, None] / dim_t
182
+ pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
183
+ pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
184
+ if pos_tensor.size(-1) == 2:
185
+ pos = torch.cat((pos_y, pos_x), dim=2)
186
+ elif pos_tensor.size(-1) == 4:
187
+ w_embed = pos_tensor[:, :, 2] * scale
188
+ pos_w = w_embed[:, :, None] / dim_t
189
+ pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
190
+
191
+ h_embed = pos_tensor[:, :, 3] * scale
192
+ pos_h = h_embed[:, :, None] / dim_t
193
+ pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
194
+
195
+ pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
196
+ else:
197
+ raise ValueError(f"Unknown pos_tensor shape(-1):{pos_tensor.size(-1)}")
198
+ return pos
@@ -0,0 +1,129 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
4
+
5
+ """Necks are the interface between a vision backbone and the rest of the detection model."""
6
+
7
+ from __future__ import annotations
8
+
9
+ from copy import deepcopy
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+
15
+ class Sam3DualViTDetNeck(nn.Module):
16
+ """A neck that implements a simple FPN as in ViTDet, with support for dual necks (for SAM3 and SAM2)."""
17
+
18
+ def __init__(
19
+ self,
20
+ trunk: nn.Module,
21
+ position_encoding: nn.Module,
22
+ d_model: int,
23
+ scale_factors=(4.0, 2.0, 1.0, 0.5),
24
+ add_sam2_neck: bool = False,
25
+ ):
26
+ """
27
+ SimpleFPN neck a la ViTDet
28
+ (From detectron2, very lightly adapted)
29
+ It supports a "dual neck" setting, where we have two identical necks (for SAM3 and SAM2), with different weights.
30
+
31
+ :param trunk: the backbone
32
+ :param position_encoding: the positional encoding to use
33
+ :param d_model: the dimension of the model
34
+ """
35
+ super().__init__()
36
+ self.trunk = trunk
37
+ self.position_encoding = position_encoding
38
+ self.convs = nn.ModuleList()
39
+
40
+ self.scale_factors = scale_factors
41
+ use_bias = True
42
+ dim: int = self.trunk.channel_list[-1]
43
+
44
+ for _, scale in enumerate(scale_factors):
45
+ current = nn.Sequential()
46
+
47
+ if scale == 4.0:
48
+ current.add_module(
49
+ "dconv_2x2_0",
50
+ nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
51
+ )
52
+ current.add_module(
53
+ "gelu",
54
+ nn.GELU(),
55
+ )
56
+ current.add_module(
57
+ "dconv_2x2_1",
58
+ nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2),
59
+ )
60
+ out_dim = dim // 4
61
+ elif scale == 2.0:
62
+ current.add_module(
63
+ "dconv_2x2",
64
+ nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
65
+ )
66
+ out_dim = dim // 2
67
+ elif scale == 1.0:
68
+ out_dim = dim
69
+ elif scale == 0.5:
70
+ current.add_module(
71
+ "maxpool_2x2",
72
+ nn.MaxPool2d(kernel_size=2, stride=2),
73
+ )
74
+ out_dim = dim
75
+ else:
76
+ raise NotImplementedError(f"scale_factor={scale} is not supported yet.")
77
+
78
+ current.add_module(
79
+ "conv_1x1",
80
+ nn.Conv2d(
81
+ in_channels=out_dim,
82
+ out_channels=d_model,
83
+ kernel_size=1,
84
+ bias=use_bias,
85
+ ),
86
+ )
87
+ current.add_module(
88
+ "conv_3x3",
89
+ nn.Conv2d(
90
+ in_channels=d_model,
91
+ out_channels=d_model,
92
+ kernel_size=3,
93
+ padding=1,
94
+ bias=use_bias,
95
+ ),
96
+ )
97
+ self.convs.append(current)
98
+
99
+ self.sam2_convs = None
100
+ if add_sam2_neck:
101
+ # Assumes sam2 neck is just a clone of the original neck
102
+ self.sam2_convs = deepcopy(self.convs)
103
+
104
+ def forward(
105
+ self, tensor_list: list[torch.Tensor]
106
+ ) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]:
107
+ """Get the feature maps and positional encodings from the neck."""
108
+ xs = self.trunk(tensor_list)
109
+ sam3_out, sam3_pos = [], []
110
+ sam2_out, sam2_pos = None, None
111
+ if self.sam2_convs is not None:
112
+ sam2_out, sam2_pos = [], []
113
+ x = xs[-1] # simpleFPN
114
+ for i in range(len(self.convs)):
115
+ sam3_x_out = self.convs[i](x)
116
+ sam3_pos_out = self.position_encoding(sam3_x_out).to(sam3_x_out.dtype)
117
+ sam3_out.append(sam3_x_out)
118
+ sam3_pos.append(sam3_pos_out)
119
+
120
+ if self.sam2_convs is not None:
121
+ sam2_x_out = self.sam2_convs[i](x)
122
+ sam2_pos_out = self.position_encoding(sam2_x_out).to(sam2_x_out.dtype)
123
+ sam2_out.append(sam2_x_out)
124
+ sam2_pos.append(sam2_pos_out)
125
+ return sam3_out, sam3_pos, sam2_out, sam2_pos
126
+
127
+ def set_imgsz(self, imgsz: list[int] = [1008, 1008]):
128
+ """Set the image size for the trunk backbone."""
129
+ self.trunk.set_imgsz(imgsz)
@@ -0,0 +1,357 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
4
+
5
+ from __future__ import annotations
6
+
7
+ from copy import deepcopy
8
+
9
+ import torch
10
+
11
+ from ultralytics.nn.modules.utils import inverse_sigmoid
12
+ from ultralytics.utils.ops import xywh2xyxy
13
+
14
+ from .geometry_encoders import Prompt
15
+ from .vl_combiner import SAM3VLBackbone
16
+
17
+
18
+ def _update_out(out, out_name, out_value, auxiliary=True, update_aux=True):
19
+ """Helper function to update output dictionary with main and auxiliary outputs."""
20
+ out[out_name] = out_value[-1] if auxiliary else out_value
21
+ if auxiliary and update_aux:
22
+ if "aux_outputs" not in out:
23
+ out["aux_outputs"] = [{} for _ in range(len(out_value) - 1)]
24
+ assert len(out["aux_outputs"]) == len(out_value) - 1
25
+ for aux_output, aux_value in zip(out["aux_outputs"], out_value[:-1]):
26
+ aux_output[out_name] = aux_value
27
+
28
+
29
+ class SAM3SemanticModel(torch.nn.Module):
30
+ """SAM3 model for semantic segmentation with vision-language backbone."""
31
+
32
+ def __init__(
33
+ self,
34
+ backbone: SAM3VLBackbone,
35
+ transformer,
36
+ input_geometry_encoder,
37
+ segmentation_head=None,
38
+ num_feature_levels=1,
39
+ o2m_mask_predict=True,
40
+ dot_prod_scoring=None,
41
+ use_instance_query: bool = True,
42
+ multimask_output: bool = True,
43
+ use_act_checkpoint_seg_head: bool = True,
44
+ matcher=None,
45
+ use_dot_prod_scoring=True,
46
+ supervise_joint_box_scores: bool = False, # only relevant if using presence token/score
47
+ detach_presence_in_joint_score: bool = False, # only relevant if using presence token/score
48
+ separate_scorer_for_instance: bool = False,
49
+ num_interactive_steps_val: int = 0,
50
+ ):
51
+ """Initialize the SAM3SemanticModel."""
52
+ super().__init__()
53
+ self.backbone = backbone
54
+ self.geometry_encoder = input_geometry_encoder
55
+ self.transformer = transformer
56
+ self.hidden_dim = transformer.d_model
57
+ self.num_feature_levels = num_feature_levels
58
+ self.segmentation_head = segmentation_head
59
+
60
+ self.o2m_mask_predict = o2m_mask_predict
61
+
62
+ self.dot_prod_scoring = dot_prod_scoring
63
+ self.use_act_checkpoint_seg_head = use_act_checkpoint_seg_head
64
+ self.matcher = matcher
65
+
66
+ self.num_interactive_steps_val = num_interactive_steps_val
67
+ self.use_dot_prod_scoring = use_dot_prod_scoring
68
+
69
+ if self.use_dot_prod_scoring:
70
+ assert dot_prod_scoring is not None
71
+ self.dot_prod_scoring = dot_prod_scoring
72
+ self.instance_dot_prod_scoring = None
73
+ if separate_scorer_for_instance:
74
+ self.instance_dot_prod_scoring = deepcopy(dot_prod_scoring)
75
+ else:
76
+ self.class_embed = torch.nn.Linear(self.hidden_dim, 1)
77
+ self.instance_class_embed = None
78
+ if separate_scorer_for_instance:
79
+ self.instance_class_embed = deepcopy(self.class_embed)
80
+
81
+ self.supervise_joint_box_scores = supervise_joint_box_scores
82
+ self.detach_presence_in_joint_score = detach_presence_in_joint_score
83
+
84
+ # verify the number of queries for O2O and O2M
85
+ num_o2o_static = self.transformer.decoder.num_queries
86
+ num_o2m_static = self.transformer.decoder.num_o2m_queries
87
+ assert num_o2m_static == (num_o2o_static if self.transformer.decoder.dac else 0)
88
+ self.dac = self.transformer.decoder.dac
89
+
90
+ self.use_instance_query = use_instance_query
91
+ self.multimask_output = multimask_output
92
+
93
+ self.text_embeddings = {}
94
+ self.names = []
95
+
96
+ def _prepare_backbone_features(self, backbone_out, num_prompts=1):
97
+ """Prepare and flatten visual features from the image backbone output for further processing."""
98
+ if num_prompts > 1: # expand features if there's more than one prompt
99
+ for i, feat in enumerate(backbone_out["backbone_fpn"]):
100
+ backbone_out["backbone_fpn"][i] = feat.expand(num_prompts, -1, -1, -1)
101
+ for i, pos in enumerate(backbone_out["vision_pos_enc"]):
102
+ pos = pos.expand(num_prompts, -1, -1, -1)
103
+ backbone_out["vision_pos_enc"][i] = pos
104
+ assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
105
+ assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
106
+
107
+ feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :]
108
+ vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :]
109
+ feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
110
+ # flatten NxCxHxW to HWxNxC
111
+ vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
112
+ vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]
113
+ return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
114
+
115
+ def _encode_prompt(
116
+ self,
117
+ img_feats,
118
+ img_pos_embeds,
119
+ vis_feat_sizes,
120
+ geometric_prompt,
121
+ visual_prompt_embed=None,
122
+ visual_prompt_mask=None,
123
+ prev_mask_pred=None,
124
+ ):
125
+ """Encode the geometric and visual prompts."""
126
+ if prev_mask_pred is not None:
127
+ img_feats = [img_feats[-1] + prev_mask_pred]
128
+ # Encode geometry
129
+ geo_feats, geo_masks = self.geometry_encoder(
130
+ geo_prompt=geometric_prompt,
131
+ img_feats=img_feats,
132
+ img_sizes=vis_feat_sizes,
133
+ img_pos_embeds=img_pos_embeds,
134
+ )
135
+ if visual_prompt_embed is None:
136
+ visual_prompt_embed = torch.zeros((0, *geo_feats.shape[1:]), device=geo_feats.device)
137
+ visual_prompt_mask = torch.zeros(
138
+ (*geo_masks.shape[:-1], 0),
139
+ device=geo_masks.device,
140
+ dtype=geo_masks.dtype,
141
+ )
142
+ prompt = torch.cat([geo_feats, visual_prompt_embed], dim=0)
143
+ prompt_mask = torch.cat([geo_masks, visual_prompt_mask], dim=1)
144
+ return prompt, prompt_mask
145
+
146
+ def _run_encoder(
147
+ self,
148
+ img_feats,
149
+ img_pos_embeds,
150
+ vis_feat_sizes,
151
+ prompt,
152
+ prompt_mask,
153
+ encoder_extra_kwargs: dict | None = None,
154
+ ):
155
+ """Run the transformer encoder."""
156
+ # Run the encoder
157
+ # make a copy of the image feature lists since the encoder may modify these lists in-place
158
+ memory = self.transformer.encoder(
159
+ src=img_feats.copy(),
160
+ src_key_padding_mask=None,
161
+ src_pos=img_pos_embeds.copy(),
162
+ prompt=prompt,
163
+ prompt_key_padding_mask=prompt_mask,
164
+ feat_sizes=vis_feat_sizes,
165
+ encoder_extra_kwargs=encoder_extra_kwargs,
166
+ )
167
+ encoder_out = {
168
+ # encoded image features
169
+ "encoder_hidden_states": memory["memory"],
170
+ "pos_embed": memory["pos_embed"],
171
+ "padding_mask": memory["padding_mask"],
172
+ "spatial_shapes": memory["spatial_shapes"],
173
+ "valid_ratios": memory["valid_ratios"],
174
+ "vis_feat_sizes": vis_feat_sizes,
175
+ # encoded text features (or other prompts)
176
+ "prompt_before_enc": prompt,
177
+ "prompt_after_enc": memory.get("memory_text", prompt),
178
+ "prompt_mask": prompt_mask,
179
+ }
180
+ return encoder_out
181
+
182
+ def _run_decoder(
183
+ self,
184
+ pos_embed,
185
+ memory,
186
+ src_mask,
187
+ out,
188
+ prompt,
189
+ prompt_mask,
190
+ encoder_out,
191
+ ):
192
+ """Run the transformer decoder."""
193
+ bs = memory.shape[1]
194
+ query_embed = self.transformer.decoder.query_embed.weight
195
+ tgt = query_embed.unsqueeze(1).repeat(1, bs, 1)
196
+
197
+ hs, reference_boxes, dec_presence_out, _ = self.transformer.decoder(
198
+ tgt=tgt,
199
+ memory=memory,
200
+ memory_key_padding_mask=src_mask,
201
+ pos=pos_embed,
202
+ reference_boxes=None,
203
+ spatial_shapes=encoder_out["spatial_shapes"],
204
+ valid_ratios=encoder_out["valid_ratios"],
205
+ tgt_mask=None,
206
+ memory_text=prompt,
207
+ text_attention_mask=prompt_mask,
208
+ apply_dac=False,
209
+ )
210
+ hs = hs.transpose(1, 2) # seq-first to batch-first
211
+ reference_boxes = reference_boxes.transpose(1, 2) # seq-first to batch-first
212
+ if dec_presence_out is not None:
213
+ # seq-first to batch-first
214
+ dec_presence_out = dec_presence_out.transpose(1, 2)
215
+ self._update_scores_and_boxes(
216
+ out,
217
+ hs,
218
+ reference_boxes,
219
+ prompt,
220
+ prompt_mask,
221
+ dec_presence_out=dec_presence_out,
222
+ )
223
+ return out, hs
224
+
225
+ def _update_scores_and_boxes(
226
+ self,
227
+ out,
228
+ hs,
229
+ reference_boxes,
230
+ prompt,
231
+ prompt_mask,
232
+ dec_presence_out=None,
233
+ is_instance_prompt=False,
234
+ ):
235
+ """Update output dict with class scores and box predictions."""
236
+ num_o2o = hs.size(2)
237
+ # score prediction
238
+ if self.use_dot_prod_scoring:
239
+ dot_prod_scoring_head = self.dot_prod_scoring
240
+ if is_instance_prompt and self.instance_dot_prod_scoring is not None:
241
+ dot_prod_scoring_head = self.instance_dot_prod_scoring
242
+ outputs_class = dot_prod_scoring_head(hs, prompt, prompt_mask)
243
+ else:
244
+ class_embed_head = self.class_embed
245
+ if is_instance_prompt and self.instance_class_embed is not None:
246
+ class_embed_head = self.instance_class_embed
247
+ outputs_class = class_embed_head(hs)
248
+
249
+ # box prediction
250
+ box_head = self.transformer.decoder.bbox_embed
251
+ if is_instance_prompt and self.transformer.decoder.instance_bbox_embed is not None:
252
+ box_head = self.transformer.decoder.instance_bbox_embed
253
+ anchor_box_offsets = box_head(hs)
254
+ reference_boxes_inv_sig = inverse_sigmoid(reference_boxes)
255
+ outputs_coord = (reference_boxes_inv_sig + anchor_box_offsets).sigmoid()
256
+ outputs_boxes_xyxy = xywh2xyxy(outputs_coord)
257
+
258
+ if dec_presence_out is not None:
259
+ _update_out(out, "presence_logit_dec", dec_presence_out, update_aux=False)
260
+
261
+ if self.supervise_joint_box_scores:
262
+ assert dec_presence_out is not None
263
+ prob_dec_presence_out = dec_presence_out.clone().sigmoid()
264
+ if self.detach_presence_in_joint_score:
265
+ prob_dec_presence_out = prob_dec_presence_out.detach()
266
+
267
+ outputs_class = inverse_sigmoid(outputs_class.sigmoid() * prob_dec_presence_out.unsqueeze(2)).clamp(
268
+ min=-10.0, max=10.0
269
+ )
270
+
271
+ _update_out(out, "pred_logits", outputs_class[:, :, :num_o2o], update_aux=False)
272
+ _update_out(out, "pred_boxes", outputs_coord[:, :, :num_o2o], update_aux=False)
273
+ _update_out(out, "pred_boxes_xyxy", outputs_boxes_xyxy[:, :, :num_o2o], update_aux=False)
274
+
275
+ def _run_segmentation_heads(
276
+ self,
277
+ out,
278
+ backbone_out,
279
+ encoder_hidden_states,
280
+ prompt,
281
+ prompt_mask,
282
+ hs,
283
+ ):
284
+ """Run segmentation heads and get masks."""
285
+ if self.segmentation_head is not None:
286
+ num_o2o = hs.size(2)
287
+ obj_queries = hs if self.o2m_mask_predict else hs[:, :, :num_o2o]
288
+ seg_head_outputs = self.segmentation_head(
289
+ backbone_feats=backbone_out["backbone_fpn"],
290
+ obj_queries=obj_queries,
291
+ encoder_hidden_states=encoder_hidden_states,
292
+ prompt=prompt,
293
+ prompt_mask=prompt_mask,
294
+ )
295
+ for k, v in seg_head_outputs.items():
296
+ if k in self.segmentation_head.instance_keys:
297
+ _update_out(out, k, v[:, :num_o2o], auxiliary=False)
298
+ else:
299
+ out[k] = v
300
+ else:
301
+ backbone_out.pop("backbone_fpn", None)
302
+
303
+ def forward_grounding(
304
+ self, backbone_out: dict[str, torch.Tensor], text_ids: torch.Tensor, geometric_prompt: Prompt = None
305
+ ):
306
+ """Forward pass for grounding (detection + segmentation) given input images and text."""
307
+ backbone_out, img_feats, img_pos_embeds, vis_feat_sizes = self._prepare_backbone_features(
308
+ backbone_out, num_prompts=len(text_ids)
309
+ )
310
+ backbone_out.update({k: v for k, v in self.text_embeddings.items()})
311
+ with torch.profiler.record_function("SAM3Image._encode_prompt"):
312
+ prompt, prompt_mask = self._encode_prompt(img_feats, img_pos_embeds, vis_feat_sizes, geometric_prompt)
313
+ # index text features (note that regardless of early or late fusion, the batch size of
314
+ # `txt_feats` is always the number of *prompts* in the encoder)
315
+ txt_feats = backbone_out["language_features"][:, text_ids]
316
+ txt_masks = backbone_out["language_mask"][text_ids]
317
+ # encode text
318
+ prompt = torch.cat([txt_feats, prompt], dim=0)
319
+ prompt_mask = torch.cat([txt_masks, prompt_mask], dim=1)
320
+
321
+ # Run the encoder
322
+ with torch.profiler.record_function("SAM3Image._run_encoder"):
323
+ encoder_out = self._run_encoder(img_feats, img_pos_embeds, vis_feat_sizes, prompt, prompt_mask)
324
+ out = {"backbone_out": backbone_out}
325
+
326
+ # Run the decoder
327
+ with torch.profiler.record_function("SAM3Image._run_decoder"):
328
+ out, hs = self._run_decoder(
329
+ memory=encoder_out["encoder_hidden_states"],
330
+ pos_embed=encoder_out["pos_embed"],
331
+ src_mask=encoder_out["padding_mask"],
332
+ out=out,
333
+ prompt=prompt,
334
+ prompt_mask=prompt_mask,
335
+ encoder_out=encoder_out,
336
+ )
337
+
338
+ # Run segmentation heads
339
+ with torch.profiler.record_function("SAM3Image._run_segmentation_heads"):
340
+ self._run_segmentation_heads(
341
+ out=out,
342
+ backbone_out=backbone_out,
343
+ encoder_hidden_states=encoder_out["encoder_hidden_states"],
344
+ prompt=prompt,
345
+ prompt_mask=prompt_mask,
346
+ hs=hs,
347
+ )
348
+ return out
349
+
350
+ def set_classes(self, text: list[str]):
351
+ """Set the text embeddings for the given class names."""
352
+ self.text_embeddings = self.backbone.forward_text(text)
353
+ self.names = text
354
+
355
+ def set_imgsz(self, imgsz: tuple[int, int]):
356
+ """Set the image size for the model."""
357
+ self.backbone.set_imgsz(imgsz)