singlebehaviorlab 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sam2/__init__.py +11 -0
  2. sam2/automatic_mask_generator.py +454 -0
  3. sam2/benchmark.py +92 -0
  4. sam2/build_sam.py +174 -0
  5. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  6. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  7. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  8. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  9. sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
  10. sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
  11. sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
  12. sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
  13. sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
  14. sam2/modeling/__init__.py +5 -0
  15. sam2/modeling/backbones/__init__.py +5 -0
  16. sam2/modeling/backbones/hieradet.py +317 -0
  17. sam2/modeling/backbones/image_encoder.py +134 -0
  18. sam2/modeling/backbones/utils.py +93 -0
  19. sam2/modeling/memory_attention.py +169 -0
  20. sam2/modeling/memory_encoder.py +181 -0
  21. sam2/modeling/position_encoding.py +239 -0
  22. sam2/modeling/sam/__init__.py +5 -0
  23. sam2/modeling/sam/mask_decoder.py +295 -0
  24. sam2/modeling/sam/prompt_encoder.py +202 -0
  25. sam2/modeling/sam/transformer.py +311 -0
  26. sam2/modeling/sam2_base.py +913 -0
  27. sam2/modeling/sam2_utils.py +323 -0
  28. sam2/sam2_hiera_b+.yaml +113 -0
  29. sam2/sam2_hiera_l.yaml +117 -0
  30. sam2/sam2_hiera_s.yaml +116 -0
  31. sam2/sam2_hiera_t.yaml +118 -0
  32. sam2/sam2_image_predictor.py +466 -0
  33. sam2/sam2_video_predictor.py +1388 -0
  34. sam2/sam2_video_predictor_legacy.py +1172 -0
  35. sam2/utils/__init__.py +5 -0
  36. sam2/utils/amg.py +348 -0
  37. sam2/utils/misc.py +349 -0
  38. sam2/utils/transforms.py +118 -0
  39. singlebehaviorlab/__init__.py +4 -0
  40. singlebehaviorlab/__main__.py +130 -0
  41. singlebehaviorlab/_paths.py +100 -0
  42. singlebehaviorlab/backend/__init__.py +2 -0
  43. singlebehaviorlab/backend/augmentations.py +320 -0
  44. singlebehaviorlab/backend/data_store.py +420 -0
  45. singlebehaviorlab/backend/model.py +1290 -0
  46. singlebehaviorlab/backend/train.py +4667 -0
  47. singlebehaviorlab/backend/uncertainty.py +578 -0
  48. singlebehaviorlab/backend/video_processor.py +688 -0
  49. singlebehaviorlab/backend/video_utils.py +139 -0
  50. singlebehaviorlab/data/config/config.yaml +85 -0
  51. singlebehaviorlab/data/training_profiles.json +334 -0
  52. singlebehaviorlab/gui/__init__.py +4 -0
  53. singlebehaviorlab/gui/analysis_widget.py +2291 -0
  54. singlebehaviorlab/gui/attention_export.py +311 -0
  55. singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
  56. singlebehaviorlab/gui/clustering_widget.py +3187 -0
  57. singlebehaviorlab/gui/inference_popups.py +1138 -0
  58. singlebehaviorlab/gui/inference_widget.py +4550 -0
  59. singlebehaviorlab/gui/inference_worker.py +651 -0
  60. singlebehaviorlab/gui/labeling_widget.py +2324 -0
  61. singlebehaviorlab/gui/main_window.py +754 -0
  62. singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
  63. singlebehaviorlab/gui/motion_tracking.py +764 -0
  64. singlebehaviorlab/gui/overlay_export.py +1234 -0
  65. singlebehaviorlab/gui/plot_integration.py +729 -0
  66. singlebehaviorlab/gui/qt_helpers.py +29 -0
  67. singlebehaviorlab/gui/registration_widget.py +1485 -0
  68. singlebehaviorlab/gui/review_widget.py +1330 -0
  69. singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
  70. singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
  71. singlebehaviorlab/gui/timeline_themes.py +131 -0
  72. singlebehaviorlab/gui/training_profiles.py +418 -0
  73. singlebehaviorlab/gui/training_widget.py +3719 -0
  74. singlebehaviorlab/gui/video_utils.py +233 -0
  75. singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
  76. singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
  77. singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
  78. singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
  79. singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
  80. singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
  81. singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
  82. singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
  83. videoprism/__init__.py +0 -0
  84. videoprism/encoders.py +910 -0
  85. videoprism/layers.py +1136 -0
  86. videoprism/models.py +407 -0
  87. videoprism/tokenizers.py +167 -0
  88. videoprism/utils.py +168 -0
