singlebehaviorlab 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sam2/__init__.py +11 -0
- sam2/automatic_mask_generator.py +454 -0
- sam2/benchmark.py +92 -0
- sam2/build_sam.py +174 -0
- sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
- sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
- sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
- sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
- sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
- sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
- sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
- sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
- sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
- sam2/modeling/__init__.py +5 -0
- sam2/modeling/backbones/__init__.py +5 -0
- sam2/modeling/backbones/hieradet.py +317 -0
- sam2/modeling/backbones/image_encoder.py +134 -0
- sam2/modeling/backbones/utils.py +93 -0
- sam2/modeling/memory_attention.py +169 -0
- sam2/modeling/memory_encoder.py +181 -0
- sam2/modeling/position_encoding.py +239 -0
- sam2/modeling/sam/__init__.py +5 -0
- sam2/modeling/sam/mask_decoder.py +295 -0
- sam2/modeling/sam/prompt_encoder.py +202 -0
- sam2/modeling/sam/transformer.py +311 -0
- sam2/modeling/sam2_base.py +913 -0
- sam2/modeling/sam2_utils.py +323 -0
- sam2/sam2_hiera_b+.yaml +113 -0
- sam2/sam2_hiera_l.yaml +117 -0
- sam2/sam2_hiera_s.yaml +116 -0
- sam2/sam2_hiera_t.yaml +118 -0
- sam2/sam2_image_predictor.py +466 -0
- sam2/sam2_video_predictor.py +1388 -0
- sam2/sam2_video_predictor_legacy.py +1172 -0
- sam2/utils/__init__.py +5 -0
- sam2/utils/amg.py +348 -0
- sam2/utils/misc.py +349 -0
- sam2/utils/transforms.py +118 -0
- singlebehaviorlab/__init__.py +4 -0
- singlebehaviorlab/__main__.py +130 -0
- singlebehaviorlab/_paths.py +100 -0
- singlebehaviorlab/backend/__init__.py +2 -0
- singlebehaviorlab/backend/augmentations.py +320 -0
- singlebehaviorlab/backend/data_store.py +420 -0
- singlebehaviorlab/backend/model.py +1290 -0
- singlebehaviorlab/backend/train.py +4667 -0
- singlebehaviorlab/backend/uncertainty.py +578 -0
- singlebehaviorlab/backend/video_processor.py +688 -0
- singlebehaviorlab/backend/video_utils.py +139 -0
- singlebehaviorlab/data/config/config.yaml +85 -0
- singlebehaviorlab/data/training_profiles.json +334 -0
- singlebehaviorlab/gui/__init__.py +4 -0
- singlebehaviorlab/gui/analysis_widget.py +2291 -0
- singlebehaviorlab/gui/attention_export.py +311 -0
- singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
- singlebehaviorlab/gui/clustering_widget.py +3187 -0
- singlebehaviorlab/gui/inference_popups.py +1138 -0
- singlebehaviorlab/gui/inference_widget.py +4550 -0
- singlebehaviorlab/gui/inference_worker.py +651 -0
- singlebehaviorlab/gui/labeling_widget.py +2324 -0
- singlebehaviorlab/gui/main_window.py +754 -0
- singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
- singlebehaviorlab/gui/motion_tracking.py +764 -0
- singlebehaviorlab/gui/overlay_export.py +1234 -0
- singlebehaviorlab/gui/plot_integration.py +729 -0
- singlebehaviorlab/gui/qt_helpers.py +29 -0
- singlebehaviorlab/gui/registration_widget.py +1485 -0
- singlebehaviorlab/gui/review_widget.py +1330 -0
- singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
- singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
- singlebehaviorlab/gui/timeline_themes.py +131 -0
- singlebehaviorlab/gui/training_profiles.py +418 -0
- singlebehaviorlab/gui/training_widget.py +3719 -0
- singlebehaviorlab/gui/video_utils.py +233 -0
- singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
- singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
- singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
- singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
- singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
- singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
- singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
- singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
- videoprism/__init__.py +0 -0
- videoprism/encoders.py +910 -0
- videoprism/layers.py +1136 -0
- videoprism/models.py +407 -0
- videoprism/tokenizers.py +167 -0
- videoprism/utils.py +168 -0
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
|
|
4
|
+
# This source code is licensed under the license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from typing import Optional, Tuple, Type
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from torch import nn
|
|
11
|
+
|
|
12
|
+
from sam2.modeling.position_encoding import PositionEmbeddingRandom
|
|
13
|
+
|
|
14
|
+
from sam2.modeling.sam2_utils import LayerNorm2d
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class PromptEncoder(nn.Module):
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
embed_dim: int,
|
|
21
|
+
image_embedding_size: Tuple[int, int],
|
|
22
|
+
input_image_size: Tuple[int, int],
|
|
23
|
+
mask_in_chans: int,
|
|
24
|
+
activation: Type[nn.Module] = nn.GELU,
|
|
25
|
+
) -> None:
|
|
26
|
+
"""
|
|
27
|
+
Encodes prompts for input to SAM's mask decoder.
|
|
28
|
+
|
|
29
|
+
Arguments:
|
|
30
|
+
embed_dim (int): The prompts' embedding dimension
|
|
31
|
+
image_embedding_size (tuple(int, int)): The spatial size of the
|
|
32
|
+
image embedding, as (H, W).
|
|
33
|
+
input_image_size (int): The padded size of the image as input
|
|
34
|
+
to the image encoder, as (H, W).
|
|
35
|
+
mask_in_chans (int): The number of hidden channels used for
|
|
36
|
+
encoding input masks.
|
|
37
|
+
activation (nn.Module): The activation to use when encoding
|
|
38
|
+
input masks.
|
|
39
|
+
"""
|
|
40
|
+
super().__init__()
|
|
41
|
+
self.embed_dim = embed_dim
|
|
42
|
+
self.input_image_size = input_image_size
|
|
43
|
+
self.image_embedding_size = image_embedding_size
|
|
44
|
+
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
|
|
45
|
+
|
|
46
|
+
self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
|
|
47
|
+
point_embeddings = [
|
|
48
|
+
nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
|
|
49
|
+
]
|
|
50
|
+
self.point_embeddings = nn.ModuleList(point_embeddings)
|
|
51
|
+
self.not_a_point_embed = nn.Embedding(1, embed_dim)
|
|
52
|
+
|
|
53
|
+
self.mask_input_size = (
|
|
54
|
+
4 * image_embedding_size[0],
|
|
55
|
+
4 * image_embedding_size[1],
|
|
56
|
+
)
|
|
57
|
+
self.mask_downscaling = nn.Sequential(
|
|
58
|
+
nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
|
|
59
|
+
LayerNorm2d(mask_in_chans // 4),
|
|
60
|
+
activation(),
|
|
61
|
+
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
|
|
62
|
+
LayerNorm2d(mask_in_chans),
|
|
63
|
+
activation(),
|
|
64
|
+
nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
|
|
65
|
+
)
|
|
66
|
+
self.no_mask_embed = nn.Embedding(1, embed_dim)
|
|
67
|
+
|
|
68
|
+
def get_dense_pe(self) -> torch.Tensor:
|
|
69
|
+
"""
|
|
70
|
+
Returns the positional encoding used to encode point prompts,
|
|
71
|
+
applied to a dense set of points the shape of the image encoding.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
torch.Tensor: Positional encoding with shape
|
|
75
|
+
1x(embed_dim)x(embedding_h)x(embedding_w)
|
|
76
|
+
"""
|
|
77
|
+
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
|
|
78
|
+
|
|
79
|
+
def _embed_points(
|
|
80
|
+
self,
|
|
81
|
+
points: torch.Tensor,
|
|
82
|
+
labels: torch.Tensor,
|
|
83
|
+
pad: bool,
|
|
84
|
+
) -> torch.Tensor:
|
|
85
|
+
"""Embeds point prompts."""
|
|
86
|
+
points = points + 0.5 # Shift to center of pixel
|
|
87
|
+
if pad:
|
|
88
|
+
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
|
|
89
|
+
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
|
|
90
|
+
points = torch.cat([points, padding_point], dim=1)
|
|
91
|
+
labels = torch.cat([labels, padding_label], dim=1)
|
|
92
|
+
point_embedding = self.pe_layer.forward_with_coords(
|
|
93
|
+
points, self.input_image_size
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
point_embedding = torch.where(
|
|
97
|
+
(labels == -1).unsqueeze(-1),
|
|
98
|
+
torch.zeros_like(point_embedding) + self.not_a_point_embed.weight,
|
|
99
|
+
point_embedding,
|
|
100
|
+
)
|
|
101
|
+
point_embedding = torch.where(
|
|
102
|
+
(labels == 0).unsqueeze(-1),
|
|
103
|
+
point_embedding + self.point_embeddings[0].weight,
|
|
104
|
+
point_embedding,
|
|
105
|
+
)
|
|
106
|
+
point_embedding = torch.where(
|
|
107
|
+
(labels == 1).unsqueeze(-1),
|
|
108
|
+
point_embedding + self.point_embeddings[1].weight,
|
|
109
|
+
point_embedding,
|
|
110
|
+
)
|
|
111
|
+
point_embedding = torch.where(
|
|
112
|
+
(labels == 2).unsqueeze(-1),
|
|
113
|
+
point_embedding + self.point_embeddings[2].weight,
|
|
114
|
+
point_embedding,
|
|
115
|
+
)
|
|
116
|
+
point_embedding = torch.where(
|
|
117
|
+
(labels == 3).unsqueeze(-1),
|
|
118
|
+
point_embedding + self.point_embeddings[3].weight,
|
|
119
|
+
point_embedding,
|
|
120
|
+
)
|
|
121
|
+
return point_embedding
|
|
122
|
+
|
|
123
|
+
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
|
|
124
|
+
"""Embeds box prompts."""
|
|
125
|
+
boxes = boxes + 0.5 # Shift to center of pixel
|
|
126
|
+
coords = boxes.reshape(-1, 2, 2)
|
|
127
|
+
corner_embedding = self.pe_layer.forward_with_coords(
|
|
128
|
+
coords, self.input_image_size
|
|
129
|
+
)
|
|
130
|
+
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
|
|
131
|
+
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
|
|
132
|
+
return corner_embedding
|
|
133
|
+
|
|
134
|
+
def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
|
|
135
|
+
"""Embeds mask inputs."""
|
|
136
|
+
mask_embedding = self.mask_downscaling(masks)
|
|
137
|
+
return mask_embedding
|
|
138
|
+
|
|
139
|
+
def _get_batch_size(
|
|
140
|
+
self,
|
|
141
|
+
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
|
142
|
+
boxes: Optional[torch.Tensor],
|
|
143
|
+
masks: Optional[torch.Tensor],
|
|
144
|
+
) -> int:
|
|
145
|
+
"""
|
|
146
|
+
Gets the batch size of the output given the batch size of the input prompts.
|
|
147
|
+
"""
|
|
148
|
+
if points is not None:
|
|
149
|
+
return points[0].shape[0]
|
|
150
|
+
elif boxes is not None:
|
|
151
|
+
return boxes.shape[0]
|
|
152
|
+
elif masks is not None:
|
|
153
|
+
return masks.shape[0]
|
|
154
|
+
else:
|
|
155
|
+
return 1
|
|
156
|
+
|
|
157
|
+
def _get_device(self) -> torch.device:
|
|
158
|
+
return self.point_embeddings[0].weight.device
|
|
159
|
+
|
|
160
|
+
def forward(
|
|
161
|
+
self,
|
|
162
|
+
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
|
163
|
+
boxes: Optional[torch.Tensor],
|
|
164
|
+
masks: Optional[torch.Tensor],
|
|
165
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
166
|
+
"""
|
|
167
|
+
Embeds different types of prompts, returning both sparse and dense
|
|
168
|
+
embeddings.
|
|
169
|
+
|
|
170
|
+
Arguments:
|
|
171
|
+
points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
|
|
172
|
+
and labels to embed.
|
|
173
|
+
boxes (torch.Tensor or none): boxes to embed
|
|
174
|
+
masks (torch.Tensor or none): masks to embed
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
torch.Tensor: sparse embeddings for the points and boxes, with shape
|
|
178
|
+
BxNx(embed_dim), where N is determined by the number of input points
|
|
179
|
+
and boxes.
|
|
180
|
+
torch.Tensor: dense embeddings for the masks, in the shape
|
|
181
|
+
Bx(embed_dim)x(embed_H)x(embed_W)
|
|
182
|
+
"""
|
|
183
|
+
bs = self._get_batch_size(points, boxes, masks)
|
|
184
|
+
sparse_embeddings = torch.empty(
|
|
185
|
+
(bs, 0, self.embed_dim), device=self._get_device()
|
|
186
|
+
)
|
|
187
|
+
if points is not None:
|
|
188
|
+
coords, labels = points
|
|
189
|
+
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
|
|
190
|
+
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
|
|
191
|
+
if boxes is not None:
|
|
192
|
+
box_embeddings = self._embed_boxes(boxes)
|
|
193
|
+
sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
|
|
194
|
+
|
|
195
|
+
if masks is not None:
|
|
196
|
+
dense_embeddings = self._embed_masks(masks)
|
|
197
|
+
else:
|
|
198
|
+
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
|
|
199
|
+
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
return sparse_embeddings, dense_embeddings
|
|
@@ -0,0 +1,311 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
|
|
4
|
+
# This source code is licensed under the license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import math
|
|
8
|
+
from functools import partial
|
|
9
|
+
from typing import Tuple, Type
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
import torch.nn.functional as F
|
|
13
|
+
from torch import nn, Tensor
|
|
14
|
+
|
|
15
|
+
from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
|
|
16
|
+
from sam2.modeling.sam2_utils import MLP
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TwoWayTransformer(nn.Module):
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
depth: int,
|
|
23
|
+
embedding_dim: int,
|
|
24
|
+
num_heads: int,
|
|
25
|
+
mlp_dim: int,
|
|
26
|
+
activation: Type[nn.Module] = nn.ReLU,
|
|
27
|
+
attention_downsample_rate: int = 2,
|
|
28
|
+
) -> None:
|
|
29
|
+
"""
|
|
30
|
+
A transformer decoder that attends to an input image using
|
|
31
|
+
queries whose positional embedding is supplied.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
depth (int): number of layers in the transformer
|
|
35
|
+
embedding_dim (int): the channel dimension for the input embeddings
|
|
36
|
+
num_heads (int): the number of heads for multihead attention. Must
|
|
37
|
+
divide embedding_dim
|
|
38
|
+
mlp_dim (int): the channel dimension internal to the MLP block
|
|
39
|
+
activation (nn.Module): the activation to use in the MLP block
|
|
40
|
+
"""
|
|
41
|
+
super().__init__()
|
|
42
|
+
self.depth = depth
|
|
43
|
+
self.embedding_dim = embedding_dim
|
|
44
|
+
self.num_heads = num_heads
|
|
45
|
+
self.mlp_dim = mlp_dim
|
|
46
|
+
self.layers = nn.ModuleList()
|
|
47
|
+
|
|
48
|
+
for i in range(depth):
|
|
49
|
+
self.layers.append(
|
|
50
|
+
TwoWayAttentionBlock(
|
|
51
|
+
embedding_dim=embedding_dim,
|
|
52
|
+
num_heads=num_heads,
|
|
53
|
+
mlp_dim=mlp_dim,
|
|
54
|
+
activation=activation,
|
|
55
|
+
attention_downsample_rate=attention_downsample_rate,
|
|
56
|
+
skip_first_layer_pe=(i == 0),
|
|
57
|
+
)
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
self.final_attn_token_to_image = Attention(
|
|
61
|
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
|
62
|
+
)
|
|
63
|
+
self.norm_final_attn = nn.LayerNorm(embedding_dim)
|
|
64
|
+
|
|
65
|
+
def forward(
|
|
66
|
+
self,
|
|
67
|
+
image_embedding: Tensor,
|
|
68
|
+
image_pe: Tensor,
|
|
69
|
+
point_embedding: Tensor,
|
|
70
|
+
) -> Tuple[Tensor, Tensor]:
|
|
71
|
+
"""
|
|
72
|
+
Args:
|
|
73
|
+
image_embedding (torch.Tensor): image to attend to. Should be shape
|
|
74
|
+
B x embedding_dim x h x w for any h and w.
|
|
75
|
+
image_pe (torch.Tensor): the positional encoding to add to the image. Must
|
|
76
|
+
have the same shape as image_embedding.
|
|
77
|
+
point_embedding (torch.Tensor): the embedding to add to the query points.
|
|
78
|
+
Must have shape B x N_points x embedding_dim for any N_points.
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
torch.Tensor: the processed point_embedding
|
|
82
|
+
torch.Tensor: the processed image_embedding
|
|
83
|
+
"""
|
|
84
|
+
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
|
|
85
|
+
bs, c, h, w = image_embedding.shape
|
|
86
|
+
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
|
|
87
|
+
image_pe = image_pe.flatten(2).permute(0, 2, 1)
|
|
88
|
+
|
|
89
|
+
# Prepare queries
|
|
90
|
+
queries = point_embedding
|
|
91
|
+
keys = image_embedding
|
|
92
|
+
|
|
93
|
+
# Apply transformer blocks and final layernorm
|
|
94
|
+
for layer in self.layers:
|
|
95
|
+
queries, keys = layer(
|
|
96
|
+
queries=queries,
|
|
97
|
+
keys=keys,
|
|
98
|
+
query_pe=point_embedding,
|
|
99
|
+
key_pe=image_pe,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
# Apply the final attention layer from the points to the image
|
|
103
|
+
q = queries + point_embedding
|
|
104
|
+
k = keys + image_pe
|
|
105
|
+
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
|
|
106
|
+
queries = queries + attn_out
|
|
107
|
+
queries = self.norm_final_attn(queries)
|
|
108
|
+
|
|
109
|
+
return queries, keys
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class TwoWayAttentionBlock(nn.Module):
|
|
113
|
+
def __init__(
|
|
114
|
+
self,
|
|
115
|
+
embedding_dim: int,
|
|
116
|
+
num_heads: int,
|
|
117
|
+
mlp_dim: int = 2048,
|
|
118
|
+
activation: Type[nn.Module] = nn.ReLU,
|
|
119
|
+
attention_downsample_rate: int = 2,
|
|
120
|
+
skip_first_layer_pe: bool = False,
|
|
121
|
+
) -> None:
|
|
122
|
+
"""
|
|
123
|
+
A transformer block with four layers: (1) self-attention of sparse
|
|
124
|
+
inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
|
|
125
|
+
block on sparse inputs, and (4) cross attention of dense inputs to sparse
|
|
126
|
+
inputs.
|
|
127
|
+
|
|
128
|
+
Arguments:
|
|
129
|
+
embedding_dim (int): the channel dimension of the embeddings
|
|
130
|
+
num_heads (int): the number of heads in the attention layers
|
|
131
|
+
mlp_dim (int): the hidden dimension of the mlp block
|
|
132
|
+
activation (nn.Module): the activation of the mlp block
|
|
133
|
+
skip_first_layer_pe (bool): skip the PE on the first layer
|
|
134
|
+
"""
|
|
135
|
+
super().__init__()
|
|
136
|
+
self.self_attn = Attention(embedding_dim, num_heads)
|
|
137
|
+
self.norm1 = nn.LayerNorm(embedding_dim)
|
|
138
|
+
|
|
139
|
+
self.cross_attn_token_to_image = Attention(
|
|
140
|
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
|
141
|
+
)
|
|
142
|
+
self.norm2 = nn.LayerNorm(embedding_dim)
|
|
143
|
+
|
|
144
|
+
self.mlp = MLP(
|
|
145
|
+
embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation
|
|
146
|
+
)
|
|
147
|
+
self.norm3 = nn.LayerNorm(embedding_dim)
|
|
148
|
+
|
|
149
|
+
self.norm4 = nn.LayerNorm(embedding_dim)
|
|
150
|
+
self.cross_attn_image_to_token = Attention(
|
|
151
|
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
self.skip_first_layer_pe = skip_first_layer_pe
|
|
155
|
+
|
|
156
|
+
def forward(
|
|
157
|
+
self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
|
|
158
|
+
) -> Tuple[Tensor, Tensor]:
|
|
159
|
+
# Self attention block
|
|
160
|
+
if self.skip_first_layer_pe:
|
|
161
|
+
queries = self.self_attn(q=queries, k=queries, v=queries)
|
|
162
|
+
else:
|
|
163
|
+
q = queries + query_pe
|
|
164
|
+
attn_out = self.self_attn(q=q, k=q, v=queries)
|
|
165
|
+
queries = queries + attn_out
|
|
166
|
+
queries = self.norm1(queries)
|
|
167
|
+
|
|
168
|
+
# Cross attention block, tokens attending to image embedding
|
|
169
|
+
q = queries + query_pe
|
|
170
|
+
k = keys + key_pe
|
|
171
|
+
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
|
|
172
|
+
queries = queries + attn_out
|
|
173
|
+
queries = self.norm2(queries)
|
|
174
|
+
|
|
175
|
+
# MLP block
|
|
176
|
+
mlp_out = self.mlp(queries)
|
|
177
|
+
queries = queries + mlp_out
|
|
178
|
+
queries = self.norm3(queries)
|
|
179
|
+
|
|
180
|
+
# Cross attention block, image embedding attending to tokens
|
|
181
|
+
q = queries + query_pe
|
|
182
|
+
k = keys + key_pe
|
|
183
|
+
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
|
|
184
|
+
keys = keys + attn_out
|
|
185
|
+
keys = self.norm4(keys)
|
|
186
|
+
|
|
187
|
+
return queries, keys
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class Attention(nn.Module):
|
|
191
|
+
"""
|
|
192
|
+
An attention layer that allows for downscaling the size of the embedding
|
|
193
|
+
after projection to queries, keys, and values.
|
|
194
|
+
"""
|
|
195
|
+
|
|
196
|
+
def __init__(
|
|
197
|
+
self,
|
|
198
|
+
embedding_dim: int,
|
|
199
|
+
num_heads: int,
|
|
200
|
+
downsample_rate: int = 1,
|
|
201
|
+
dropout: float = 0.0,
|
|
202
|
+
kv_in_dim: int = None,
|
|
203
|
+
) -> None:
|
|
204
|
+
super().__init__()
|
|
205
|
+
self.embedding_dim = embedding_dim
|
|
206
|
+
self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
|
|
207
|
+
self.internal_dim = embedding_dim // downsample_rate
|
|
208
|
+
self.num_heads = num_heads
|
|
209
|
+
assert (
|
|
210
|
+
self.internal_dim % num_heads == 0
|
|
211
|
+
), "num_heads must divide embedding_dim."
|
|
212
|
+
|
|
213
|
+
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
|
|
214
|
+
self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
|
|
215
|
+
self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
|
|
216
|
+
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
|
|
217
|
+
|
|
218
|
+
self.dropout_p = dropout
|
|
219
|
+
|
|
220
|
+
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
|
|
221
|
+
b, n, c = x.shape
|
|
222
|
+
x = x.reshape(b, n, num_heads, c // num_heads)
|
|
223
|
+
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
|
|
224
|
+
|
|
225
|
+
def _recombine_heads(self, x: Tensor) -> Tensor:
|
|
226
|
+
b, n_heads, n_tokens, c_per_head = x.shape
|
|
227
|
+
x = x.transpose(1, 2)
|
|
228
|
+
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
|
|
229
|
+
|
|
230
|
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
|
231
|
+
# Input projections
|
|
232
|
+
q = self.q_proj(q)
|
|
233
|
+
k = self.k_proj(k)
|
|
234
|
+
v = self.v_proj(v)
|
|
235
|
+
|
|
236
|
+
# Separate into heads
|
|
237
|
+
q = self._separate_heads(q, self.num_heads)
|
|
238
|
+
k = self._separate_heads(k, self.num_heads)
|
|
239
|
+
v = self._separate_heads(v, self.num_heads)
|
|
240
|
+
|
|
241
|
+
dropout_p = self.dropout_p if self.training else 0.0
|
|
242
|
+
# Attention
|
|
243
|
+
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
|
244
|
+
|
|
245
|
+
out = self._recombine_heads(out)
|
|
246
|
+
out = self.out_proj(out)
|
|
247
|
+
|
|
248
|
+
return out
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
class RoPEAttention(Attention):
|
|
252
|
+
"""Attention with rotary position encoding."""
|
|
253
|
+
|
|
254
|
+
def __init__(
|
|
255
|
+
self,
|
|
256
|
+
*args,
|
|
257
|
+
rope_theta=10000.0,
|
|
258
|
+
# whether to repeat q rope to match k length
|
|
259
|
+
# this is needed for cross-attention to memories
|
|
260
|
+
rope_k_repeat=False,
|
|
261
|
+
feat_sizes=(64, 64), # [w, h] for stride 16 feats at 1024 resolution
|
|
262
|
+
**kwargs,
|
|
263
|
+
):
|
|
264
|
+
super().__init__(*args, **kwargs)
|
|
265
|
+
|
|
266
|
+
self.compute_cis = partial(
|
|
267
|
+
compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
|
|
268
|
+
)
|
|
269
|
+
freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
|
|
270
|
+
self.freqs_cis = (
|
|
271
|
+
freqs_cis.to("cuda") if torch.cuda.is_available() else freqs_cis
|
|
272
|
+
)
|
|
273
|
+
self.rope_k_repeat = rope_k_repeat
|
|
274
|
+
|
|
275
|
+
def forward(
|
|
276
|
+
self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0
|
|
277
|
+
) -> Tensor:
|
|
278
|
+
# Input projections
|
|
279
|
+
q = self.q_proj(q)
|
|
280
|
+
k = self.k_proj(k)
|
|
281
|
+
v = self.v_proj(v)
|
|
282
|
+
|
|
283
|
+
# Separate into heads
|
|
284
|
+
q = self._separate_heads(q, self.num_heads)
|
|
285
|
+
k = self._separate_heads(k, self.num_heads)
|
|
286
|
+
v = self._separate_heads(v, self.num_heads)
|
|
287
|
+
|
|
288
|
+
# Apply rotary position encoding
|
|
289
|
+
w = h = math.sqrt(q.shape[-2])
|
|
290
|
+
self.freqs_cis = self.freqs_cis.to(q.device)
|
|
291
|
+
if self.freqs_cis.shape[0] != q.shape[-2]:
|
|
292
|
+
self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
|
|
293
|
+
if q.shape[-2] != k.shape[-2]:
|
|
294
|
+
assert self.rope_k_repeat
|
|
295
|
+
|
|
296
|
+
num_k_rope = k.size(-2) - num_k_exclude_rope
|
|
297
|
+
q, k[:, :, :num_k_rope] = apply_rotary_enc(
|
|
298
|
+
q,
|
|
299
|
+
k[:, :, :num_k_rope],
|
|
300
|
+
freqs_cis=self.freqs_cis,
|
|
301
|
+
repeat_freqs_k=self.rope_k_repeat,
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
dropout_p = self.dropout_p if self.training else 0.0
|
|
305
|
+
# Attention
|
|
306
|
+
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
|
307
|
+
|
|
308
|
+
out = self._recombine_heads(out)
|
|
309
|
+
out = self.out_proj(out)
|
|
310
|
+
|
|
311
|
+
return out
|