nnInteractive 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. nnInteractive/__init__.py +3 -0
  2. nnInteractive/inference/__init__.py +0 -0
  3. nnInteractive/inference/cvpr2025_challenge_baseline/__init__.py +0 -0
  4. nnInteractive/inference/cvpr2025_challenge_baseline/predict.py +173 -0
  5. nnInteractive/inference/inference_session.py +1400 -0
  6. nnInteractive/interaction/__init__.py +0 -0
  7. nnInteractive/interaction/point.py +166 -0
  8. nnInteractive/supervoxel/setup.py +4 -0
  9. nnInteractive/supervoxel/src/metadata.py +118 -0
  10. nnInteractive/supervoxel/src/reader.py +175 -0
  11. nnInteractive/supervoxel/src/run.py +136 -0
  12. nnInteractive/supervoxel/src/sam2/__init__.py +2 -0
  13. nnInteractive/supervoxel/src/sam2/sam2/__init__.py +11 -0
  14. nnInteractive/supervoxel/src/sam2/sam2/automatic_mask_generator.py +434 -0
  15. nnInteractive/supervoxel/src/sam2/sam2/benchmark.py +86 -0
  16. nnInteractive/supervoxel/src/sam2/sam2/build_sam.py +172 -0
  17. nnInteractive/supervoxel/src/sam2/sam2/modeling/__init__.py +5 -0
  18. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/__init__.py +5 -0
  19. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/hieradet.py +305 -0
  20. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/image_encoder.py +132 -0
  21. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/utils.py +89 -0
  22. nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_attention.py +167 -0
  23. nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_encoder.py +179 -0
  24. nnInteractive/supervoxel/src/sam2/sam2/modeling/position_encoding.py +217 -0
  25. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/__init__.py +5 -0
  26. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/mask_decoder.py +274 -0
  27. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/prompt_encoder.py +194 -0
  28. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/transformer.py +293 -0
  29. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_base.py +879 -0
  30. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_utils.py +315 -0
  31. nnInteractive/supervoxel/src/sam2/sam2/sam2_image_predictor.py +433 -0
  32. nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor.py +1171 -0
  33. nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor_legacy.py +1125 -0
  34. nnInteractive/supervoxel/src/sam2/sam2/utils/__init__.py +5 -0
  35. nnInteractive/supervoxel/src/sam2/sam2/utils/amg.py +332 -0
  36. nnInteractive/supervoxel/src/sam2/sam2/utils/misc.py +488 -0
  37. nnInteractive/supervoxel/src/sam2/sam2/utils/transforms.py +108 -0
  38. nnInteractive/supervoxel/src/sam2/setup.py +174 -0
  39. nnInteractive/supervoxel/src/sam2/training/__init__.py +5 -0
  40. nnInteractive/supervoxel/src/sam2/training/dataset/__init__.py +5 -0
  41. nnInteractive/supervoxel/src/sam2/training/dataset/sam2_datasets.py +176 -0
  42. nnInteractive/supervoxel/src/sam2/training/dataset/transforms.py +481 -0
  43. nnInteractive/supervoxel/src/sam2/training/dataset/utils.py +102 -0
  44. nnInteractive/supervoxel/src/sam2/training/dataset/vos_dataset.py +154 -0
  45. nnInteractive/supervoxel/src/sam2/training/dataset/vos_raw_dataset.py +290 -0
  46. nnInteractive/supervoxel/src/sam2/training/dataset/vos_sampler.py +103 -0
  47. nnInteractive/supervoxel/src/sam2/training/dataset/vos_segment_loader.py +289 -0
  48. nnInteractive/supervoxel/src/sam2/training/loss_fns.py +290 -0
  49. nnInteractive/supervoxel/src/sam2/training/model/__init__.py +5 -0
  50. nnInteractive/supervoxel/src/sam2/training/model/sam2.py +515 -0
  51. nnInteractive/supervoxel/src/sam2/training/optimizer.py +462 -0
  52. nnInteractive/supervoxel/src/sam2/training/scripts/sav_frame_extraction_submitit.py +157 -0
  53. nnInteractive/supervoxel/src/sam2/training/train.py +232 -0
  54. nnInteractive/supervoxel/src/sam2/training/trainer.py +1051 -0
  55. nnInteractive/supervoxel/src/sam2/training/utils/__init__.py +5 -0
  56. nnInteractive/supervoxel/src/sam2/training/utils/checkpoint_utils.py +328 -0
  57. nnInteractive/supervoxel/src/sam2/training/utils/data_utils.py +166 -0
  58. nnInteractive/supervoxel/src/sam2/training/utils/distributed.py +560 -0
  59. nnInteractive/supervoxel/src/sam2/training/utils/logger.py +236 -0
  60. nnInteractive/supervoxel/src/sam2/training/utils/train_utils.py +275 -0
  61. nnInteractive/supervoxel/src/supervoxel.py +198 -0
  62. nnInteractive/trainer/__init__.py +0 -0
  63. nnInteractive/trainer/nnInteractiveTrainer.py +24 -0
  64. nnInteractive/utils/__init__.py +0 -0
  65. nnInteractive/utils/bboxes.py +217 -0
  66. nnInteractive/utils/checkpoint_cleansing.py +9 -0
  67. nnInteractive/utils/crop.py +268 -0
  68. nnInteractive/utils/erosion_dilation.py +48 -0
  69. nnInteractive/utils/inference_helpers.py +45 -0
  70. nnInteractive/utils/os_shennanigans.py +16 -0
  71. nnInteractive/utils/rounding.py +13 -0
  72. nninteractive-2.0.0.dist-info/METADATA +511 -0
  73. nninteractive-2.0.0.dist-info/RECORD +76 -0
  74. nninteractive-2.0.0.dist-info/WHEEL +5 -0
  75. nninteractive-2.0.0.dist-info/licenses/LICENSE +201 -0
  76. nninteractive-2.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,179 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d