@@ -0,0 +1,202 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Optional, Tuple, Type
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+ from sam2.modeling.position_encoding import PositionEmbeddingRandom
13
+
14
+ from sam2.modeling.sam2_utils import LayerNorm2d
15
+
16
+
17
+ class PromptEncoder(nn.Module):
18
+ def __init__(
19
+ self,
20
+ embed_dim: int,
21
+ image_embedding_size: Tuple[int, int],
22
+ input_image_size: Tuple[int, int],
23
+ mask_in_chans: int,
24
+ activation: Type[nn.Module] = nn.GELU,
25
+ ) -> None:
26
+ """
27
+ Encodes prompts for input to SAM's mask decoder.
28
+
29
+ Arguments:
30
+ embed_dim (int): The prompts' embedding dimension
31
+ image_embedding_size (tuple(int, int)): The spatial size of the
32
+ image embedding, as (H, W).
33
+ input_image_size (int): The padded size of the image as input
34
+ to the image encoder, as (H, W).
35
+ mask_in_chans (int): The number of hidden channels used for
36
+ encoding input masks.
37
+ activation (nn.Module): The activation to use when encoding
38
+ input masks.
39
+ """
40
+ super().__init__()
41
+ self.embed_dim = embed_dim
42
+ self.input_image_size = input_image_size
43
+ self.image_embedding_size = image_embedding_size
44
+ self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
45
+
46
+ self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
47
+ point_embeddings = [
48
+ nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
49
+ ]
50
+ self.point_embeddings = nn.ModuleList(point_embeddings)
51
+ self.not_a_point_embed = nn.Embedding(1, embed_dim)
52
+
53
+ self.mask_input_size = (
54
+ 4 * image_embedding_size[0],
55
+ 4 * image_embedding_size[1],
56
+ )
57
+ self.mask_downscaling = nn.Sequential(
58
+ nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
59
+ LayerNorm2d(mask_in_chans // 4),
60
+ activation(),
61
+ nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
62
+ LayerNorm2d(mask_in_chans),
63
+ activation(),
64
+ nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
65
+ )
66
+ self.no_mask_embed = nn.Embedding(1, embed_dim)
67
+
68
+ def get_dense_pe(self) -> torch.Tensor:
69
+ """
70
+ Returns the positional encoding used to encode point prompts,
71
+ applied to a dense set of points the shape of the image encoding.
72
+
73
+ Returns:
74
+ torch.Tensor: Positional encoding with shape
75
+ 1x(embed_dim)x(embedding_h)x(embedding_w)
76
+ """
77
+ return self.pe_layer(self.image_embedding_size).unsqueeze(0)
78
+
79
+ def _embed_points(
80
+ self,
81
+ points: torch.Tensor,
82
+ labels: torch.Tensor,
83
+ pad: bool,
84
+ ) -> torch.Tensor:
85
+ """Embeds point prompts."""
86
+ points = points + 0.5 # Shift to center of pixel
87
+ if pad:
88
+ padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
89
+ padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
90
+ points = torch.cat([points, padding_point], dim=1)
91
+ labels = torch.cat([labels, padding_label], dim=1)
92
+ point_embedding = self.pe_layer.forward_with_coords(
93
+ points, self.input_image_size
94
+ )
95
+
96
+ point_embedding = torch.where(
97
+ (labels == -1).unsqueeze(-1),
98
+ torch.zeros_like(point_embedding) + self.not_a_point_embed.weight,
99
+ point_embedding,
100
+ )
101
+ point_embedding = torch.where(
102
+ (labels == 0).unsqueeze(-1),
103
+ point_embedding + self.point_embeddings[0].weight,
104
+ point_embedding,
105
+ )
106
+ point_embedding = torch.where(
107
+ (labels == 1).unsqueeze(-1),
108
+ point_embedding + self.point_embeddings[1].weight,
109
+ point_embedding,
110
+ )
111
+ point_embedding = torch.where(
112
+ (labels == 2).unsqueeze(-1),
113
+ point_embedding + self.point_embeddings[2].weight,
114
+ point_embedding,
115
+ )
116
+ point_embedding = torch.where(
117
+ (labels == 3).unsqueeze(-1),
118
+ point_embedding + self.point_embeddings[3].weight,
119
+ point_embedding,
120
+ )
121
+ return point_embedding
122
+
123
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
124
+ """Embeds box prompts."""
125
+ boxes = boxes + 0.5 # Shift to center of pixel
126
+ coords = boxes.reshape(-1, 2, 2)
127
+ corner_embedding = self.pe_layer.forward_with_coords(
128
+ coords, self.input_image_size
129
+ )
130
+ corner_embedding[:, 0, :] += self.point_embeddings[2].weight
131
+ corner_embedding[:, 1, :] += self.point_embeddings[3].weight
132
+ return corner_embedding
133
+
134
+ def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
135
+ """Embeds mask inputs."""
136
+ mask_embedding = self.mask_downscaling(masks)
137
+ return mask_embedding
138
+
139
+ def _get_batch_size(
140
+ self,
141
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
142
+ boxes: Optional[torch.Tensor],
143
+ masks: Optional[torch.Tensor],
144
+ ) -> int:
145
+ """
146
+ Gets the batch size of the output given the batch size of the input prompts.
147
+ """
148
+ if points is not None:
149
+ return points[0].shape[0]
150
+ elif boxes is not None:
151
+ return boxes.shape[0]
152
+ elif masks is not None:
153
+ return masks.shape[0]
154
+ else:
155
+ return 1
156
+
157
+ def _get_device(self) -> torch.device:
158
+ return self.point_embeddings[0].weight.device
159
+
160
+ def forward(
161
+ self,
162
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
163
+ boxes: Optional[torch.Tensor],
164
+ masks: Optional[torch.Tensor],
165
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
166
+ """
167
+ Embeds different types of prompts, returning both sparse and dense
168
+ embeddings.
169
+
170
+ Arguments:
171
+ points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
172
+ and labels to embed.
173
+ boxes (torch.Tensor or none): boxes to embed
174
+ masks (torch.Tensor or none): masks to embed
175
+
176
+ Returns:
177
+ torch.Tensor: sparse embeddings for the points and boxes, with shape
178
+ BxNx(embed_dim), where N is determined by the number of input points
179
+ and boxes.
180
+ torch.Tensor: dense embeddings for the masks, in the shape
181
+ Bx(embed_dim)x(embed_H)x(embed_W)
182
+ """
183
+ bs = self._get_batch_size(points, boxes, masks)
184
+ sparse_embeddings = torch.empty(
185
+ (bs, 0, self.embed_dim), device=self._get_device()
186
+ )
187
+ if points is not None:
188
+ coords, labels = points
189
+ point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
190
+ sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
191
+ if boxes is not None:
192
+ box_embeddings = self._embed_boxes(boxes)
193
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
194
+
195
+ if masks is not None:
196
+ dense_embeddings = self._embed_masks(masks)
197
+ else:
198
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
199
+ bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
200
+ )
201
+
202
+ return sparse_embeddings, dense_embeddings
@@ -0,0 +1,311 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from functools import partial
9
+ from typing import Tuple, Type
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import nn, Tensor
14
+
15
+ from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
16
+ from sam2.modeling.sam2_utils import MLP
17
+
18
+
19
+ class TwoWayTransformer(nn.Module):
20
+ def __init__(
21
+ self,
22
+ depth: int,
23
+ embedding_dim: int,
24
+ num_heads: int,
25
+ mlp_dim: int,
26
+ activation: Type[nn.Module] = nn.ReLU,
27
+ attention_downsample_rate: int = 2,
28
+ ) -> None:
29
+ """
30
+ A transformer decoder that attends to an input image using
31
+ queries whose positional embedding is supplied.
32
+
33
+ Args:
34
+ depth (int): number of layers in the transformer
35
+ embedding_dim (int): the channel dimension for the input embeddings
36
+ num_heads (int): the number of heads for multihead attention. Must
37
+ divide embedding_dim
38
+ mlp_dim (int): the channel dimension internal to the MLP block
39
+ activation (nn.Module): the activation to use in the MLP block
40
+ """
41
+ super().__init__()
42
+ self.depth = depth
43
+ self.embedding_dim = embedding_dim
44
+ self.num_heads = num_heads
45
+ self.mlp_dim = mlp_dim
46
+ self.layers = nn.ModuleList()
47
+
48
+ for i in range(depth):
49
+ self.layers.append(
50
+ TwoWayAttentionBlock(
51
+ embedding_dim=embedding_dim,
52
+ num_heads=num_heads,
53
+ mlp_dim=mlp_dim,
54
+ activation=activation,
55
+ attention_downsample_rate=attention_downsample_rate,
56
+ skip_first_layer_pe=(i == 0),
57
+ )
58
+ )
59
+
60
+ self.final_attn_token_to_image = Attention(
61
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
62
+ )
63
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
64
+
65
+ def forward(
66
+ self,
67
+ image_embedding: Tensor,
68
+ image_pe: Tensor,
69
+ point_embedding: Tensor,
70
+ ) -> Tuple[Tensor, Tensor]:
71
+ """
72
+ Args:
73
+ image_embedding (torch.Tensor): image to attend to. Should be shape
74
+ B x embedding_dim x h x w for any h and w.
75
+ image_pe (torch.Tensor): the positional encoding to add to the image. Must
76
+ have the same shape as image_embedding.
77
+ point_embedding (torch.Tensor): the embedding to add to the query points.
78
+ Must have shape B x N_points x embedding_dim for any N_points.
79
+
80
+ Returns:
81
+ torch.Tensor: the processed point_embedding
82
+ torch.Tensor: the processed image_embedding
83
+ """
84
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
85
+ bs, c, h, w = image_embedding.shape
86
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
87
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
88
+
89
+ # Prepare queries
90
+ queries = point_embedding
91
+ keys = image_embedding
92
+
93
+ # Apply transformer blocks and final layernorm
94
+ for layer in self.layers:
95
+ queries, keys = layer(
96
+ queries=queries,
97
+ keys=keys,
98
+ query_pe=point_embedding,
99
+ key_pe=image_pe,
100
+ )
101
+
102
+ # Apply the final attention layer from the points to the image
103
+ q = queries + point_embedding
104
+ k = keys + image_pe
105
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
106
+ queries = queries + attn_out
107
+ queries = self.norm_final_attn(queries)
108
+
109
+ return queries, keys
110
+
111
+
112
+ class TwoWayAttentionBlock(nn.Module):
113
+ def __init__(
114
+ self,
115
+ embedding_dim: int,
116
+ num_heads: int,
117
+ mlp_dim: int = 2048,
118
+ activation: Type[nn.Module] = nn.ReLU,
119
+ attention_downsample_rate: int = 2,
120
+ skip_first_layer_pe: bool = False,
121
+ ) -> None:
122
+ """
123
+ A transformer block with four layers: (1) self-attention of sparse
124
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
125
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
126
+ inputs.
127
+
128
+ Arguments:
129
+ embedding_dim (int): the channel dimension of the embeddings
130
+ num_heads (int): the number of heads in the attention layers
131
+ mlp_dim (int): the hidden dimension of the mlp block
132
+ activation (nn.Module): the activation of the mlp block
133
+ skip_first_layer_pe (bool): skip the PE on the first layer
134
+ """
135
+ super().__init__()
136
+ self.self_attn = Attention(embedding_dim, num_heads)
137
+ self.norm1 = nn.LayerNorm(embedding_dim)
138
+
139
+ self.cross_attn_token_to_image = Attention(
140
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
141
+ )
142
+ self.norm2 = nn.LayerNorm(embedding_dim)
143
+
144
+ self.mlp = MLP(
145
+ embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation
146
+ )
147
+ self.norm3 = nn.LayerNorm(embedding_dim)
148
+
149
+ self.norm4 = nn.LayerNorm(embedding_dim)
150
+ self.cross_attn_image_to_token = Attention(
151
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
152
+ )
153
+
154
+ self.skip_first_layer_pe = skip_first_layer_pe
155
+
156
+ def forward(
157
+ self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
158
+ ) -> Tuple[Tensor, Tensor]:
159
+ # Self attention block
160
+ if self.skip_first_layer_pe:
161
+ queries = self.self_attn(q=queries, k=queries, v=queries)
162
+ else:
163
+ q = queries + query_pe
164
+ attn_out = self.self_attn(q=q, k=q, v=queries)
165
+ queries = queries + attn_out
166
+ queries = self.norm1(queries)
167
+
168
+ # Cross attention block, tokens attending to image embedding
169
+ q = queries + query_pe
170
+ k = keys + key_pe
171
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
172
+ queries = queries + attn_out
173
+ queries = self.norm2(queries)
174
+
175
+ # MLP block
176
+ mlp_out = self.mlp(queries)
177
+ queries = queries + mlp_out
178
+ queries = self.norm3(queries)
179
+
180
+ # Cross attention block, image embedding attending to tokens
181
+ q = queries + query_pe
182
+ k = keys + key_pe
183
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
184
+ keys = keys + attn_out
185
+ keys = self.norm4(keys)
186
+
187
+ return queries, keys
188
+
189
+
190
+ class Attention(nn.Module):
191
+ """
192
+ An attention layer that allows for downscaling the size of the embedding
193
+ after projection to queries, keys, and values.
194
+ """
195
+
196
+ def __init__(
197
+ self,
198
+ embedding_dim: int,
199
+ num_heads: int,
200
+ downsample_rate: int = 1,
201
+ dropout: float = 0.0,
202
+ kv_in_dim: int = None,
203
+ ) -> None:
204
+ super().__init__()
205
+ self.embedding_dim = embedding_dim
206
+ self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
207
+ self.internal_dim = embedding_dim // downsample_rate
208
+ self.num_heads = num_heads
209
+ assert (
210
+ self.internal_dim % num_heads == 0
211
+ ), "num_heads must divide embedding_dim."
212
+
213
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
214
+ self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
215
+ self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
216
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
217
+
218
+ self.dropout_p = dropout
219
+
220
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
221
+ b, n, c = x.shape
222
+ x = x.reshape(b, n, num_heads, c // num_heads)
223
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
224
+
225
+ def _recombine_heads(self, x: Tensor) -> Tensor:
226
+ b, n_heads, n_tokens, c_per_head = x.shape
227
+ x = x.transpose(1, 2)
228
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
229
+
230
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
231
+ # Input projections
232
+ q = self.q_proj(q)
233
+ k = self.k_proj(k)
234
+ v = self.v_proj(v)
235
+
236
+ # Separate into heads
237
+ q = self._separate_heads(q, self.num_heads)
238
+ k = self._separate_heads(k, self.num_heads)
239
+ v = self._separate_heads(v, self.num_heads)
240
+
241
+ dropout_p = self.dropout_p if self.training else 0.0
242
+ # Attention
243
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
244
+
245
+ out = self._recombine_heads(out)
246
+ out = self.out_proj(out)
247
+
248
+ return out
249
+
250
+
251
+ class RoPEAttention(Attention):
252
+ """Attention with rotary position encoding."""
253
+
254
+ def __init__(
255
+ self,
256
+ *args,
257
+ rope_theta=10000.0,
258
+ # whether to repeat q rope to match k length
259
+ # this is needed for cross-attention to memories
260
+ rope_k_repeat=False,
261
+ feat_sizes=(64, 64), # [w, h] for stride 16 feats at 1024 resolution
262
+ **kwargs,
263
+ ):
264
+ super().__init__(*args, **kwargs)
265
+
266
+ self.compute_cis = partial(
267
+ compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
268
+ )
269
+ freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
270
+ self.freqs_cis = (
271
+ freqs_cis.to("cuda") if torch.cuda.is_available() else freqs_cis
272
+ )
273
+ self.rope_k_repeat = rope_k_repeat
274
+
275
+ def forward(
276
+ self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0
277
+ ) -> Tensor:
278
+ # Input projections
279
+ q = self.q_proj(q)
280
+ k = self.k_proj(k)
281
+ v = self.v_proj(v)
282
+
283
+ # Separate into heads
284
+ q = self._separate_heads(q, self.num_heads)
285
+ k = self._separate_heads(k, self.num_heads)
286
+ v = self._separate_heads(v, self.num_heads)
287
+
288
+ # Apply rotary position encoding
289
+ w = h = math.sqrt(q.shape[-2])
290
+ self.freqs_cis = self.freqs_cis.to(q.device)
291
+ if self.freqs_cis.shape[0] != q.shape[-2]:
292
+ self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
293
+ if q.shape[-2] != k.shape[-2]:
294
+ assert self.rope_k_repeat
295
+
296
+ num_k_rope = k.size(-2) - num_k_exclude_rope
297
+ q, k[:, :, :num_k_rope] = apply_rotary_enc(
298
+ q,
299
+ k[:, :, :num_k_rope],
300
+ freqs_cis=self.freqs_cis,
301
+ repeat_freqs_k=self.rope_k_repeat,
302
+ )
303
+
304
+ dropout_p = self.dropout_p if self.training else 0.0
305
+ # Attention
306
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
307
+
308
+ out = self._recombine_heads(out)
309
+ out = self.out_proj(out)
310
+
311
+ return out