nnInteractive 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. nnInteractive/__init__.py +3 -0
  2. nnInteractive/inference/__init__.py +0 -0
  3. nnInteractive/inference/cvpr2025_challenge_baseline/__init__.py +0 -0
  4. nnInteractive/inference/cvpr2025_challenge_baseline/predict.py +173 -0
  5. nnInteractive/inference/inference_session.py +1400 -0
  6. nnInteractive/interaction/__init__.py +0 -0
  7. nnInteractive/interaction/point.py +166 -0
  8. nnInteractive/supervoxel/setup.py +4 -0
  9. nnInteractive/supervoxel/src/metadata.py +118 -0
  10. nnInteractive/supervoxel/src/reader.py +175 -0
  11. nnInteractive/supervoxel/src/run.py +136 -0
  12. nnInteractive/supervoxel/src/sam2/__init__.py +2 -0
  13. nnInteractive/supervoxel/src/sam2/sam2/__init__.py +11 -0
  14. nnInteractive/supervoxel/src/sam2/sam2/automatic_mask_generator.py +434 -0
  15. nnInteractive/supervoxel/src/sam2/sam2/benchmark.py +86 -0
  16. nnInteractive/supervoxel/src/sam2/sam2/build_sam.py +172 -0
  17. nnInteractive/supervoxel/src/sam2/sam2/modeling/__init__.py +5 -0
  18. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/__init__.py +5 -0
  19. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/hieradet.py +305 -0
  20. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/image_encoder.py +132 -0
  21. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/utils.py +89 -0
  22. nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_attention.py +167 -0
  23. nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_encoder.py +179 -0
  24. nnInteractive/supervoxel/src/sam2/sam2/modeling/position_encoding.py +217 -0
  25. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/__init__.py +5 -0
  26. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/mask_decoder.py +274 -0
  27. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/prompt_encoder.py +194 -0
  28. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/transformer.py +293 -0
  29. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_base.py +879 -0
  30. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_utils.py +315 -0
  31. nnInteractive/supervoxel/src/sam2/sam2/sam2_image_predictor.py +433 -0
  32. nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor.py +1171 -0
  33. nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor_legacy.py +1125 -0
  34. nnInteractive/supervoxel/src/sam2/sam2/utils/__init__.py +5 -0
  35. nnInteractive/supervoxel/src/sam2/sam2/utils/amg.py +332 -0
  36. nnInteractive/supervoxel/src/sam2/sam2/utils/misc.py +488 -0
  37. nnInteractive/supervoxel/src/sam2/sam2/utils/transforms.py +108 -0
  38. nnInteractive/supervoxel/src/sam2/setup.py +174 -0
  39. nnInteractive/supervoxel/src/sam2/training/__init__.py +5 -0
  40. nnInteractive/supervoxel/src/sam2/training/dataset/__init__.py +5 -0
  41. nnInteractive/supervoxel/src/sam2/training/dataset/sam2_datasets.py +176 -0
  42. nnInteractive/supervoxel/src/sam2/training/dataset/transforms.py +481 -0
  43. nnInteractive/supervoxel/src/sam2/training/dataset/utils.py +102 -0
  44. nnInteractive/supervoxel/src/sam2/training/dataset/vos_dataset.py +154 -0
  45. nnInteractive/supervoxel/src/sam2/training/dataset/vos_raw_dataset.py +290 -0
  46. nnInteractive/supervoxel/src/sam2/training/dataset/vos_sampler.py +103 -0
  47. nnInteractive/supervoxel/src/sam2/training/dataset/vos_segment_loader.py +289 -0
  48. nnInteractive/supervoxel/src/sam2/training/loss_fns.py +290 -0
  49. nnInteractive/supervoxel/src/sam2/training/model/__init__.py +5 -0
  50. nnInteractive/supervoxel/src/sam2/training/model/sam2.py +515 -0
  51. nnInteractive/supervoxel/src/sam2/training/optimizer.py +462 -0
  52. nnInteractive/supervoxel/src/sam2/training/scripts/sav_frame_extraction_submitit.py +157 -0
  53. nnInteractive/supervoxel/src/sam2/training/train.py +232 -0
  54. nnInteractive/supervoxel/src/sam2/training/trainer.py +1051 -0
  55. nnInteractive/supervoxel/src/sam2/training/utils/__init__.py +5 -0
  56. nnInteractive/supervoxel/src/sam2/training/utils/checkpoint_utils.py +328 -0
  57. nnInteractive/supervoxel/src/sam2/training/utils/data_utils.py +166 -0
  58. nnInteractive/supervoxel/src/sam2/training/utils/distributed.py +560 -0
  59. nnInteractive/supervoxel/src/sam2/training/utils/logger.py +236 -0
  60. nnInteractive/supervoxel/src/sam2/training/utils/train_utils.py +275 -0
  61. nnInteractive/supervoxel/src/supervoxel.py +198 -0
  62. nnInteractive/trainer/__init__.py +0 -0
  63. nnInteractive/trainer/nnInteractiveTrainer.py +24 -0
  64. nnInteractive/utils/__init__.py +0 -0
  65. nnInteractive/utils/bboxes.py +217 -0
  66. nnInteractive/utils/checkpoint_cleansing.py +9 -0
  67. nnInteractive/utils/crop.py +268 -0
  68. nnInteractive/utils/erosion_dilation.py +48 -0
  69. nnInteractive/utils/inference_helpers.py +45 -0
  70. nnInteractive/utils/os_shennanigans.py +16 -0
  71. nnInteractive/utils/rounding.py +13 -0
  72. nninteractive-2.0.0.dist-info/METADATA +511 -0
  73. nninteractive-2.0.0.dist-info/RECORD +76 -0
  74. nninteractive-2.0.0.dist-info/WHEEL +5 -0
  75. nninteractive-2.0.0.dist-info/licenses/LICENSE +201 -0
  76. nninteractive-2.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,194 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Optional, Tuple, Type
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+ from sam2.modeling.position_encoding import PositionEmbeddingRandom
13
+
14
+ from sam2.modeling.sam2_utils import LayerNorm2d
15
+
16
+
17
+ class PromptEncoder(nn.Module):
18
+ def __init__(
19
+ self,
20
+ embed_dim: int,
21
+ image_embedding_size: Tuple[int, int],
22
+ input_image_size: Tuple[int, int],
23
+ mask_in_chans: int,
24
+ activation: Type[nn.Module] = nn.GELU,
25
+ ) -> None:
26
+ """
27
+ Encodes prompts for input to SAM's mask decoder.
28
+
29
+ Arguments:
30
+ embed_dim (int): The prompts' embedding dimension
31
+ image_embedding_size (tuple(int, int)): The spatial size of the
32
+ image embedding, as (H, W).
33
+ input_image_size (int): The padded size of the image as input
34
+ to the image encoder, as (H, W).
35
+ mask_in_chans (int): The number of hidden channels used for
36
+ encoding input masks.
37
+ activation (nn.Module): The activation to use when encoding
38
+ input masks.
39
+ """
40
+ super().__init__()
41
+ self.embed_dim = embed_dim
42
+ self.input_image_size = input_image_size
43
+ self.image_embedding_size = image_embedding_size
44
+ self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
45
+
46
+ self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
47
+ point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
48
+ self.point_embeddings = nn.ModuleList(point_embeddings)
49
+ self.not_a_point_embed = nn.Embedding(1, embed_dim)
50
+
51
+ self.mask_input_size = (
52
+ 4 * image_embedding_size[0],
53
+ 4 * image_embedding_size[1],
54
+ )
55
+ self.mask_downscaling = nn.Sequential(
56
+ nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
57
+ LayerNorm2d(mask_in_chans // 4),
58
+ activation(),
59
+ nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
60
+ LayerNorm2d(mask_in_chans),
61
+ activation(),
62
+ nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
63
+ )
64
+ self.no_mask_embed = nn.Embedding(1, embed_dim)
65
+
66
+ def get_dense_pe(self) -> torch.Tensor:
67
+ """
68
+ Returns the positional encoding used to encode point prompts,
69
+ applied to a dense set of points the shape of the image encoding.
70
+
71
+ Returns:
72
+ torch.Tensor: Positional encoding with shape
73
+ 1x(embed_dim)x(embedding_h)x(embedding_w)
74
+ """
75
+ return self.pe_layer(self.image_embedding_size).unsqueeze(0)
76
+
77
+ def _embed_points(
78
+ self,
79
+ points: torch.Tensor,
80
+ labels: torch.Tensor,
81
+ pad: bool,
82
+ ) -> torch.Tensor:
83
+ """Embeds point prompts."""
84
+ points = points + 0.5 # Shift to center of pixel
85
+ if pad:
86
+ padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
87
+ padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
88
+ points = torch.cat([points, padding_point], dim=1)
89
+ labels = torch.cat([labels, padding_label], dim=1)
90
+ point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
91
+
92
+ point_embedding = torch.where(
93
+ (labels == -1).unsqueeze(-1),
94
+ torch.zeros_like(point_embedding) + self.not_a_point_embed.weight,
95
+ point_embedding,
96
+ )
97
+ point_embedding = torch.where(
98
+ (labels == 0).unsqueeze(-1),
99
+ point_embedding + self.point_embeddings[0].weight,
100
+ point_embedding,
101
+ )
102
+ point_embedding = torch.where(
103
+ (labels == 1).unsqueeze(-1),
104
+ point_embedding + self.point_embeddings[1].weight,
105
+ point_embedding,
106
+ )
107
+ point_embedding = torch.where(
108
+ (labels == 2).unsqueeze(-1),
109
+ point_embedding + self.point_embeddings[2].weight,
110
+ point_embedding,
111
+ )
112
+ point_embedding = torch.where(
113
+ (labels == 3).unsqueeze(-1),
114
+ point_embedding + self.point_embeddings[3].weight,
115
+ point_embedding,
116
+ )
117
+ return point_embedding
118
+
119
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
120
+ """Embeds box prompts."""
121
+ boxes = boxes + 0.5 # Shift to center of pixel
122
+ coords = boxes.reshape(-1, 2, 2)
123
+ corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
124
+ corner_embedding[:, 0, :] += self.point_embeddings[2].weight
125
+ corner_embedding[:, 1, :] += self.point_embeddings[3].weight
126
+ return corner_embedding
127
+
128
+ def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
129
+ """Embeds mask inputs."""
130
+ mask_embedding = self.mask_downscaling(masks)
131
+ return mask_embedding
132
+
133
+ def _get_batch_size(
134
+ self,
135
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
136
+ boxes: Optional[torch.Tensor],
137
+ masks: Optional[torch.Tensor],
138
+ ) -> int:
139
+ """
140
+ Gets the batch size of the output given the batch size of the input prompts.
141
+ """
142
+ if points is not None:
143
+ return points[0].shape[0]
144
+ elif boxes is not None:
145
+ return boxes.shape[0]
146
+ elif masks is not None:
147
+ return masks.shape[0]
148
+ else:
149
+ return 1
150
+
151
+ def _get_device(self) -> torch.device:
152
+ return self.point_embeddings[0].weight.device
153
+
154
+ def forward(
155
+ self,
156
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
157
+ boxes: Optional[torch.Tensor],
158
+ masks: Optional[torch.Tensor],
159
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
160
+ """
161
+ Embeds different types of prompts, returning both sparse and dense
162
+ embeddings.
163
+
164
+ Arguments:
165
+ points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
166
+ and labels to embed.
167
+ boxes (torch.Tensor or none): boxes to embed
168
+ masks (torch.Tensor or none): masks to embed
169
+
170
+ Returns:
171
+ torch.Tensor: sparse embeddings for the points and boxes, with shape
172
+ BxNx(embed_dim), where N is determined by the number of input points
173
+ and boxes.
174
+ torch.Tensor: dense embeddings for the masks, in the shape
175
+ Bx(embed_dim)x(embed_H)x(embed_W)
176
+ """
177
+ bs = self._get_batch_size(points, boxes, masks)
178
+ sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
179
+ if points is not None:
180
+ coords, labels = points
181
+ point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
182
+ sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
183
+ if boxes is not None:
184
+ box_embeddings = self._embed_boxes(boxes)
185
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
186
+
187
+ if masks is not None:
188
+ dense_embeddings = self._embed_masks(masks)
189
+ else:
190
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
191
+ bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
192
+ )
193
+
194
+ return sparse_embeddings, dense_embeddings
@@ -0,0 +1,293 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from functools import partial
9
+ from typing import Tuple, Type
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import nn, Tensor
14
+
15
+ from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
16
+ from sam2.modeling.sam2_utils import MLP
17
+
18
+
19
+ class TwoWayTransformer(nn.Module):
20
+ def __init__(
21
+ self,
22
+ depth: int,
23
+ embedding_dim: int,
24
+ num_heads: int,
25
+ mlp_dim: int,
26
+ activation: Type[nn.Module] = nn.ReLU,
27
+ attention_downsample_rate: int = 2,
28
+ ) -> None:
29
+ """
30
+ A transformer decoder that attends to an input image using
31
+ queries whose positional embedding is supplied.
32
+
33
+ Args:
34
+ depth (int): number of layers in the transformer
35
+ embedding_dim (int): the channel dimension for the input embeddings
36
+ num_heads (int): the number of heads for multihead attention. Must
37
+ divide embedding_dim
38
+ mlp_dim (int): the channel dimension internal to the MLP block
39
+ activation (nn.Module): the activation to use in the MLP block
40
+ """
41
+ super().__init__()
42
+ self.depth = depth
43
+ self.embedding_dim = embedding_dim
44
+ self.num_heads = num_heads
45
+ self.mlp_dim = mlp_dim
46
+ self.layers = nn.ModuleList()
47
+
48
+ for i in range(depth):
49
+ self.layers.append(
50
+ TwoWayAttentionBlock(
51
+ embedding_dim=embedding_dim,
52
+ num_heads=num_heads,
53
+ mlp_dim=mlp_dim,
54
+ activation=activation,
55
+ attention_downsample_rate=attention_downsample_rate,
56
+ skip_first_layer_pe=(i == 0),
57
+ )
58
+ )
59
+
60
+ self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
61
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
62
+
63
+ def forward(
64
+ self,
65
+ image_embedding: Tensor,
66
+ image_pe: Tensor,
67
+ point_embedding: Tensor,
68
+ ) -> Tuple[Tensor, Tensor]:
69
+ """
70
+ Args:
71
+ image_embedding (torch.Tensor): image to attend to. Should be shape
72
+ B x embedding_dim x h x w for any h and w.
73
+ image_pe (torch.Tensor): the positional encoding to add to the image. Must
74
+ have the same shape as image_embedding.
75
+ point_embedding (torch.Tensor): the embedding to add to the query points.
76
+ Must have shape B x N_points x embedding_dim for any N_points.
77
+
78
+ Returns:
79
+ torch.Tensor: the processed point_embedding
80
+ torch.Tensor: the processed image_embedding
81
+ """
82
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
83
+ bs, c, h, w = image_embedding.shape
84
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
85
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
86
+
87
+ # Prepare queries
88
+ queries = point_embedding
89
+ keys = image_embedding
90
+
91
+ # Apply transformer blocks and final layernorm
92
+ for layer in self.layers:
93
+ queries, keys = layer(
94
+ queries=queries,
95
+ keys=keys,
96
+ query_pe=point_embedding,
97
+ key_pe=image_pe,
98
+ )
99
+
100
+ # Apply the final attention layer from the points to the image
101
+ q = queries + point_embedding
102
+ k = keys + image_pe
103
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
104
+ queries = queries + attn_out
105
+ queries = self.norm_final_attn(queries)
106
+
107
+ return queries, keys
108
+
109
+
110
+ class TwoWayAttentionBlock(nn.Module):
111
+ def __init__(
112
+ self,
113
+ embedding_dim: int,
114
+ num_heads: int,
115
+ mlp_dim: int = 2048,
116
+ activation: Type[nn.Module] = nn.ReLU,
117
+ attention_downsample_rate: int = 2,
118
+ skip_first_layer_pe: bool = False,
119
+ ) -> None:
120
+ """
121
+ A transformer block with four layers: (1) self-attention of sparse
122
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
123
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
124
+ inputs.
125
+
126
+ Arguments:
127
+ embedding_dim (int): the channel dimension of the embeddings
128
+ num_heads (int): the number of heads in the attention layers
129
+ mlp_dim (int): the hidden dimension of the mlp block
130
+ activation (nn.Module): the activation of the mlp block
131
+ skip_first_layer_pe (bool): skip the PE on the first layer
132
+ """
133
+ super().__init__()
134
+ self.self_attn = Attention(embedding_dim, num_heads)
135
+ self.norm1 = nn.LayerNorm(embedding_dim)
136
+
137
+ self.cross_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
138
+ self.norm2 = nn.LayerNorm(embedding_dim)
139
+
140
+ self.mlp = MLP(embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation)
141
+ self.norm3 = nn.LayerNorm(embedding_dim)
142
+
143
+ self.norm4 = nn.LayerNorm(embedding_dim)
144
+ self.cross_attn_image_to_token = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
145
+
146
+ self.skip_first_layer_pe = skip_first_layer_pe
147
+
148
+ def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]:
149
+ # Self attention block
150
+ if self.skip_first_layer_pe:
151
+ queries = self.self_attn(q=queries, k=queries, v=queries)
152
+ else:
153
+ q = queries + query_pe
154
+ attn_out = self.self_attn(q=q, k=q, v=queries)
155
+ queries = queries + attn_out
156
+ queries = self.norm1(queries)
157
+
158
+ # Cross attention block, tokens attending to image embedding
159
+ q = queries + query_pe
160
+ k = keys + key_pe
161
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
162
+ queries = queries + attn_out
163
+ queries = self.norm2(queries)
164
+
165
+ # MLP block
166
+ mlp_out = self.mlp(queries)
167
+ queries = queries + mlp_out
168
+ queries = self.norm3(queries)
169
+
170
+ # Cross attention block, image embedding attending to tokens
171
+ q = queries + query_pe
172
+ k = keys + key_pe
173
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
174
+ keys = keys + attn_out
175
+ keys = self.norm4(keys)
176
+
177
+ return queries, keys
178
+
179
+
180
+ class Attention(nn.Module):
181
+ """
182
+ An attention layer that allows for downscaling the size of the embedding
183
+ after projection to queries, keys, and values.
184
+ """
185
+
186
+ def __init__(
187
+ self,
188
+ embedding_dim: int,
189
+ num_heads: int,
190
+ downsample_rate: int = 1,
191
+ dropout: float = 0.0,
192
+ kv_in_dim: int = None,
193
+ ) -> None:
194
+ super().__init__()
195
+ self.embedding_dim = embedding_dim
196
+ self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
197
+ self.internal_dim = embedding_dim // downsample_rate
198
+ self.num_heads = num_heads
199
+ assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
200
+
201
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
202
+ self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
203
+ self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
204
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
205
+
206
+ self.dropout_p = dropout
207
+
208
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
209
+ b, n, c = x.shape
210
+ x = x.reshape(b, n, num_heads, c // num_heads)
211
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
212
+
213
+ def _recombine_heads(self, x: Tensor) -> Tensor:
214
+ b, n_heads, n_tokens, c_per_head = x.shape
215
+ x = x.transpose(1, 2)
216
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
217
+
218
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
219
+ # Input projections
220
+ q = self.q_proj(q)
221
+ k = self.k_proj(k)
222
+ v = self.v_proj(v)
223
+
224
+ # Separate into heads
225
+ q = self._separate_heads(q, self.num_heads)
226
+ k = self._separate_heads(k, self.num_heads)
227
+ v = self._separate_heads(v, self.num_heads)
228
+
229
+ dropout_p = self.dropout_p if self.training else 0.0
230
+ # Attention
231
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
232
+
233
+ out = self._recombine_heads(out)
234
+ out = self.out_proj(out)
235
+
236
+ return out
237
+
238
+
239
+ class RoPEAttention(Attention):
240
+ """Attention with rotary position encoding."""
241
+
242
+ def __init__(
243
+ self,
244
+ *args,
245
+ rope_theta=10000.0,
246
+ # whether to repeat q rope to match k length
247
+ # this is needed for cross-attention to memories
248
+ rope_k_repeat=False,
249
+ feat_sizes=(64, 64), # [w, h] for stride 16 feats at 1024 resolution
250
+ **kwargs,
251
+ ):
252
+ super().__init__(*args, **kwargs)
253
+
254
+ self.compute_cis = partial(compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta)
255
+ freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
256
+ self.freqs_cis = freqs_cis.to("cuda") if torch.cuda.is_available() else freqs_cis
257
+ self.rope_k_repeat = rope_k_repeat
258
+
259
+ def forward(self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0) -> Tensor:
260
+ # Input projections
261
+ q = self.q_proj(q)
262
+ k = self.k_proj(k)
263
+ v = self.v_proj(v)
264
+
265
+ # Separate into heads
266
+ q = self._separate_heads(q, self.num_heads)
267
+ k = self._separate_heads(k, self.num_heads)
268
+ v = self._separate_heads(v, self.num_heads)
269
+
270
+ # Apply rotary position encoding
271
+ w = h = math.sqrt(q.shape[-2])
272
+ self.freqs_cis = self.freqs_cis.to(q.device)
273
+ if self.freqs_cis.shape[0] != q.shape[-2]:
274
+ self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
275
+ if q.shape[-2] != k.shape[-2]:
276
+ assert self.rope_k_repeat
277
+
278
+ num_k_rope = k.size(-2) - num_k_exclude_rope
279
+ q, k[:, :, :num_k_rope] = apply_rotary_enc(
280
+ q,
281
+ k[:, :, :num_k_rope],
282
+ freqs_cis=self.freqs_cis,
283
+ repeat_freqs_k=self.rope_k_repeat,
284
+ )
285
+
286
+ dropout_p = self.dropout_p if self.training else 0.0
287
+ # Attention
288
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
289
+
290
+ out = self._recombine_heads(out)
291
+ out = self.out_proj(out)
292
+
293
+ return out