15
+
16
+
17
+ class MaskDownSampler(nn.Module):
18
+ """
19
+ Progressively downsample a mask by total_stride, each time by stride.
20
+ Note that LayerNorm is applied per *token*, like in ViT.
21
+
22
+ With each downsample (by a factor stride**2), channel capacity increases by the same factor.
23
+ In the end, we linearly project to embed_dim channels.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ embed_dim=256,
29
+ kernel_size=4,
30
+ stride=4,
31
+ padding=0,
32
+ total_stride=16,
33
+ activation=nn.GELU,
34
+ ):
35
+ super().__init__()
36
+ num_layers = int(math.log2(total_stride) // math.log2(stride))
37
+ assert stride**num_layers == total_stride
38
+ self.encoder = nn.Sequential()
39
+ mask_in_chans, mask_out_chans = 1, 1
40
+ for _ in range(num_layers):
41
+ mask_out_chans = mask_in_chans * (stride**2)
42
+ self.encoder.append(
43
+ nn.Conv2d(
44
+ mask_in_chans,
45
+ mask_out_chans,
46
+ kernel_size=kernel_size,
47
+ stride=stride,
48
+ padding=padding,
49
+ )
50
+ )
51
+ self.encoder.append(LayerNorm2d(mask_out_chans))
52
+ self.encoder.append(activation())
53
+ mask_in_chans = mask_out_chans
54
+
55
+ self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
56
+
57
+ def forward(self, x):
58
+ return self.encoder(x)
59
+
60
+
61
+ # Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt)
62
+ class CXBlock(nn.Module):
63
+ r"""ConvNeXt Block. There are two equivalent implementations:
64
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
65
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
66
+ We use (2) as we find it slightly faster in PyTorch
67
+
68
+ Args:
69
+ dim (int): Number of input channels.
70
+ drop_path (float): Stochastic depth rate. Default: 0.0
71
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ dim,
77
+ kernel_size=7,
78
+ padding=3,
79
+ drop_path=0.0,
80
+ layer_scale_init_value=1e-6,
81
+ use_dwconv=True,
82
+ ):
83
+ super().__init__()
84
+ self.dwconv = nn.Conv2d(
85
+ dim,
86
+ dim,
87
+ kernel_size=kernel_size,
88
+ padding=padding,
89
+ groups=dim if use_dwconv else 1,
90
+ ) # depthwise conv
91
+ self.norm = LayerNorm2d(dim, eps=1e-6)
92
+ self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
93
+ self.act = nn.GELU()
94
+ self.pwconv2 = nn.Linear(4 * dim, dim)
95
+ self.gamma = (
96
+ nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
97
+ if layer_scale_init_value > 0
98
+ else None
99
+ )
100
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
101
+
102
+ def forward(self, x):
103
+ input = x
104
+ x = self.dwconv(x)
105
+ x = self.norm(x)
106
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
107
+ x = self.pwconv1(x)
108
+ x = self.act(x)
109
+ x = self.pwconv2(x)
110
+ if self.gamma is not None:
111
+ x = self.gamma * x
112
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
113
+
114
+ x = input + self.drop_path(x)
115
+ return x
116
+
117
+
118
+ class Fuser(nn.Module):
119
+ def __init__(self, layer, num_layers, dim=None, input_projection=False):
120
+ super().__init__()
121
+ self.proj = nn.Identity()
122
+ self.layers = get_clones(layer, num_layers)
123
+
124
+ if input_projection:
125
+ assert dim is not None
126
+ self.proj = nn.Conv2d(dim, dim, kernel_size=1)
127
+
128
+ def forward(self, x):
129
+ # normally x: (N, C, H, W)
130
+ x = self.proj(x)
131
+ for layer in self.layers:
132
+ x = layer(x)
133
+ return x
134
+
135
+
136
+ class MemoryEncoder(nn.Module):
137
+ def __init__(
138
+ self,
139
+ out_dim,
140
+ mask_downsampler,
141
+ fuser,
142
+ position_encoding,
143
+ in_dim=256, # in_dim of pix_feats
144
+ ):
145
+ super().__init__()
146
+
147
+ self.mask_downsampler = mask_downsampler
148
+
149
+ self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
150
+ self.fuser = fuser
151
+ self.position_encoding = position_encoding
152
+ self.out_proj = nn.Identity()
153
+ if out_dim != in_dim:
154
+ self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
155
+
156
+ def forward(
157
+ self,
158
+ pix_feat: torch.Tensor,
159
+ masks: torch.Tensor,
160
+ skip_mask_sigmoid: bool = False,
161
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
162
+ ## Process masks
163
+ # sigmoid, so that less domain shift from gt masks which are bool
164
+ if not skip_mask_sigmoid:
165
+ masks = F.sigmoid(masks)
166
+ masks = self.mask_downsampler(masks)
167
+
168
+ ## Fuse pix_feats and downsampled masks
169
+ # in case the visual features are on CPU, cast them to CUDA
170
+ pix_feat = pix_feat.to(masks.device)
171
+
172
+ x = self.pix_feat_proj(pix_feat)
173
+ x = x + masks
174
+ x = self.fuser(x)
175
+ x = self.out_proj(x)
176
+
177
+ pos = self.position_encoding(x).to(x.dtype)
178
+
179
+ return {"vision_features": x, "vision_pos_enc": [pos]}
@@ -0,0 +1,217 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import Any, Optional, Tuple
9
+
10
+ import numpy as np
11
+
12
+ import torch
13
+ from torch import nn
14
+
15
+
16
+ class PositionEmbeddingSine(nn.Module):
17
+ """
18
+ This is a more standard version of the position embedding, very similar to the one
19
+ used by the Attention Is All You Need paper, generalized to work on images.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ num_pos_feats,
25
+ temperature: int = 10000,
26
+ normalize: bool = True,
27
+ scale: Optional[float] = None,
28
+ # Following settings only relevant
29
+ # for warmping up cache for compilation
30
+ warmup_cache: bool = True,
31
+ image_size: int = 1024,
32
+ strides: Tuple[int] = (4, 8, 16, 32),
33
+ ):
34
+ super().__init__()
35
+ assert num_pos_feats % 2 == 0, "Expecting even model width"
36
+ self.num_pos_feats = num_pos_feats // 2
37
+ self.temperature = temperature
38
+ self.normalize = normalize
39
+ if scale is not None and normalize is False:
40
+ raise ValueError("normalize should be True if scale is passed")
41
+ if scale is None:
42
+ scale = 2 * math.pi
43
+ self.scale = scale
44
+
45
+ self.cache = {}
46
+ if warmup_cache and torch.cuda.is_available():
47
+ # Warmup cache for cuda, to help with compilation
48
+ device = torch.device("cuda")
49
+ for stride in strides:
50
+ cache_key = (image_size // stride, image_size // stride)
51
+ self._pe(1, device, *cache_key)
52
+
53
+ def _encode_xy(self, x, y):
54
+ # The positions are expected to be normalized
55
+ assert len(x) == len(y) and x.ndim == y.ndim == 1
56
+ x_embed = x * self.scale
57
+ y_embed = y * self.scale
58
+
59
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
60
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
61
+
62
+ pos_x = x_embed[:, None] / dim_t
63
+ pos_y = y_embed[:, None] / dim_t
64
+ pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1)
65
+ pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1)
66
+ return pos_x, pos_y
67
+
68
+ @torch.no_grad()
69
+ def encode_boxes(self, x, y, w, h):
70
+ pos_x, pos_y = self._encode_xy(x, y)
71
+ pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
72
+ return pos
73
+
74
+ encode = encode_boxes # Backwards compatibility
75
+
76
+ @torch.no_grad()
77
+ def encode_points(self, x, y, labels):
78
+ (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
79
+ assert bx == by and nx == ny and bx == bl and nx == nl
80
+ pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
81
+ pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
82
+ pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
83
+ return pos
84
+
85
+ @torch.no_grad()
86
+ def _pe(self, B, device, *cache_key):
87
+ H, W = cache_key
88
+ if cache_key in self.cache:
89
+ return self.cache[cache_key].to(device)[None].repeat(B, 1, 1, 1)
90
+
91
+ y_embed = torch.arange(1, H + 1, dtype=torch.float32, device=device).view(1, -1, 1).repeat(B, 1, W)
92
+ x_embed = torch.arange(1, W + 1, dtype=torch.float32, device=device).view(1, 1, -1).repeat(B, H, 1)
93
+
94
+ if self.normalize:
95
+ eps = 1e-6
96
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
97
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
98
+
99
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=device)
100
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
101
+
102
+ pos_x = x_embed[:, :, :, None] / dim_t
103
+ pos_y = y_embed[:, :, :, None] / dim_t
104
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
105
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
106
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
107
+ self.cache[cache_key] = pos[0]
108
+ return pos
109
+
110
+ @torch.no_grad()
111
+ def forward(self, x: torch.Tensor):
112
+ B = x.shape[0]
113
+ cache_key = (x.shape[-2], x.shape[-1])
114
+ return self._pe(B, x.device, *cache_key)
115
+
116
+
117
+ class PositionEmbeddingRandom(nn.Module):
118
+ """
119
+ Positional encoding using random spatial frequencies.
120
+ """
121
+
122
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
123
+ super().__init__()
124
+ if scale is None or scale <= 0.0:
125
+ scale = 1.0
126
+ self.register_buffer(
127
+ "positional_encoding_gaussian_matrix",
128
+ scale * torch.randn((2, num_pos_feats)),
129
+ )
130
+
131
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
132
+ """Positionally encode points that are normalized to [0,1]."""
133
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
134
+ coords = 2 * coords - 1
135
+ coords = coords @ self.positional_encoding_gaussian_matrix
136
+ coords = 2 * np.pi * coords
137
+ # outputs d_1 x ... x d_n x C shape
138
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
139
+
140
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
141
+ """Generate positional encoding for a grid of the specified size."""
142
+ h, w = size
143
+ device: Any = self.positional_encoding_gaussian_matrix.device
144
+ grid = torch.ones((h, w), device=device, dtype=torch.float32)
145
+ y_embed = grid.cumsum(dim=0) - 0.5
146
+ x_embed = grid.cumsum(dim=1) - 0.5
147
+ y_embed = y_embed / h
148
+ x_embed = x_embed / w
149
+
150
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
151
+ return pe.permute(2, 0, 1) # C x H x W
152
+
153
+ def forward_with_coords(self, coords_input: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor:
154
+ """Positionally encode points that are not normalized to [0,1]."""
155
+ coords = coords_input.clone()
156
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
157
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
158
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
159
+
160
+
161
+ # Rotary Positional Encoding, adapted from:
162
+ # 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py
163
+ # 2. https://github.com/naver-ai/rope-vit
164
+ # 3. https://github.com/lucidrains/rotary-embedding-torch
165
+
166
+
167
+ def init_t_xy(end_x: int, end_y: int):
168
+ t = torch.arange(end_x * end_y, dtype=torch.float32)
169
+ t_x = (t % end_x).float()
170
+ t_y = torch.div(t, end_x, rounding_mode="floor").float()
171
+ return t_x, t_y
172
+
173
+
174
+ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
175
+ freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
176
+ freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
177
+
178
+ t_x, t_y = init_t_xy(end_x, end_y)
179
+ freqs_x = torch.outer(t_x, freqs_x)
180
+ freqs_y = torch.outer(t_y, freqs_y)
181
+ freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
182
+ freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
183
+ return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
184
+
185
+
186
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
187
+ ndim = x.ndim
188
+ assert 0 <= 1 < ndim
189
+ assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
190
+ shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
191
+ return freqs_cis.view(*shape)
192
+
193
+
194
+ def apply_rotary_enc(
195
+ xq: torch.Tensor,
196
+ xk: torch.Tensor,
197
+ freqs_cis: torch.Tensor,
198
+ repeat_freqs_k: bool = False,
199
+ ):
200
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
201
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) if xk.shape[-2] != 0 else None
202
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
203
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
204
+ if xk_ is None:
205
+ # no keys to rotate, due to dropout
206
+ return xq_out.type_as(xq).to(xq.device), xk
207
+ # repeat freqs along seq_len dim to match k seq_len
208
+ if repeat_freqs_k:
209
+ r = xk_.shape[-2] // xq_.shape[-2]
210
+ if freqs_cis.is_cuda:
211
+ freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
212
+ else:
213
+ # torch.repeat on complex numbers may not be supported on non-CUDA devices
214
+ # (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten
215
+ freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
216
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
217
+ return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
@@ -0,0 +1,5 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
@@ -0,0 +1,274 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import List, Optional, Tuple, Type
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+ from sam2.modeling.sam2_utils import LayerNorm2d, MLP
13
+
14
+
15
+ class MaskDecoder(nn.Module):
16
+ def __init__(
17
+ self,
18
+ *,
19
+ transformer_dim: int,
20
+ transformer: nn.Module,
21
+ num_multimask_outputs: int = 3,
22
+ activation: Type[nn.Module] = nn.GELU,
23
+ iou_head_depth: int = 3,
24
+ iou_head_hidden_dim: int = 256,
25
+ use_high_res_features: bool = False,
26
+ iou_prediction_use_sigmoid=False,
27
+ dynamic_multimask_via_stability=False,
28
+ dynamic_multimask_stability_delta=0.05,
29
+ dynamic_multimask_stability_thresh=0.98,
30
+ pred_obj_scores: bool = False,
31
+ pred_obj_scores_mlp: bool = False,
32
+ use_multimask_token_for_obj_ptr: bool = False,
33
+ ) -> None:
34
+ """
35
+ Predicts masks given an image and prompt embeddings, using a
36
+ transformer architecture.
37
+
38
+ Arguments:
39
+ transformer_dim (int): the channel dimension of the transformer
40
+ transformer (nn.Module): the transformer used to predict masks
41
+ num_multimask_outputs (int): the number of masks to predict
42
+ when disambiguating masks
43
+ activation (nn.Module): the type of activation to use when
44
+ upscaling masks
45
+ iou_head_depth (int): the depth of the MLP used to predict
46
+ mask quality
47
+ iou_head_hidden_dim (int): the hidden dimension of the MLP
48
+ used to predict mask quality
49
+ """
50
+ super().__init__()
51
+ self.transformer_dim = transformer_dim
52
+ self.transformer = transformer
53
+
54
+ self.num_multimask_outputs = num_multimask_outputs
55
+
56
+ self.iou_token = nn.Embedding(1, transformer_dim)
57
+ self.num_mask_tokens = num_multimask_outputs + 1
58
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
59
+
60
+ self.pred_obj_scores = pred_obj_scores
61
+ if self.pred_obj_scores:
62
+ self.obj_score_token = nn.Embedding(1, transformer_dim)
63
+ self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
64
+
65
+ self.output_upscaling = nn.Sequential(
66
+ nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
67
+ LayerNorm2d(transformer_dim // 4),
68
+ activation(),
69
+ nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
70
+ activation(),
71
+ )
72
+ self.use_high_res_features = use_high_res_features
73
+ if use_high_res_features:
74
+ self.conv_s0 = nn.Conv2d(transformer_dim, transformer_dim // 8, kernel_size=1, stride=1)
75
+ self.conv_s1 = nn.Conv2d(transformer_dim, transformer_dim // 4, kernel_size=1, stride=1)
76
+
77
+ self.output_hypernetworks_mlps = nn.ModuleList(
78
+ [MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for i in range(self.num_mask_tokens)]
79
+ )
80
+
81
+ self.iou_prediction_head = MLP(
82
+ transformer_dim,
83
+ iou_head_hidden_dim,
84
+ self.num_mask_tokens,
85
+ iou_head_depth,
86
+ sigmoid_output=iou_prediction_use_sigmoid,
87
+ )
88
+ if self.pred_obj_scores:
89
+ self.pred_obj_score_head = nn.Linear(transformer_dim, 1)
90
+ if pred_obj_scores_mlp:
91
+ self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)
92
+
93
+ # When outputting a single mask, optionally we can dynamically fall back to the best
94
+ # multimask output token if the single mask output token gives low stability scores.
95
+ self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
96
+ self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
97
+ self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
98
+
99
+ def forward(
100
+ self,
101
+ image_embeddings: torch.Tensor,
102
+ image_pe: torch.Tensor,
103
+ sparse_prompt_embeddings: torch.Tensor,
104
+ dense_prompt_embeddings: torch.Tensor,
105
+ multimask_output: bool,
106
+ repeat_image: bool,
107
+ high_res_features: Optional[List[torch.Tensor]] = None,
108
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
109
+ """
110
+ Predict masks given image and prompt embeddings.
111
+
112
+ Arguments:
113
+ image_embeddings (torch.Tensor): the embeddings from the image encoder
114
+ image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
115
+ sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
116
+ dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
117
+ multimask_output (bool): Whether to return multiple masks or a single
118
+ mask.
119
+
120
+ Returns:
121
+ torch.Tensor: batched predicted masks
122
+ torch.Tensor: batched predictions of mask quality
123
+ torch.Tensor: batched SAM token for mask output
124
+ """
125
+ masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
126
+ image_embeddings=image_embeddings,
127
+ image_pe=image_pe,
128
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
129
+ dense_prompt_embeddings=dense_prompt_embeddings,
130
+ repeat_image=repeat_image,
131
+ high_res_features=high_res_features,
132
+ )
133
+
134
+ # Select the correct mask or masks for output
135
+ if multimask_output:
136
+ masks = masks[:, 1:, :, :]
137
+ iou_pred = iou_pred[:, 1:]
138
+ elif self.dynamic_multimask_via_stability and not self.training:
139
+ masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
140
+ else:
141
+ masks = masks[:, 0:1, :, :]
142
+ iou_pred = iou_pred[:, 0:1]
143
+
144
+ if multimask_output and self.use_multimask_token_for_obj_ptr:
145
+ sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape
146
+ else:
147
+ # Take the mask output token. Here we *always* use the token for single mask output.
148
+ # At test time, even if we track after 1-click (and using multimask_output=True),
149
+ # we still take the single mask token here. The rationale is that we always track
150
+ # after multiple clicks during training, so the past tokens seen during training
151
+ # are always the single mask token (and we'll let it be the object-memory token).
152
+ sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape
153
+
154
+ # Prepare output
155
+ return masks, iou_pred, sam_tokens_out, object_score_logits
156
+
157
+ def predict_masks(
158
+ self,
159
+ image_embeddings: torch.Tensor,
160
+ image_pe: torch.Tensor,
161
+ sparse_prompt_embeddings: torch.Tensor,
162
+ dense_prompt_embeddings: torch.Tensor,
163
+ repeat_image: bool,
164
+ high_res_features: Optional[List[torch.Tensor]] = None,
165
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
166
+ """Predicts masks. See 'forward' for more details."""
167
+ # Concatenate output tokens
168
+ s = 0
169
+ if self.pred_obj_scores:
170
+ output_tokens = torch.cat(
171
+ [
172
+ self.obj_score_token.weight,
173
+ self.iou_token.weight,
174
+ self.mask_tokens.weight,
175
+ ],
176
+ dim=0,
177
+ )
178
+ s = 1
179
+ else:
180
+ output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
181
+ output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
182
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
183
+
184
+ # Expand per-image data in batch direction to be per-mask
185
+ if repeat_image:
186
+ src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
187
+ else:
188
+ assert image_embeddings.shape[0] == tokens.shape[0]
189
+ src = image_embeddings
190
+ src = src + dense_prompt_embeddings
191
+ assert image_pe.size(0) == 1, "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
192
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
193
+ b, c, h, w = src.shape
194
+
195
+ # Run the transformer
196
+ hs, src = self.transformer(src, pos_src, tokens)
197
+ iou_token_out = hs[:, s, :]
198
+ mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]
199
+
200
+ # Upscale mask embeddings and predict masks using the mask tokens
201
+ src = src.transpose(1, 2).view(b, c, h, w)
202
+ if not self.use_high_res_features:
203
+ upscaled_embedding = self.output_upscaling(src)
204
+ else:
205
+ dc1, ln1, act1, dc2, act2 = self.output_upscaling
206
+ feat_s0, feat_s1 = high_res_features
207
+ upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
208
+ upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)
209
+
210
+ hyper_in_list: List[torch.Tensor] = []
211
+ for i in range(self.num_mask_tokens):
212
+ hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
213
+ hyper_in = torch.stack(hyper_in_list, dim=1)
214
+ b, c, h, w = upscaled_embedding.shape
215
+ masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
216
+
217
+ # Generate mask quality predictions
218
+ iou_pred = self.iou_prediction_head(iou_token_out)
219
+ if self.pred_obj_scores:
220
+ assert s == 1
221
+ object_score_logits = self.pred_obj_score_head(hs[:, 0, :])
222
+ else:
223
+ # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1
224
+ object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)
225
+
226
+ return masks, iou_pred, mask_tokens_out, object_score_logits
227
+
228
+ def _get_stability_scores(self, mask_logits):
229
+ """
230
+ Compute stability scores of the mask logits based on the IoU between upper and
231
+ lower thresholds.
232
+ """
233
+ mask_logits = mask_logits.flatten(-2)
234
+ stability_delta = self.dynamic_multimask_stability_delta
235
+ area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
236
+ area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
237
+ stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
238
+ return stability_scores
239
+
240
+ def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
241
+ """
242
+ When outputting a single mask, if the stability score from the current single-mask
243
+ output (based on output token 0) falls below a threshold, we instead select from
244
+ multi-mask outputs (based on output token 1~3) the mask with the highest predicted
245
+ IoU score. This is intended to ensure a valid mask for both clicking and tracking.
246
+ """
247
+ # The best mask from multimask output tokens (1~3)
248
+ multimask_logits = all_mask_logits[:, 1:, :, :]
249
+ multimask_iou_scores = all_iou_scores[:, 1:]
250
+ best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
251
+ batch_inds = torch.arange(multimask_iou_scores.size(0), device=all_iou_scores.device)
252
+ best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
253
+ best_multimask_logits = best_multimask_logits.unsqueeze(1)
254
+ best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
255
+ best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1)
256
+
257
+ # The mask from singlemask output token 0 and its stability score
258
+ singlemask_logits = all_mask_logits[:, 0:1, :, :]
259
+ singlemask_iou_scores = all_iou_scores[:, 0:1]
260
+ stability_scores = self._get_stability_scores(singlemask_logits)
261
+ is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
262
+
263
+ # Dynamically fall back to best multimask output upon low stability scores.
264
+ mask_logits_out = torch.where(
265
+ is_stable[..., None, None].expand_as(singlemask_logits),
266
+ singlemask_logits,
267
+ best_multimask_logits,
268
+ )
269
+ iou_scores_out = torch.where(
270
+ is_stable.expand_as(singlemask_iou_scores),
271
+ singlemask_iou_scores,
272
+ best_multimask_iou_scores,
273
+ )
274
+ return mask_logits_out, iou_scores_out