ultralytics 8.2.69__py3-none-any.whl → 8.2.71__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ultralytics might be problematic. Click here for more details.
- ultralytics/__init__.py +3 -2
- ultralytics/cfg/__init__.py +4 -0
- ultralytics/data/converter.py +81 -0
- ultralytics/engine/trainer.py +3 -2
- ultralytics/engine/validator.py +2 -2
- ultralytics/models/__init__.py +2 -1
- ultralytics/models/fastsam/predict.py +1 -0
- ultralytics/models/sam/build.py +2 -2
- ultralytics/models/sam/model.py +10 -2
- ultralytics/models/sam/modules/decoders.py +1 -42
- ultralytics/models/sam/modules/encoders.py +3 -1
- ultralytics/models/sam/modules/sam.py +5 -7
- ultralytics/models/sam/modules/transformer.py +4 -3
- ultralytics/models/sam/predict.py +12 -6
- ultralytics/models/sam2/__init__.py +6 -0
- ultralytics/models/sam2/build.py +156 -0
- ultralytics/models/sam2/model.py +97 -0
- ultralytics/models/sam2/modules/__init__.py +1 -0
- ultralytics/models/sam2/modules/decoders.py +305 -0
- ultralytics/models/sam2/modules/encoders.py +332 -0
- ultralytics/models/sam2/modules/memory_attention.py +170 -0
- ultralytics/models/sam2/modules/sam2.py +804 -0
- ultralytics/models/sam2/modules/sam2_blocks.py +715 -0
- ultralytics/models/sam2/modules/utils.py +191 -0
- ultralytics/models/sam2/predict.py +182 -0
- ultralytics/nn/modules/transformer.py +5 -3
- ultralytics/utils/__init__.py +9 -9
- ultralytics/utils/plotting.py +1 -1
- ultralytics/utils/torch_utils.py +11 -7
- {ultralytics-8.2.69.dist-info → ultralytics-8.2.71.dist-info}/METADATA +1 -1
- {ultralytics-8.2.69.dist-info → ultralytics-8.2.71.dist-info}/RECORD +35 -24
- {ultralytics-8.2.69.dist-info → ultralytics-8.2.71.dist-info}/LICENSE +0 -0
- {ultralytics-8.2.69.dist-info → ultralytics-8.2.71.dist-info}/WHEEL +0 -0
- {ultralytics-8.2.69.dist-info → ultralytics-8.2.71.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.2.69.dist-info → ultralytics-8.2.71.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,332 @@
|
|
|
1
|
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
2
|
+
|
|
3
|
+
from typing import List, Optional, Tuple
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
|
|
9
|
+
from ultralytics.models.sam.modules.encoders import PatchEmbed
|
|
10
|
+
|
|
11
|
+
from .sam2_blocks import CXBlock, Fuser, MaskDownSampler, MultiScaleBlock, PositionEmbeddingSine
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class MemoryEncoder(nn.Module):
|
|
15
|
+
"""Encodes pixel features and masks into a memory representation for efficient image segmentation."""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
out_dim,
|
|
20
|
+
in_dim=256, # in_dim of pix_feats
|
|
21
|
+
):
|
|
22
|
+
"""Initializes the MemoryEncoder module for encoding pixel features and masks in SAM-like models."""
|
|
23
|
+
super().__init__()
|
|
24
|
+
|
|
25
|
+
self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1)
|
|
26
|
+
|
|
27
|
+
self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
|
|
28
|
+
self.fuser = Fuser(CXBlock(dim=256), num_layers=2)
|
|
29
|
+
self.position_encoding = PositionEmbeddingSine(num_pos_feats=64)
|
|
30
|
+
self.out_proj = nn.Identity()
|
|
31
|
+
if out_dim != in_dim:
|
|
32
|
+
self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
|
|
33
|
+
|
|
34
|
+
def forward(
|
|
35
|
+
self,
|
|
36
|
+
pix_feat: torch.Tensor,
|
|
37
|
+
masks: torch.Tensor,
|
|
38
|
+
skip_mask_sigmoid: bool = False,
|
|
39
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
40
|
+
"""Processes pixel features and masks, fusing them to generate encoded memory representations."""
|
|
41
|
+
if not skip_mask_sigmoid:
|
|
42
|
+
masks = F.sigmoid(masks)
|
|
43
|
+
masks = self.mask_downsampler(masks)
|
|
44
|
+
|
|
45
|
+
# Fuse pix_feats and downsampled masks, in case the visual features are on CPU, cast them to CUDA
|
|
46
|
+
pix_feat = pix_feat.to(masks.device)
|
|
47
|
+
|
|
48
|
+
x = self.pix_feat_proj(pix_feat)
|
|
49
|
+
x = x + masks
|
|
50
|
+
x = self.fuser(x)
|
|
51
|
+
x = self.out_proj(x)
|
|
52
|
+
|
|
53
|
+
pos = self.position_encoding(x).to(x.dtype)
|
|
54
|
+
|
|
55
|
+
return {"vision_features": x, "vision_pos_enc": [pos]}
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class ImageEncoder(nn.Module):
|
|
59
|
+
"""Encodes images using a trunk-neck architecture, producing multiscale features and positional encodings."""
|
|
60
|
+
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
trunk: nn.Module,
|
|
64
|
+
neck: nn.Module,
|
|
65
|
+
scalp: int = 0,
|
|
66
|
+
):
|
|
67
|
+
"""Initializes an image encoder with a trunk, neck, and optional scalp for feature extraction."""
|
|
68
|
+
super().__init__()
|
|
69
|
+
self.trunk = trunk
|
|
70
|
+
self.neck = neck
|
|
71
|
+
self.scalp = scalp
|
|
72
|
+
assert (
|
|
73
|
+
self.trunk.channel_list == self.neck.backbone_channel_list
|
|
74
|
+
), f"Channel dims of trunk {self.trunk.channel_list} and neck {self.neck.backbone_channel_list} do not match."
|
|
75
|
+
|
|
76
|
+
def forward(self, sample: torch.Tensor):
|
|
77
|
+
"""Processes image input through trunk and neck, returning features, positional encodings, and FPN outputs."""
|
|
78
|
+
features, pos = self.neck(self.trunk(sample))
|
|
79
|
+
if self.scalp > 0:
|
|
80
|
+
# Discard the lowest resolution features
|
|
81
|
+
features, pos = features[: -self.scalp], pos[: -self.scalp]
|
|
82
|
+
|
|
83
|
+
src = features[-1]
|
|
84
|
+
output = {
|
|
85
|
+
"vision_features": src,
|
|
86
|
+
"vision_pos_enc": pos,
|
|
87
|
+
"backbone_fpn": features,
|
|
88
|
+
}
|
|
89
|
+
return output
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class FpnNeck(nn.Module):
|
|
93
|
+
"""Feature Pyramid Network (FPN) neck variant for multiscale feature fusion in object detection models."""
|
|
94
|
+
|
|
95
|
+
def __init__(
|
|
96
|
+
self,
|
|
97
|
+
d_model: int,
|
|
98
|
+
backbone_channel_list: List[int],
|
|
99
|
+
kernel_size: int = 1,
|
|
100
|
+
stride: int = 1,
|
|
101
|
+
padding: int = 0,
|
|
102
|
+
fpn_interp_model: str = "bilinear",
|
|
103
|
+
fuse_type: str = "sum",
|
|
104
|
+
fpn_top_down_levels: Optional[List[int]] = None,
|
|
105
|
+
):
|
|
106
|
+
"""
|
|
107
|
+
Initializes a modified Feature Pyramid Network (FPN) neck.
|
|
108
|
+
|
|
109
|
+
This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing,
|
|
110
|
+
similar to ViT positional embedding interpolation.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
d_model (int): Dimension of the model.
|
|
114
|
+
backbone_channel_list (List[int]): List of channel dimensions from the backbone.
|
|
115
|
+
kernel_size (int): Kernel size for the convolutional layers.
|
|
116
|
+
stride (int): Stride for the convolutional layers.
|
|
117
|
+
padding (int): Padding for the convolutional layers.
|
|
118
|
+
fpn_interp_model (str): Interpolation mode for FPN feature resizing.
|
|
119
|
+
fuse_type (str): Type of feature fusion, either 'sum' or 'avg'.
|
|
120
|
+
fpn_top_down_levels (Optional[List[int]]): Levels to have top-down features in outputs.
|
|
121
|
+
|
|
122
|
+
Attributes:
|
|
123
|
+
position_encoding (PositionEmbeddingSine): Sinusoidal positional encoding.
|
|
124
|
+
convs (nn.ModuleList): List of convolutional layers for each backbone level.
|
|
125
|
+
backbone_channel_list (List[int]): List of channel dimensions from the backbone.
|
|
126
|
+
fpn_interp_model (str): Interpolation mode for FPN feature resizing.
|
|
127
|
+
fuse_type (str): Type of feature fusion.
|
|
128
|
+
fpn_top_down_levels (List[int]): Levels with top-down feature propagation.
|
|
129
|
+
|
|
130
|
+
Examples:
|
|
131
|
+
>>> backbone_channels = [64, 128, 256, 512]
|
|
132
|
+
>>> fpn_neck = FpnNeck(256, backbone_channels)
|
|
133
|
+
>>> print(fpn_neck)
|
|
134
|
+
"""
|
|
135
|
+
super().__init__()
|
|
136
|
+
self.position_encoding = PositionEmbeddingSine(num_pos_feats=256)
|
|
137
|
+
self.convs = nn.ModuleList()
|
|
138
|
+
self.backbone_channel_list = backbone_channel_list
|
|
139
|
+
for dim in backbone_channel_list:
|
|
140
|
+
current = nn.Sequential()
|
|
141
|
+
current.add_module(
|
|
142
|
+
"conv",
|
|
143
|
+
nn.Conv2d(
|
|
144
|
+
in_channels=dim,
|
|
145
|
+
out_channels=d_model,
|
|
146
|
+
kernel_size=kernel_size,
|
|
147
|
+
stride=stride,
|
|
148
|
+
padding=padding,
|
|
149
|
+
),
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
self.convs.append(current)
|
|
153
|
+
self.fpn_interp_model = fpn_interp_model
|
|
154
|
+
assert fuse_type in ["sum", "avg"]
|
|
155
|
+
self.fuse_type = fuse_type
|
|
156
|
+
|
|
157
|
+
# levels to have top-down features in its outputs
|
|
158
|
+
# e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
|
|
159
|
+
# have top-down propagation, while outputs of level 0 and level 1 have only
|
|
160
|
+
# lateral features from the same backbone level.
|
|
161
|
+
if fpn_top_down_levels is None:
|
|
162
|
+
# default is to have top-down features on all levels
|
|
163
|
+
fpn_top_down_levels = range(len(self.convs))
|
|
164
|
+
self.fpn_top_down_levels = list(fpn_top_down_levels)
|
|
165
|
+
|
|
166
|
+
def forward(self, xs: List[torch.Tensor]):
|
|
167
|
+
"""
|
|
168
|
+
Performs forward pass through the Feature Pyramid Network (FPN) neck.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
xs (List[torch.Tensor]): List of input tensors from the backbone, with shape (B, C, H, W) for each tensor.
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
(Tuple[List[torch.Tensor], List[torch.Tensor]]): A tuple containing two lists:
|
|
175
|
+
- out: List of output feature maps after FPN processing, with shape (B, d_model, H, W) for each tensor.
|
|
176
|
+
- pos: List of positional encodings corresponding to each output feature map.
|
|
177
|
+
|
|
178
|
+
Examples:
|
|
179
|
+
>>> fpn_neck = FpnNeck(d_model=256, backbone_channel_list=[64, 128, 256, 512])
|
|
180
|
+
>>> inputs = [torch.rand(1, c, 32, 32) for c in [64, 128, 256, 512]]
|
|
181
|
+
>>> outputs, positions = fpn_neck(inputs)
|
|
182
|
+
"""
|
|
183
|
+
out = [None] * len(self.convs)
|
|
184
|
+
pos = [None] * len(self.convs)
|
|
185
|
+
assert len(xs) == len(self.convs)
|
|
186
|
+
# fpn forward pass
|
|
187
|
+
# see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
|
|
188
|
+
prev_features = None
|
|
189
|
+
# forward in top-down order (from low to high resolution)
|
|
190
|
+
n = len(self.convs) - 1
|
|
191
|
+
for i in range(n, -1, -1):
|
|
192
|
+
x = xs[i]
|
|
193
|
+
lateral_features = self.convs[n - i](x)
|
|
194
|
+
if i in self.fpn_top_down_levels and prev_features is not None:
|
|
195
|
+
top_down_features = F.interpolate(
|
|
196
|
+
prev_features.to(dtype=torch.float32),
|
|
197
|
+
scale_factor=2.0,
|
|
198
|
+
mode=self.fpn_interp_model,
|
|
199
|
+
align_corners=(None if self.fpn_interp_model == "nearest" else False),
|
|
200
|
+
antialias=False,
|
|
201
|
+
)
|
|
202
|
+
prev_features = lateral_features + top_down_features
|
|
203
|
+
if self.fuse_type == "avg":
|
|
204
|
+
prev_features /= 2
|
|
205
|
+
else:
|
|
206
|
+
prev_features = lateral_features
|
|
207
|
+
x_out = prev_features
|
|
208
|
+
out[i] = x_out
|
|
209
|
+
pos[i] = self.position_encoding(x_out).to(x_out.dtype)
|
|
210
|
+
|
|
211
|
+
return out, pos
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
class Hiera(nn.Module):
|
|
215
|
+
"""Hierarchical vision transformer for efficient multiscale feature extraction in image processing tasks."""
|
|
216
|
+
|
|
217
|
+
def __init__(
|
|
218
|
+
self,
|
|
219
|
+
embed_dim: int = 96, # initial embed dim
|
|
220
|
+
num_heads: int = 1, # initial number of heads
|
|
221
|
+
drop_path_rate: float = 0.0, # stochastic depth
|
|
222
|
+
q_pool: int = 3, # number of q_pool stages
|
|
223
|
+
q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
|
|
224
|
+
stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
|
|
225
|
+
dim_mul: float = 2.0, # dim_mul factor at stage shift
|
|
226
|
+
head_mul: float = 2.0, # head_mul factor at stage shift
|
|
227
|
+
window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
|
|
228
|
+
# window size per stage, when not using global att.
|
|
229
|
+
window_spec: Tuple[int, ...] = (
|
|
230
|
+
8,
|
|
231
|
+
4,
|
|
232
|
+
14,
|
|
233
|
+
7,
|
|
234
|
+
),
|
|
235
|
+
# global attn in these blocks
|
|
236
|
+
global_att_blocks: Tuple[int, ...] = (
|
|
237
|
+
12,
|
|
238
|
+
16,
|
|
239
|
+
20,
|
|
240
|
+
),
|
|
241
|
+
return_interm_layers=True, # return feats from every stage
|
|
242
|
+
):
|
|
243
|
+
"""Initializes a Hiera model with configurable architecture for hierarchical vision transformers."""
|
|
244
|
+
super().__init__()
|
|
245
|
+
|
|
246
|
+
assert len(stages) == len(window_spec)
|
|
247
|
+
self.window_spec = window_spec
|
|
248
|
+
|
|
249
|
+
depth = sum(stages)
|
|
250
|
+
self.q_stride = q_stride
|
|
251
|
+
self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
|
|
252
|
+
assert 0 <= q_pool <= len(self.stage_ends[:-1])
|
|
253
|
+
self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
|
|
254
|
+
self.return_interm_layers = return_interm_layers
|
|
255
|
+
|
|
256
|
+
self.patch_embed = PatchEmbed(
|
|
257
|
+
embed_dim=embed_dim,
|
|
258
|
+
kernel_size=(7, 7),
|
|
259
|
+
stride=(4, 4),
|
|
260
|
+
padding=(3, 3),
|
|
261
|
+
)
|
|
262
|
+
# Which blocks have global att?
|
|
263
|
+
self.global_att_blocks = global_att_blocks
|
|
264
|
+
|
|
265
|
+
# Windowed positional embedding (https://arxiv.org/abs/2311.05613)
|
|
266
|
+
self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
|
|
267
|
+
self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size))
|
|
268
|
+
self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]))
|
|
269
|
+
|
|
270
|
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
|
271
|
+
|
|
272
|
+
cur_stage = 1
|
|
273
|
+
self.blocks = nn.ModuleList()
|
|
274
|
+
|
|
275
|
+
for i in range(depth):
|
|
276
|
+
dim_out = embed_dim
|
|
277
|
+
# lags by a block, so first block of
|
|
278
|
+
# next stage uses an initial window size
|
|
279
|
+
# of previous stage and final window size of current stage
|
|
280
|
+
window_size = self.window_spec[cur_stage - 1]
|
|
281
|
+
|
|
282
|
+
if self.global_att_blocks is not None:
|
|
283
|
+
window_size = 0 if i in self.global_att_blocks else window_size
|
|
284
|
+
|
|
285
|
+
if i - 1 in self.stage_ends:
|
|
286
|
+
dim_out = int(embed_dim * dim_mul)
|
|
287
|
+
num_heads = int(num_heads * head_mul)
|
|
288
|
+
cur_stage += 1
|
|
289
|
+
|
|
290
|
+
block = MultiScaleBlock(
|
|
291
|
+
dim=embed_dim,
|
|
292
|
+
dim_out=dim_out,
|
|
293
|
+
num_heads=num_heads,
|
|
294
|
+
drop_path=dpr[i],
|
|
295
|
+
q_stride=self.q_stride if i in self.q_pool_blocks else None,
|
|
296
|
+
window_size=window_size,
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
embed_dim = dim_out
|
|
300
|
+
self.blocks.append(block)
|
|
301
|
+
|
|
302
|
+
self.channel_list = (
|
|
303
|
+
[self.blocks[i].dim_out for i in self.stage_ends[::-1]]
|
|
304
|
+
if return_interm_layers
|
|
305
|
+
else [self.blocks[-1].dim_out]
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
|
|
309
|
+
"""Generate positional embeddings by interpolating and combining window and background embeddings."""
|
|
310
|
+
h, w = hw
|
|
311
|
+
window_embed = self.pos_embed_window
|
|
312
|
+
pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
|
|
313
|
+
pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)])
|
|
314
|
+
pos_embed = pos_embed.permute(0, 2, 3, 1)
|
|
315
|
+
return pos_embed
|
|
316
|
+
|
|
317
|
+
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
|
318
|
+
"""Performs hierarchical vision transformer forward pass, returning multiscale feature maps."""
|
|
319
|
+
x = self.patch_embed(x)
|
|
320
|
+
# x: (B, H, W, C)
|
|
321
|
+
|
|
322
|
+
# Add pos embed
|
|
323
|
+
x = x + self._get_pos_embed(x.shape[1:3])
|
|
324
|
+
|
|
325
|
+
outputs = []
|
|
326
|
+
for i, blk in enumerate(self.blocks):
|
|
327
|
+
x = blk(x)
|
|
328
|
+
if (i == self.stage_ends[-1]) or (i in self.stage_ends and self.return_interm_layers):
|
|
329
|
+
feats = x.permute(0, 3, 1, 2)
|
|
330
|
+
outputs.append(feats)
|
|
331
|
+
|
|
332
|
+
return outputs
|
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch import Tensor, nn
|
|
8
|
+
|
|
9
|
+
from .sam2_blocks import RoPEAttention
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class MemoryAttentionLayer(nn.Module):
|
|
13
|
+
"""Implements a memory attention layer with self-attention and cross-attention mechanisms for neural networks."""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
d_model: int = 256,
|
|
18
|
+
dim_feedforward: int = 2048,
|
|
19
|
+
dropout: float = 0.1,
|
|
20
|
+
pos_enc_at_attn: bool = False,
|
|
21
|
+
pos_enc_at_cross_attn_keys: bool = True,
|
|
22
|
+
pos_enc_at_cross_attn_queries: bool = False,
|
|
23
|
+
):
|
|
24
|
+
"""Initializes a MemoryAttentionLayer with self-attention, cross-attention, and feedforward components."""
|
|
25
|
+
super().__init__()
|
|
26
|
+
self.d_model = d_model
|
|
27
|
+
self.dim_feedforward = dim_feedforward
|
|
28
|
+
self.dropout_value = dropout
|
|
29
|
+
self.self_attn = RoPEAttention(embedding_dim=256, num_heads=1, downsample_rate=1)
|
|
30
|
+
self.cross_attn_image = RoPEAttention(
|
|
31
|
+
rope_k_repeat=True,
|
|
32
|
+
embedding_dim=256,
|
|
33
|
+
num_heads=1,
|
|
34
|
+
downsample_rate=1,
|
|
35
|
+
kv_in_dim=64,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
# Implementation of Feedforward model
|
|
39
|
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
|
40
|
+
self.dropout = nn.Dropout(dropout)
|
|
41
|
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
|
42
|
+
|
|
43
|
+
self.norm1 = nn.LayerNorm(d_model)
|
|
44
|
+
self.norm2 = nn.LayerNorm(d_model)
|
|
45
|
+
self.norm3 = nn.LayerNorm(d_model)
|
|
46
|
+
self.dropout1 = nn.Dropout(dropout)
|
|
47
|
+
self.dropout2 = nn.Dropout(dropout)
|
|
48
|
+
self.dropout3 = nn.Dropout(dropout)
|
|
49
|
+
|
|
50
|
+
self.activation = nn.ReLU()
|
|
51
|
+
|
|
52
|
+
# Where to add pos enc
|
|
53
|
+
self.pos_enc_at_attn = pos_enc_at_attn
|
|
54
|
+
self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
|
|
55
|
+
self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
|
|
56
|
+
|
|
57
|
+
def _forward_sa(self, tgt, query_pos):
|
|
58
|
+
"""Performs self-attention on input tensor using positional encoding and RoPE attention mechanism."""
|
|
59
|
+
tgt2 = self.norm1(tgt)
|
|
60
|
+
q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
|
|
61
|
+
tgt2 = self.self_attn(q, k, v=tgt2)
|
|
62
|
+
tgt = tgt + self.dropout1(tgt2)
|
|
63
|
+
return tgt
|
|
64
|
+
|
|
65
|
+
def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
|
|
66
|
+
"""Performs cross-attention between target and memory tensors using RoPEAttention mechanism."""
|
|
67
|
+
kwds = {}
|
|
68
|
+
if num_k_exclude_rope > 0:
|
|
69
|
+
assert isinstance(self.cross_attn_image, RoPEAttention)
|
|
70
|
+
kwds = {"num_k_exclude_rope": num_k_exclude_rope}
|
|
71
|
+
|
|
72
|
+
# Cross-Attention
|
|
73
|
+
tgt2 = self.norm2(tgt)
|
|
74
|
+
tgt2 = self.cross_attn_image(
|
|
75
|
+
q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
|
|
76
|
+
k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
|
|
77
|
+
v=memory,
|
|
78
|
+
**kwds,
|
|
79
|
+
)
|
|
80
|
+
tgt = tgt + self.dropout2(tgt2)
|
|
81
|
+
return tgt
|
|
82
|
+
|
|
83
|
+
def forward(
|
|
84
|
+
self,
|
|
85
|
+
tgt,
|
|
86
|
+
memory,
|
|
87
|
+
pos: Optional[Tensor] = None,
|
|
88
|
+
query_pos: Optional[Tensor] = None,
|
|
89
|
+
num_k_exclude_rope: int = 0,
|
|
90
|
+
) -> torch.Tensor:
|
|
91
|
+
"""Performs self-attention, cross-attention, and MLP operations on input tensors for memory-based attention."""
|
|
92
|
+
tgt = self._forward_sa(tgt, query_pos)
|
|
93
|
+
tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
|
|
94
|
+
# MLP
|
|
95
|
+
tgt2 = self.norm3(tgt)
|
|
96
|
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
|
97
|
+
tgt = tgt + self.dropout3(tgt2)
|
|
98
|
+
return tgt
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class MemoryAttention(nn.Module):
|
|
102
|
+
"""Memory attention module for processing sequential data with self and cross-attention mechanisms."""
|
|
103
|
+
|
|
104
|
+
def __init__(
|
|
105
|
+
self,
|
|
106
|
+
d_model: int,
|
|
107
|
+
pos_enc_at_input: bool,
|
|
108
|
+
layer: nn.Module,
|
|
109
|
+
num_layers: int,
|
|
110
|
+
batch_first: bool = True, # Do layers expect batch first input?
|
|
111
|
+
):
|
|
112
|
+
"""Initializes MemoryAttention module with layers and normalization for attention processing."""
|
|
113
|
+
super().__init__()
|
|
114
|
+
self.d_model = d_model
|
|
115
|
+
self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)])
|
|
116
|
+
self.num_layers = num_layers
|
|
117
|
+
self.norm = nn.LayerNorm(d_model)
|
|
118
|
+
self.pos_enc_at_input = pos_enc_at_input
|
|
119
|
+
self.batch_first = batch_first
|
|
120
|
+
|
|
121
|
+
def forward(
|
|
122
|
+
self,
|
|
123
|
+
curr: torch.Tensor, # self-attention inputs
|
|
124
|
+
memory: torch.Tensor, # cross-attention inputs
|
|
125
|
+
curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
|
|
126
|
+
memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
|
|
127
|
+
num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
|
|
128
|
+
):
|
|
129
|
+
"""Applies self-attention and cross-attention to input tensors, processing through multiple layers."""
|
|
130
|
+
if isinstance(curr, list):
|
|
131
|
+
assert isinstance(curr_pos, list)
|
|
132
|
+
assert len(curr) == len(curr_pos) == 1
|
|
133
|
+
curr, curr_pos = (
|
|
134
|
+
curr[0],
|
|
135
|
+
curr_pos[0],
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
assert curr.shape[1] == memory.shape[1], "Batch size must be the same for curr and memory"
|
|
139
|
+
|
|
140
|
+
output = curr
|
|
141
|
+
if self.pos_enc_at_input and curr_pos is not None:
|
|
142
|
+
output = output + 0.1 * curr_pos
|
|
143
|
+
|
|
144
|
+
if self.batch_first:
|
|
145
|
+
# Convert to batch first
|
|
146
|
+
output = output.transpose(0, 1)
|
|
147
|
+
curr_pos = curr_pos.transpose(0, 1)
|
|
148
|
+
memory = memory.transpose(0, 1)
|
|
149
|
+
memory_pos = memory_pos.transpose(0, 1)
|
|
150
|
+
|
|
151
|
+
for layer in self.layers:
|
|
152
|
+
kwds = {}
|
|
153
|
+
if isinstance(layer.cross_attn_image, RoPEAttention):
|
|
154
|
+
kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
|
|
155
|
+
|
|
156
|
+
output = layer(
|
|
157
|
+
tgt=output,
|
|
158
|
+
memory=memory,
|
|
159
|
+
pos=memory_pos,
|
|
160
|
+
query_pos=curr_pos,
|
|
161
|
+
**kwds,
|
|
162
|
+
)
|
|
163
|
+
normed_output = self.norm(output)
|
|
164
|
+
|
|
165
|
+
if self.batch_first:
|
|
166
|
+
# Convert back to seq first
|
|
167
|
+
normed_output = normed_output.transpose(0, 1)
|
|
168
|
+
curr_pos = curr_pos.transpose(0, 1)
|
|
169
|
+
|
|
170
|
+
return normed_output
|