frontveg 0.1.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. frontveg/__init__.py +11 -0
  2. frontveg/_tests/__init__.py +0 -0
  3. frontveg/_tests/test_widget.py +66 -0
  4. frontveg/_version.py +21 -0
  5. frontveg/_widget.py +132 -0
  6. frontveg/napari.yaml +14 -0
  7. frontveg/utils.py +95 -0
  8. frontveg-0.1.dev1.dist-info/METADATA +143 -0
  9. frontveg-0.1.dev1.dist-info/RECORD +44 -0
  10. frontveg-0.1.dev1.dist-info/WHEEL +5 -0
  11. frontveg-0.1.dev1.dist-info/entry_points.txt +2 -0
  12. frontveg-0.1.dev1.dist-info/licenses/LICENSE +28 -0
  13. frontveg-0.1.dev1.dist-info/top_level.txt +2 -0
  14. sam2/__init__.py +11 -0
  15. sam2/automatic_mask_generator.py +454 -0
  16. sam2/build_sam.py +167 -0
  17. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  18. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  19. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  20. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  21. sam2/modeling/__init__.py +5 -0
  22. sam2/modeling/backbones/__init__.py +5 -0
  23. sam2/modeling/backbones/hieradet.py +317 -0
  24. sam2/modeling/backbones/image_encoder.py +134 -0
  25. sam2/modeling/backbones/utils.py +95 -0
  26. sam2/modeling/memory_attention.py +169 -0
  27. sam2/modeling/memory_encoder.py +181 -0
  28. sam2/modeling/position_encoding.py +221 -0
  29. sam2/modeling/sam/__init__.py +5 -0
  30. sam2/modeling/sam/mask_decoder.py +295 -0
  31. sam2/modeling/sam/prompt_encoder.py +182 -0
  32. sam2/modeling/sam/transformer.py +360 -0
  33. sam2/modeling/sam2_base.py +907 -0
  34. sam2/modeling/sam2_utils.py +323 -0
  35. sam2/sam2_hiera_b+.yaml +1 -0
  36. sam2/sam2_hiera_l.yaml +1 -0
  37. sam2/sam2_hiera_s.yaml +1 -0
  38. sam2/sam2_hiera_t.yaml +1 -0
  39. sam2/sam2_image_predictor.py +466 -0
  40. sam2/sam2_video_predictor.py +1172 -0
  41. sam2/utils/__init__.py +5 -0
  42. sam2/utils/amg.py +348 -0
  43. sam2/utils/misc.py +349 -0
  44. sam2/utils/transforms.py +118 -0
@@ -0,0 +1,295 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import List, Optional, Tuple, Type
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+ from sam2.modeling.sam2_utils import LayerNorm2d, MLP
13
+
14
+
15
+ class MaskDecoder(nn.Module):
16
+ def __init__(
17
+ self,
18
+ *,
19
+ transformer_dim: int,
20
+ transformer: nn.Module,
21
+ num_multimask_outputs: int = 3,
22
+ activation: Type[nn.Module] = nn.GELU,
23
+ iou_head_depth: int = 3,
24
+ iou_head_hidden_dim: int = 256,
25
+ use_high_res_features: bool = False,
26
+ iou_prediction_use_sigmoid=False,
27
+ dynamic_multimask_via_stability=False,
28
+ dynamic_multimask_stability_delta=0.05,
29
+ dynamic_multimask_stability_thresh=0.98,
30
+ pred_obj_scores: bool = False,
31
+ pred_obj_scores_mlp: bool = False,
32
+ use_multimask_token_for_obj_ptr: bool = False,
33
+ ) -> None:
34
+ """
35
+ Predicts masks given an image and prompt embeddings, using a
36
+ transformer architecture.
37
+
38
+ Arguments:
39
+ transformer_dim (int): the channel dimension of the transformer
40
+ transformer (nn.Module): the transformer used to predict masks
41
+ num_multimask_outputs (int): the number of masks to predict
42
+ when disambiguating masks
43
+ activation (nn.Module): the type of activation to use when
44
+ upscaling masks
45
+ iou_head_depth (int): the depth of the MLP used to predict
46
+ mask quality
47
+ iou_head_hidden_dim (int): the hidden dimension of the MLP
48
+ used to predict mask quality
49
+ """
50
+ super().__init__()
51
+ self.transformer_dim = transformer_dim
52
+ self.transformer = transformer
53
+
54
+ self.num_multimask_outputs = num_multimask_outputs
55
+
56
+ self.iou_token = nn.Embedding(1, transformer_dim)
57
+ self.num_mask_tokens = num_multimask_outputs + 1
58
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
59
+
60
+ self.pred_obj_scores = pred_obj_scores
61
+ if self.pred_obj_scores:
62
+ self.obj_score_token = nn.Embedding(1, transformer_dim)
63
+ self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
64
+
65
+ self.output_upscaling = nn.Sequential(
66
+ nn.ConvTranspose2d(
67
+ transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
68
+ ),
69
+ LayerNorm2d(transformer_dim // 4),
70
+ activation(),
71
+ nn.ConvTranspose2d(
72
+ transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
73
+ ),
74
+ activation(),
75
+ )
76
+ self.use_high_res_features = use_high_res_features
77
+ if use_high_res_features:
78
+ self.conv_s0 = nn.Conv2d(
79
+ transformer_dim, transformer_dim // 8, kernel_size=1, stride=1
80
+ )
81
+ self.conv_s1 = nn.Conv2d(
82
+ transformer_dim, transformer_dim // 4, kernel_size=1, stride=1
83
+ )
84
+
85
+ self.output_hypernetworks_mlps = nn.ModuleList(
86
+ [
87
+ MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
88
+ for i in range(self.num_mask_tokens)
89
+ ]
90
+ )
91
+
92
+ self.iou_prediction_head = MLP(
93
+ transformer_dim,
94
+ iou_head_hidden_dim,
95
+ self.num_mask_tokens,
96
+ iou_head_depth,
97
+ sigmoid_output=iou_prediction_use_sigmoid,
98
+ )
99
+ if self.pred_obj_scores:
100
+ self.pred_obj_score_head = nn.Linear(transformer_dim, 1)
101
+ if pred_obj_scores_mlp:
102
+ self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)
103
+
104
+ # When outputting a single mask, optionally we can dynamically fall back to the best
105
+ # multimask output token if the single mask output token gives low stability scores.
106
+ self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
107
+ self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
108
+ self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
109
+
110
+ def forward(
111
+ self,
112
+ image_embeddings: torch.Tensor,
113
+ image_pe: torch.Tensor,
114
+ sparse_prompt_embeddings: torch.Tensor,
115
+ dense_prompt_embeddings: torch.Tensor,
116
+ multimask_output: bool,
117
+ repeat_image: bool,
118
+ high_res_features: Optional[List[torch.Tensor]] = None,
119
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
120
+ """
121
+ Predict masks given image and prompt embeddings.
122
+
123
+ Arguments:
124
+ image_embeddings (torch.Tensor): the embeddings from the image encoder
125
+ image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
126
+ sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
127
+ dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
128
+ multimask_output (bool): Whether to return multiple masks or a single
129
+ mask.
130
+
131
+ Returns:
132
+ torch.Tensor: batched predicted masks
133
+ torch.Tensor: batched predictions of mask quality
134
+ torch.Tensor: batched SAM token for mask output
135
+ """
136
+ masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
137
+ image_embeddings=image_embeddings,
138
+ image_pe=image_pe,
139
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
140
+ dense_prompt_embeddings=dense_prompt_embeddings,
141
+ repeat_image=repeat_image,
142
+ high_res_features=high_res_features,
143
+ )
144
+
145
+ # Select the correct mask or masks for output
146
+ if multimask_output:
147
+ masks = masks[:, 1:, :, :]
148
+ iou_pred = iou_pred[:, 1:]
149
+ elif self.dynamic_multimask_via_stability and not self.training:
150
+ masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
151
+ else:
152
+ masks = masks[:, 0:1, :, :]
153
+ iou_pred = iou_pred[:, 0:1]
154
+
155
+ if multimask_output and self.use_multimask_token_for_obj_ptr:
156
+ sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape
157
+ else:
158
+ # Take the mask output token. Here we *always* use the token for single mask output.
159
+ # At test time, even if we track after 1-click (and using multimask_output=True),
160
+ # we still take the single mask token here. The rationale is that we always track
161
+ # after multiple clicks during training, so the past tokens seen during training
162
+ # are always the single mask token (and we'll let it be the object-memory token).
163
+ sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape
164
+
165
+ # Prepare output
166
+ return masks, iou_pred, sam_tokens_out, object_score_logits
167
+
168
+ def predict_masks(
169
+ self,
170
+ image_embeddings: torch.Tensor,
171
+ image_pe: torch.Tensor,
172
+ sparse_prompt_embeddings: torch.Tensor,
173
+ dense_prompt_embeddings: torch.Tensor,
174
+ repeat_image: bool,
175
+ high_res_features: Optional[List[torch.Tensor]] = None,
176
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
177
+ """Predicts masks. See 'forward' for more details."""
178
+ # Concatenate output tokens
179
+ s = 0
180
+ if self.pred_obj_scores:
181
+ output_tokens = torch.cat(
182
+ [
183
+ self.obj_score_token.weight,
184
+ self.iou_token.weight,
185
+ self.mask_tokens.weight,
186
+ ],
187
+ dim=0,
188
+ )
189
+ s = 1
190
+ else:
191
+ output_tokens = torch.cat(
192
+ [self.iou_token.weight, self.mask_tokens.weight], dim=0
193
+ )
194
+ output_tokens = output_tokens.unsqueeze(0).expand(
195
+ sparse_prompt_embeddings.size(0), -1, -1
196
+ )
197
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
198
+
199
+ # Expand per-image data in batch direction to be per-mask
200
+ if repeat_image:
201
+ src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
202
+ else:
203
+ assert image_embeddings.shape[0] == tokens.shape[0]
204
+ src = image_embeddings
205
+ src = src + dense_prompt_embeddings
206
+ assert (
207
+ image_pe.size(0) == 1
208
+ ), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
209
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
210
+ b, c, h, w = src.shape
211
+
212
+ # Run the transformer
213
+ hs, src = self.transformer(src, pos_src, tokens)
214
+ iou_token_out = hs[:, s, :]
215
+ mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]
216
+
217
+ # Upscale mask embeddings and predict masks using the mask tokens
218
+ src = src.transpose(1, 2).view(b, c, h, w)
219
+ if not self.use_high_res_features:
220
+ upscaled_embedding = self.output_upscaling(src)
221
+ else:
222
+ dc1, ln1, act1, dc2, act2 = self.output_upscaling
223
+ feat_s0, feat_s1 = high_res_features
224
+ upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
225
+ upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)
226
+
227
+ hyper_in_list: List[torch.Tensor] = []
228
+ for i in range(self.num_mask_tokens):
229
+ hyper_in_list.append(
230
+ self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
231
+ )
232
+ hyper_in = torch.stack(hyper_in_list, dim=1)
233
+ b, c, h, w = upscaled_embedding.shape
234
+ masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
235
+
236
+ # Generate mask quality predictions
237
+ iou_pred = self.iou_prediction_head(iou_token_out)
238
+ if self.pred_obj_scores:
239
+ assert s == 1
240
+ object_score_logits = self.pred_obj_score_head(hs[:, 0, :])
241
+ else:
242
+ # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1
243
+ object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)
244
+
245
+ return masks, iou_pred, mask_tokens_out, object_score_logits
246
+
247
+ def _get_stability_scores(self, mask_logits):
248
+ """
249
+ Compute stability scores of the mask logits based on the IoU between upper and
250
+ lower thresholds.
251
+ """
252
+ mask_logits = mask_logits.flatten(-2)
253
+ stability_delta = self.dynamic_multimask_stability_delta
254
+ area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
255
+ area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
256
+ stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
257
+ return stability_scores
258
+
259
+ def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
260
+ """
261
+ When outputting a single mask, if the stability score from the current single-mask
262
+ output (based on output token 0) falls below a threshold, we instead select from
263
+ multi-mask outputs (based on output token 1~3) the mask with the highest predicted
264
+ IoU score. This is intended to ensure a valid mask for both clicking and tracking.
265
+ """
266
+ # The best mask from multimask output tokens (1~3)
267
+ multimask_logits = all_mask_logits[:, 1:, :, :]
268
+ multimask_iou_scores = all_iou_scores[:, 1:]
269
+ best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
270
+ batch_inds = torch.arange(
271
+ multimask_iou_scores.size(0), device=all_iou_scores.device
272
+ )
273
+ best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
274
+ best_multimask_logits = best_multimask_logits.unsqueeze(1)
275
+ best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
276
+ best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1)
277
+
278
+ # The mask from singlemask output token 0 and its stability score
279
+ singlemask_logits = all_mask_logits[:, 0:1, :, :]
280
+ singlemask_iou_scores = all_iou_scores[:, 0:1]
281
+ stability_scores = self._get_stability_scores(singlemask_logits)
282
+ is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
283
+
284
+ # Dynamically fall back to best multimask output upon low stability scores.
285
+ mask_logits_out = torch.where(
286
+ is_stable[..., None, None].expand_as(singlemask_logits),
287
+ singlemask_logits,
288
+ best_multimask_logits,
289
+ )
290
+ iou_scores_out = torch.where(
291
+ is_stable.expand_as(singlemask_iou_scores),
292
+ singlemask_iou_scores,
293
+ best_multimask_iou_scores,
294
+ )
295
+ return mask_logits_out, iou_scores_out
@@ -0,0 +1,182 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Optional, Tuple, Type
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+ from sam2.modeling.position_encoding import PositionEmbeddingRandom
13
+
14
+ from sam2.modeling.sam2_utils import LayerNorm2d
15
+
16
+
17
+ class PromptEncoder(nn.Module):
18
+ def __init__(
19
+ self,
20
+ embed_dim: int,
21
+ image_embedding_size: Tuple[int, int],
22
+ input_image_size: Tuple[int, int],
23
+ mask_in_chans: int,
24
+ activation: Type[nn.Module] = nn.GELU,
25
+ ) -> None:
26
+ """
27
+ Encodes prompts for input to SAM's mask decoder.
28
+
29
+ Arguments:
30
+ embed_dim (int): The prompts' embedding dimension
31
+ image_embedding_size (tuple(int, int)): The spatial size of the
32
+ image embedding, as (H, W).
33
+ input_image_size (int): The padded size of the image as input
34
+ to the image encoder, as (H, W).
35
+ mask_in_chans (int): The number of hidden channels used for
36
+ encoding input masks.
37
+ activation (nn.Module): The activation to use when encoding
38
+ input masks.
39
+ """
40
+ super().__init__()
41
+ self.embed_dim = embed_dim
42
+ self.input_image_size = input_image_size
43
+ self.image_embedding_size = image_embedding_size
44
+ self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
45
+
46
+ self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
47
+ point_embeddings = [
48
+ nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
49
+ ]
50
+ self.point_embeddings = nn.ModuleList(point_embeddings)
51
+ self.not_a_point_embed = nn.Embedding(1, embed_dim)
52
+
53
+ self.mask_input_size = (
54
+ 4 * image_embedding_size[0],
55
+ 4 * image_embedding_size[1],
56
+ )
57
+ self.mask_downscaling = nn.Sequential(
58
+ nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
59
+ LayerNorm2d(mask_in_chans // 4),
60
+ activation(),
61
+ nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
62
+ LayerNorm2d(mask_in_chans),
63
+ activation(),
64
+ nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
65
+ )
66
+ self.no_mask_embed = nn.Embedding(1, embed_dim)
67
+
68
+ def get_dense_pe(self) -> torch.Tensor:
69
+ """
70
+ Returns the positional encoding used to encode point prompts,
71
+ applied to a dense set of points the shape of the image encoding.
72
+
73
+ Returns:
74
+ torch.Tensor: Positional encoding with shape
75
+ 1x(embed_dim)x(embedding_h)x(embedding_w)
76
+ """
77
+ return self.pe_layer(self.image_embedding_size).unsqueeze(0)
78
+
79
+ def _embed_points(
80
+ self,
81
+ points: torch.Tensor,
82
+ labels: torch.Tensor,
83
+ pad: bool,
84
+ ) -> torch.Tensor:
85
+ """Embeds point prompts."""
86
+ points = points + 0.5 # Shift to center of pixel
87
+ if pad:
88
+ padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
89
+ padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
90
+ points = torch.cat([points, padding_point], dim=1)
91
+ labels = torch.cat([labels, padding_label], dim=1)
92
+ point_embedding = self.pe_layer.forward_with_coords(
93
+ points, self.input_image_size
94
+ )
95
+ point_embedding[labels == -1] = 0.0
96
+ point_embedding[labels == -1] += self.not_a_point_embed.weight
97
+ point_embedding[labels == 0] += self.point_embeddings[0].weight
98
+ point_embedding[labels == 1] += self.point_embeddings[1].weight
99
+ point_embedding[labels == 2] += self.point_embeddings[2].weight
100
+ point_embedding[labels == 3] += self.point_embeddings[3].weight
101
+ return point_embedding
102
+
103
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
104
+ """Embeds box prompts."""
105
+ boxes = boxes + 0.5 # Shift to center of pixel
106
+ coords = boxes.reshape(-1, 2, 2)
107
+ corner_embedding = self.pe_layer.forward_with_coords(
108
+ coords, self.input_image_size
109
+ )
110
+ corner_embedding[:, 0, :] += self.point_embeddings[2].weight
111
+ corner_embedding[:, 1, :] += self.point_embeddings[3].weight
112
+ return corner_embedding
113
+
114
+ def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
115
+ """Embeds mask inputs."""
116
+ mask_embedding = self.mask_downscaling(masks)
117
+ return mask_embedding
118
+
119
+ def _get_batch_size(
120
+ self,
121
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
122
+ boxes: Optional[torch.Tensor],
123
+ masks: Optional[torch.Tensor],
124
+ ) -> int:
125
+ """
126
+ Gets the batch size of the output given the batch size of the input prompts.
127
+ """
128
+ if points is not None:
129
+ return points[0].shape[0]
130
+ elif boxes is not None:
131
+ return boxes.shape[0]
132
+ elif masks is not None:
133
+ return masks.shape[0]
134
+ else:
135
+ return 1
136
+
137
+ def _get_device(self) -> torch.device:
138
+ return self.point_embeddings[0].weight.device
139
+
140
+ def forward(
141
+ self,
142
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
143
+ boxes: Optional[torch.Tensor],
144
+ masks: Optional[torch.Tensor],
145
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
146
+ """
147
+ Embeds different types of prompts, returning both sparse and dense
148
+ embeddings.
149
+
150
+ Arguments:
151
+ points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
152
+ and labels to embed.
153
+ boxes (torch.Tensor or none): boxes to embed
154
+ masks (torch.Tensor or none): masks to embed
155
+
156
+ Returns:
157
+ torch.Tensor: sparse embeddings for the points and boxes, with shape
158
+ BxNx(embed_dim), where N is determined by the number of input points
159
+ and boxes.
160
+ torch.Tensor: dense embeddings for the masks, in the shape
161
+ Bx(embed_dim)x(embed_H)x(embed_W)
162
+ """
163
+ bs = self._get_batch_size(points, boxes, masks)
164
+ sparse_embeddings = torch.empty(
165
+ (bs, 0, self.embed_dim), device=self._get_device()
166
+ )
167
+ if points is not None:
168
+ coords, labels = points
169
+ point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
170
+ sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
171
+ if boxes is not None:
172
+ box_embeddings = self._embed_boxes(boxes)
173
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
174
+
175
+ if masks is not None:
176
+ dense_embeddings = self._embed_masks(masks)
177
+ else:
178
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
179
+ bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
180
+ )
181
+
182
+ return sparse_embeddings, dense_embeddings