ultralytics 8.2.68__py3-none-any.whl → 8.2.70__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ultralytics might be problematic. Click here for more details.
- tests/test_cli.py +4 -16
- ultralytics/__init__.py +3 -2
- ultralytics/cfg/__init__.py +4 -0
- ultralytics/data/augment.py +1 -1
- ultralytics/hub/google/__init__.py +3 -3
- ultralytics/models/__init__.py +2 -1
- ultralytics/models/fastsam/__init__.py +1 -2
- ultralytics/models/fastsam/model.py +18 -0
- ultralytics/models/fastsam/predict.py +116 -1
- ultralytics/models/sam/build.py +2 -2
- ultralytics/models/sam/model.py +10 -2
- ultralytics/models/sam/modules/decoders.py +1 -42
- ultralytics/models/sam/modules/encoders.py +3 -1
- ultralytics/models/sam/modules/sam.py +5 -7
- ultralytics/models/sam/modules/transformer.py +4 -3
- ultralytics/models/sam/predict.py +12 -6
- ultralytics/models/sam2/__init__.py +6 -0
- ultralytics/models/sam2/build.py +156 -0
- ultralytics/models/sam2/model.py +97 -0
- ultralytics/models/sam2/modules/__init__.py +1 -0
- ultralytics/models/sam2/modules/decoders.py +305 -0
- ultralytics/models/sam2/modules/encoders.py +332 -0
- ultralytics/models/sam2/modules/memory_attention.py +170 -0
- ultralytics/models/sam2/modules/sam2.py +804 -0
- ultralytics/models/sam2/modules/sam2_blocks.py +715 -0
- ultralytics/models/sam2/modules/utils.py +191 -0
- ultralytics/models/sam2/predict.py +182 -0
- ultralytics/nn/modules/transformer.py +5 -3
- ultralytics/utils/ops.py +1 -1
- ultralytics/utils/torch_utils.py +9 -6
- {ultralytics-8.2.68.dist-info → ultralytics-8.2.70.dist-info}/METADATA +1 -1
- {ultralytics-8.2.68.dist-info → ultralytics-8.2.70.dist-info}/RECORD +36 -26
- {ultralytics-8.2.68.dist-info → ultralytics-8.2.70.dist-info}/WHEEL +1 -1
- ultralytics/models/fastsam/prompt.py +0 -352
- {ultralytics-8.2.68.dist-info → ultralytics-8.2.70.dist-info}/LICENSE +0 -0
- {ultralytics-8.2.68.dist-info → ultralytics-8.2.70.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.2.68.dist-info → ultralytics-8.2.70.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,715 @@
|
|
|
1
|
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
import math
|
|
5
|
+
from functools import partial
|
|
6
|
+
from typing import Optional, Tuple, Type, Union
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
from torch import Tensor, nn
|
|
11
|
+
|
|
12
|
+
from ultralytics.models.sam.modules.transformer import (
|
|
13
|
+
Attention,
|
|
14
|
+
)
|
|
15
|
+
from ultralytics.models.sam.modules.transformer import (
|
|
16
|
+
TwoWayAttentionBlock as SAMTwoWayAttentionBlock,
|
|
17
|
+
)
|
|
18
|
+
from ultralytics.models.sam.modules.transformer import (
|
|
19
|
+
TwoWayTransformer as SAMTwoWayTransformer,
|
|
20
|
+
)
|
|
21
|
+
from ultralytics.nn.modules import MLP, LayerNorm2d
|
|
22
|
+
|
|
23
|
+
from .utils import apply_rotary_enc, compute_axial_cis, window_partition, window_unpartition
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class DropPath(nn.Module):
|
|
27
|
+
"""Implements stochastic depth regularization for neural networks during training."""
|
|
28
|
+
|
|
29
|
+
def __init__(self, drop_prob=0.0, scale_by_keep=True):
|
|
30
|
+
"""Initialize DropPath module with specified drop probability and scaling option."""
|
|
31
|
+
super(DropPath, self).__init__()
|
|
32
|
+
self.drop_prob = drop_prob
|
|
33
|
+
self.scale_by_keep = scale_by_keep
|
|
34
|
+
|
|
35
|
+
def forward(self, x):
|
|
36
|
+
"""Applies stochastic depth to input tensor during training, with optional scaling."""
|
|
37
|
+
if self.drop_prob == 0.0 or not self.training:
|
|
38
|
+
return x
|
|
39
|
+
keep_prob = 1 - self.drop_prob
|
|
40
|
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
|
41
|
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
|
42
|
+
if keep_prob > 0.0 and self.scale_by_keep:
|
|
43
|
+
random_tensor.div_(keep_prob)
|
|
44
|
+
return x * random_tensor
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class MaskDownSampler(nn.Module):
|
|
48
|
+
"""Downsamples and embeds masks using convolutional layers and layer normalization for efficient processing."""
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
embed_dim=256,
|
|
53
|
+
kernel_size=4,
|
|
54
|
+
stride=4,
|
|
55
|
+
padding=0,
|
|
56
|
+
total_stride=16,
|
|
57
|
+
activation=nn.GELU,
|
|
58
|
+
):
|
|
59
|
+
"""Initializes a mask downsampler module for progressive downsampling and channel expansion."""
|
|
60
|
+
super().__init__()
|
|
61
|
+
num_layers = int(math.log2(total_stride) // math.log2(stride))
|
|
62
|
+
assert stride**num_layers == total_stride
|
|
63
|
+
self.encoder = nn.Sequential()
|
|
64
|
+
mask_in_chans, mask_out_chans = 1, 1
|
|
65
|
+
for _ in range(num_layers):
|
|
66
|
+
mask_out_chans = mask_in_chans * (stride**2)
|
|
67
|
+
self.encoder.append(
|
|
68
|
+
nn.Conv2d(
|
|
69
|
+
mask_in_chans,
|
|
70
|
+
mask_out_chans,
|
|
71
|
+
kernel_size=kernel_size,
|
|
72
|
+
stride=stride,
|
|
73
|
+
padding=padding,
|
|
74
|
+
)
|
|
75
|
+
)
|
|
76
|
+
self.encoder.append(LayerNorm2d(mask_out_chans))
|
|
77
|
+
self.encoder.append(activation())
|
|
78
|
+
mask_in_chans = mask_out_chans
|
|
79
|
+
|
|
80
|
+
self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
|
|
81
|
+
|
|
82
|
+
def forward(self, x):
|
|
83
|
+
"""Downsamples and encodes input mask to embed_dim channels using convolutional layers and LayerNorm2d."""
|
|
84
|
+
return self.encoder(x)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
# Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt)
|
|
88
|
+
class CXBlock(nn.Module):
|
|
89
|
+
"""
|
|
90
|
+
ConvNeXt Block for efficient feature extraction in convolutional neural networks.
|
|
91
|
+
|
|
92
|
+
This block implements a modified version of the ConvNeXt architecture, offering two equivalent
|
|
93
|
+
implementations for improved performance and flexibility.
|
|
94
|
+
|
|
95
|
+
Attributes:
|
|
96
|
+
dwconv (nn.Conv2d): Depthwise convolution layer.
|
|
97
|
+
norm (LayerNorm2d): Layer normalization applied to channels.
|
|
98
|
+
pwconv1 (nn.Linear): First pointwise convolution implemented as a linear layer.
|
|
99
|
+
act (nn.GELU): GELU activation function.
|
|
100
|
+
pwconv2 (nn.Linear): Second pointwise convolution implemented as a linear layer.
|
|
101
|
+
gamma (nn.Parameter | None): Learnable scale parameter for layer scaling.
|
|
102
|
+
drop_path (nn.Module): DropPath layer for stochastic depth regularization.
|
|
103
|
+
|
|
104
|
+
Methods:
|
|
105
|
+
forward: Processes the input tensor through the ConvNeXt block.
|
|
106
|
+
|
|
107
|
+
Examples:
|
|
108
|
+
>>> import torch
|
|
109
|
+
>>> x = torch.randn(1, 64, 56, 56)
|
|
110
|
+
>>> block = CXBlock(dim=64, kernel_size=7, padding=3)
|
|
111
|
+
>>> output = block(x)
|
|
112
|
+
>>> print(output.shape)
|
|
113
|
+
torch.Size([1, 64, 56, 56])
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
def __init__(
|
|
117
|
+
self,
|
|
118
|
+
dim,
|
|
119
|
+
kernel_size=7,
|
|
120
|
+
padding=3,
|
|
121
|
+
drop_path=0.0,
|
|
122
|
+
layer_scale_init_value=1e-6,
|
|
123
|
+
use_dwconv=True,
|
|
124
|
+
):
|
|
125
|
+
"""
|
|
126
|
+
Initialize a ConvNeXt Block.
|
|
127
|
+
|
|
128
|
+
This block implements a ConvNeXt architecture with optional depthwise convolution, layer normalization,
|
|
129
|
+
pointwise convolutions, and GELU activation.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
dim (int): Number of input channels.
|
|
133
|
+
kernel_size (int): Size of the convolutional kernel. Default is 7.
|
|
134
|
+
padding (int): Padding size for the convolution. Default is 3.
|
|
135
|
+
drop_path (float): Stochastic depth rate. Default is 0.0.
|
|
136
|
+
layer_scale_init_value (float): Initial value for Layer Scale. Default is 1e-6.
|
|
137
|
+
use_dwconv (bool): Whether to use depthwise convolution. Default is True.
|
|
138
|
+
|
|
139
|
+
Attributes:
|
|
140
|
+
dwconv (nn.Conv2d): Depthwise or standard 2D convolution layer.
|
|
141
|
+
norm (LayerNorm2d): Layer normalization applied to the output of dwconv.
|
|
142
|
+
pwconv1 (nn.Linear): First pointwise convolution implemented as a linear layer.
|
|
143
|
+
act (nn.GELU): GELU activation function.
|
|
144
|
+
pwconv2 (nn.Linear): Second pointwise convolution implemented as a linear layer.
|
|
145
|
+
gamma (nn.Parameter | None): Learnable scale parameter for the residual path.
|
|
146
|
+
|
|
147
|
+
Examples:
|
|
148
|
+
>>> block = CXBlock(dim=64, kernel_size=7, padding=3)
|
|
149
|
+
>>> x = torch.randn(1, 64, 32, 32)
|
|
150
|
+
>>> output = block(x)
|
|
151
|
+
>>> print(output.shape)
|
|
152
|
+
torch.Size([1, 64, 32, 32])
|
|
153
|
+
"""
|
|
154
|
+
super().__init__()
|
|
155
|
+
self.dwconv = nn.Conv2d(
|
|
156
|
+
dim,
|
|
157
|
+
dim,
|
|
158
|
+
kernel_size=kernel_size,
|
|
159
|
+
padding=padding,
|
|
160
|
+
groups=dim if use_dwconv else 1,
|
|
161
|
+
) # depthwise conv
|
|
162
|
+
self.norm = LayerNorm2d(dim, eps=1e-6)
|
|
163
|
+
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
|
|
164
|
+
self.act = nn.GELU()
|
|
165
|
+
self.pwconv2 = nn.Linear(4 * dim, dim)
|
|
166
|
+
self.gamma = (
|
|
167
|
+
nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
|
168
|
+
if layer_scale_init_value > 0
|
|
169
|
+
else None
|
|
170
|
+
)
|
|
171
|
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
172
|
+
|
|
173
|
+
def forward(self, x):
|
|
174
|
+
"""Applies ConvNeXt block operations to input tensor, including convolutions and residual connection."""
|
|
175
|
+
input = x
|
|
176
|
+
x = self.dwconv(x)
|
|
177
|
+
x = self.norm(x)
|
|
178
|
+
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
|
179
|
+
x = self.pwconv1(x)
|
|
180
|
+
x = self.act(x)
|
|
181
|
+
x = self.pwconv2(x)
|
|
182
|
+
if self.gamma is not None:
|
|
183
|
+
x = self.gamma * x
|
|
184
|
+
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
|
185
|
+
|
|
186
|
+
x = input + self.drop_path(x)
|
|
187
|
+
return x
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class Fuser(nn.Module):
|
|
191
|
+
"""
|
|
192
|
+
A module for fusing features through multiple layers of a neural network.
|
|
193
|
+
|
|
194
|
+
This class applies a series of identical layers to an input tensor, optionally projecting the input first.
|
|
195
|
+
|
|
196
|
+
Attributes:
|
|
197
|
+
proj (nn.Module): An optional input projection layer. Identity if no projection is needed.
|
|
198
|
+
layers (nn.ModuleList): A list of identical layers to be applied sequentially.
|
|
199
|
+
|
|
200
|
+
Methods:
|
|
201
|
+
forward: Applies the fuser to an input tensor.
|
|
202
|
+
|
|
203
|
+
Examples:
|
|
204
|
+
>>> layer = CXBlock(dim=256)
|
|
205
|
+
>>> fuser = Fuser(layer, num_layers=3, dim=256, input_projection=True)
|
|
206
|
+
>>> x = torch.randn(1, 256, 32, 32)
|
|
207
|
+
>>> output = fuser(x)
|
|
208
|
+
>>> print(output.shape)
|
|
209
|
+
torch.Size([1, 256, 32, 32])
|
|
210
|
+
"""
|
|
211
|
+
|
|
212
|
+
def __init__(self, layer, num_layers, dim=None, input_projection=False):
|
|
213
|
+
"""
|
|
214
|
+
Initializes the Fuser module.
|
|
215
|
+
|
|
216
|
+
This module creates a sequence of identical layers and optionally applies an input projection.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
layer (nn.Module): The layer to be replicated in the fuser.
|
|
220
|
+
num_layers (int): The number of times to replicate the layer.
|
|
221
|
+
dim (int | None): The dimension for input projection, if used.
|
|
222
|
+
input_projection (bool): Whether to use input projection.
|
|
223
|
+
|
|
224
|
+
Attributes:
|
|
225
|
+
proj (nn.Module): The input projection layer, or nn.Identity if not used.
|
|
226
|
+
layers (nn.ModuleList): A list of replicated layers.
|
|
227
|
+
|
|
228
|
+
Examples:
|
|
229
|
+
>>> layer = nn.Linear(64, 64)
|
|
230
|
+
>>> fuser = Fuser(layer, num_layers=3, dim=64, input_projection=True)
|
|
231
|
+
>>> input_tensor = torch.randn(1, 64)
|
|
232
|
+
>>> output = fuser(input_tensor)
|
|
233
|
+
"""
|
|
234
|
+
super().__init__()
|
|
235
|
+
self.proj = nn.Identity()
|
|
236
|
+
self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)])
|
|
237
|
+
|
|
238
|
+
if input_projection:
|
|
239
|
+
assert dim is not None
|
|
240
|
+
self.proj = nn.Conv2d(dim, dim, kernel_size=1)
|
|
241
|
+
|
|
242
|
+
def forward(self, x):
|
|
243
|
+
"""Applies a series of layers to the input tensor, optionally projecting it first."""
|
|
244
|
+
x = self.proj(x)
|
|
245
|
+
for layer in self.layers:
|
|
246
|
+
x = layer(x)
|
|
247
|
+
return x
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
class TwoWayAttentionBlock(SAMTwoWayAttentionBlock):
|
|
251
|
+
"""
|
|
252
|
+
A two-way attention block for performing self-attention and cross-attention in both directions.
|
|
253
|
+
|
|
254
|
+
This block extends the SAMTwoWayAttentionBlock and consists of four main components: self-attention on
|
|
255
|
+
sparse inputs, cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and
|
|
256
|
+
cross-attention from dense to sparse inputs.
|
|
257
|
+
|
|
258
|
+
Attributes:
|
|
259
|
+
self_attn (Attention): Self-attention layer for queries.
|
|
260
|
+
norm1 (nn.LayerNorm): Layer normalization after the first attention block.
|
|
261
|
+
cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.
|
|
262
|
+
norm2 (nn.LayerNorm): Layer normalization after the second attention block.
|
|
263
|
+
mlp (MLP): MLP block for transforming query embeddings.
|
|
264
|
+
norm3 (nn.LayerNorm): Layer normalization after the MLP block.
|
|
265
|
+
norm4 (nn.LayerNorm): Layer normalization after the third attention block.
|
|
266
|
+
cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.
|
|
267
|
+
skip_first_layer_pe (bool): Flag to skip positional encoding in the first layer.
|
|
268
|
+
|
|
269
|
+
Methods:
|
|
270
|
+
forward: Processes input through the attention blocks and MLP.
|
|
271
|
+
|
|
272
|
+
Examples:
|
|
273
|
+
>>> block = TwoWayAttentionBlock(embedding_dim=256, num_heads=8)
|
|
274
|
+
>>> sparse_input = torch.randn(1, 100, 256)
|
|
275
|
+
>>> dense_input = torch.randn(1, 256, 16, 16)
|
|
276
|
+
>>> sparse_output, dense_output = block(sparse_input, dense_input)
|
|
277
|
+
"""
|
|
278
|
+
|
|
279
|
+
def __init__(
|
|
280
|
+
self,
|
|
281
|
+
embedding_dim: int,
|
|
282
|
+
num_heads: int,
|
|
283
|
+
mlp_dim: int = 2048,
|
|
284
|
+
activation: Type[nn.Module] = nn.ReLU,
|
|
285
|
+
attention_downsample_rate: int = 2,
|
|
286
|
+
skip_first_layer_pe: bool = False,
|
|
287
|
+
) -> None:
|
|
288
|
+
"""
|
|
289
|
+
Initializes a TwoWayAttentionBlock for performing self-attention and cross-attention in two directions.
|
|
290
|
+
|
|
291
|
+
This block consists of four main layers: self-attention on sparse inputs, cross-attention of sparse inputs
|
|
292
|
+
to dense inputs, an MLP block on sparse inputs, and cross-attention of dense inputs to sparse inputs.
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
embedding_dim (int): The channel dimension of the embeddings.
|
|
296
|
+
num_heads (int): The number of heads in the attention layers.
|
|
297
|
+
mlp_dim (int): The hidden dimension of the MLP block.
|
|
298
|
+
activation (Type[nn.Module]): The activation function of the MLP block.
|
|
299
|
+
attention_downsample_rate (int): The downsample rate for attention computations.
|
|
300
|
+
skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer.
|
|
301
|
+
|
|
302
|
+
Attributes:
|
|
303
|
+
self_attn (Attention): The self-attention layer for the queries.
|
|
304
|
+
norm1 (nn.LayerNorm): Layer normalization following the first attention block.
|
|
305
|
+
cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.
|
|
306
|
+
norm2 (nn.LayerNorm): Layer normalization following the second attention block.
|
|
307
|
+
mlp (MLP): MLP block that transforms the query embeddings.
|
|
308
|
+
norm3 (nn.LayerNorm): Layer normalization following the MLP block.
|
|
309
|
+
norm4 (nn.LayerNorm): Layer normalization following the third attention block.
|
|
310
|
+
cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.
|
|
311
|
+
skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer.
|
|
312
|
+
|
|
313
|
+
Examples:
|
|
314
|
+
>>> block = TwoWayAttentionBlock(embedding_dim=256, num_heads=8, mlp_dim=2048)
|
|
315
|
+
>>> sparse_inputs = torch.randn(1, 100, 256)
|
|
316
|
+
>>> dense_inputs = torch.randn(1, 256, 32, 32)
|
|
317
|
+
>>> sparse_outputs, dense_outputs = block(sparse_inputs, dense_inputs)
|
|
318
|
+
"""
|
|
319
|
+
super().__init__(embedding_dim, num_heads, mlp_dim, activation, attention_downsample_rate, skip_first_layer_pe)
|
|
320
|
+
self.mlp = MLP(embedding_dim, mlp_dim, embedding_dim, num_layers=2, act=activation)
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
class TwoWayTransformer(SAMTwoWayTransformer):
|
|
324
|
+
"""
|
|
325
|
+
A Two-Way Transformer module for simultaneous attention to image and query points.
|
|
326
|
+
|
|
327
|
+
This class implements a specialized transformer decoder that attends to an input image using queries with
|
|
328
|
+
supplied positional embeddings. It is particularly useful for tasks like object detection, image
|
|
329
|
+
segmentation, and point cloud processing.
|
|
330
|
+
|
|
331
|
+
Attributes:
|
|
332
|
+
depth (int): Number of layers in the transformer.
|
|
333
|
+
embedding_dim (int): Channel dimension for input embeddings.
|
|
334
|
+
num_heads (int): Number of heads for multihead attention.
|
|
335
|
+
mlp_dim (int): Internal channel dimension for the MLP block.
|
|
336
|
+
layers (nn.ModuleList): List of TwoWayAttentionBlock layers comprising the transformer.
|
|
337
|
+
final_attn_token_to_image (Attention): Final attention layer from queries to image.
|
|
338
|
+
norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.
|
|
339
|
+
|
|
340
|
+
Methods:
|
|
341
|
+
forward: Processes input image embeddings and query embeddings through the transformer.
|
|
342
|
+
|
|
343
|
+
Examples:
|
|
344
|
+
>>> transformer = TwoWayTransformer(depth=5, embedding_dim=256, num_heads=8, mlp_dim=2048)
|
|
345
|
+
>>> image_embedding = torch.randn(1, 256, 64, 64)
|
|
346
|
+
>>> query_embedding = torch.randn(1, 100, 256)
|
|
347
|
+
>>> output = transformer(image_embedding, query_embedding)
|
|
348
|
+
"""
|
|
349
|
+
|
|
350
|
+
def __init__(
|
|
351
|
+
self,
|
|
352
|
+
depth: int,
|
|
353
|
+
embedding_dim: int,
|
|
354
|
+
num_heads: int,
|
|
355
|
+
mlp_dim: int,
|
|
356
|
+
activation: Type[nn.Module] = nn.ReLU,
|
|
357
|
+
attention_downsample_rate: int = 2,
|
|
358
|
+
) -> None:
|
|
359
|
+
"""
|
|
360
|
+
Initializes a TwoWayTransformer instance.
|
|
361
|
+
|
|
362
|
+
This transformer decoder attends to an input image using queries with supplied positional embeddings.
|
|
363
|
+
It is designed for tasks like object detection, image segmentation, and point cloud processing.
|
|
364
|
+
|
|
365
|
+
Args:
|
|
366
|
+
depth (int): Number of layers in the transformer.
|
|
367
|
+
embedding_dim (int): Channel dimension for the input embeddings.
|
|
368
|
+
num_heads (int): Number of heads for multihead attention. Must divide embedding_dim.
|
|
369
|
+
mlp_dim (int): Channel dimension internal to the MLP block.
|
|
370
|
+
activation (Type[nn.Module]): Activation function to use in the MLP block.
|
|
371
|
+
attention_downsample_rate (int): Downsampling rate for attention computations.
|
|
372
|
+
|
|
373
|
+
Attributes:
|
|
374
|
+
depth (int): Number of layers in the transformer.
|
|
375
|
+
embedding_dim (int): Channel dimension for the input embeddings.
|
|
376
|
+
num_heads (int): Number of heads for multihead attention.
|
|
377
|
+
mlp_dim (int): Internal channel dimension for the MLP block.
|
|
378
|
+
layers (nn.ModuleList): List of TwoWayAttentionBlock layers comprising the transformer.
|
|
379
|
+
final_attn_token_to_image (Attention): Final attention layer from queries to image.
|
|
380
|
+
norm_final_attn (nn.LayerNorm): Layer normalization applied to the final queries.
|
|
381
|
+
|
|
382
|
+
Examples:
|
|
383
|
+
>>> transformer = TwoWayTransformer(depth=5, embedding_dim=256, num_heads=8, mlp_dim=2048)
|
|
384
|
+
>>> transformer
|
|
385
|
+
TwoWayTransformer(
|
|
386
|
+
(layers): ModuleList(
|
|
387
|
+
(0-4): 5 x TwoWayAttentionBlock(...)
|
|
388
|
+
)
|
|
389
|
+
(final_attn_token_to_image): Attention(...)
|
|
390
|
+
(norm_final_attn): LayerNorm(...)
|
|
391
|
+
)
|
|
392
|
+
"""
|
|
393
|
+
super().__init__(depth, embedding_dim, num_heads, mlp_dim, activation, attention_downsample_rate)
|
|
394
|
+
self.layers = nn.ModuleList()
|
|
395
|
+
for i in range(depth):
|
|
396
|
+
self.layers.append(
|
|
397
|
+
TwoWayAttentionBlock(
|
|
398
|
+
embedding_dim=embedding_dim,
|
|
399
|
+
num_heads=num_heads,
|
|
400
|
+
mlp_dim=mlp_dim,
|
|
401
|
+
activation=activation,
|
|
402
|
+
attention_downsample_rate=attention_downsample_rate,
|
|
403
|
+
skip_first_layer_pe=(i == 0),
|
|
404
|
+
)
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
class RoPEAttention(Attention):
|
|
409
|
+
"""Implements rotary position encoding for attention mechanisms in transformer architectures."""
|
|
410
|
+
|
|
411
|
+
def __init__(
|
|
412
|
+
self,
|
|
413
|
+
*args,
|
|
414
|
+
rope_theta=10000.0,
|
|
415
|
+
# whether to repeat q rope to match k length
|
|
416
|
+
# this is needed for cross-attention to memories
|
|
417
|
+
rope_k_repeat=False,
|
|
418
|
+
feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution
|
|
419
|
+
**kwargs,
|
|
420
|
+
):
|
|
421
|
+
"""Initializes RoPEAttention with rotary position encoding for attention mechanisms."""
|
|
422
|
+
super().__init__(*args, **kwargs)
|
|
423
|
+
|
|
424
|
+
self.compute_cis = partial(compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta)
|
|
425
|
+
freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
|
|
426
|
+
self.freqs_cis = freqs_cis
|
|
427
|
+
self.rope_k_repeat = rope_k_repeat
|
|
428
|
+
|
|
429
|
+
def forward(self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0) -> Tensor:
|
|
430
|
+
"""Applies rotary position encoding and computes attention between query, key, and value tensors."""
|
|
431
|
+
q = self.q_proj(q)
|
|
432
|
+
k = self.k_proj(k)
|
|
433
|
+
v = self.v_proj(v)
|
|
434
|
+
|
|
435
|
+
# Separate into heads
|
|
436
|
+
q = self._separate_heads(q, self.num_heads)
|
|
437
|
+
k = self._separate_heads(k, self.num_heads)
|
|
438
|
+
v = self._separate_heads(v, self.num_heads)
|
|
439
|
+
|
|
440
|
+
# Apply rotary position encoding
|
|
441
|
+
w = h = math.sqrt(q.shape[-2])
|
|
442
|
+
self.freqs_cis = self.freqs_cis.to(q.device)
|
|
443
|
+
if self.freqs_cis.shape[0] != q.shape[-2]:
|
|
444
|
+
self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
|
|
445
|
+
if q.shape[-2] != k.shape[-2]:
|
|
446
|
+
assert self.rope_k_repeat
|
|
447
|
+
|
|
448
|
+
num_k_rope = k.size(-2) - num_k_exclude_rope
|
|
449
|
+
q, k[:, :, :num_k_rope] = apply_rotary_enc(
|
|
450
|
+
q,
|
|
451
|
+
k[:, :, :num_k_rope],
|
|
452
|
+
freqs_cis=self.freqs_cis,
|
|
453
|
+
repeat_freqs_k=self.rope_k_repeat,
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
# Attention
|
|
457
|
+
_, _, _, c_per_head = q.shape
|
|
458
|
+
attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
|
|
459
|
+
attn = attn / math.sqrt(c_per_head)
|
|
460
|
+
attn = torch.softmax(attn, dim=-1)
|
|
461
|
+
|
|
462
|
+
# Get output
|
|
463
|
+
out = attn @ v
|
|
464
|
+
|
|
465
|
+
out = self._recombine_heads(out)
|
|
466
|
+
out = self.out_proj(out)
|
|
467
|
+
|
|
468
|
+
return out
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
|
|
472
|
+
"""Applies pooling and optional normalization to a tensor, handling permutations for spatial operations."""
|
|
473
|
+
if pool is None:
|
|
474
|
+
return x
|
|
475
|
+
# (B, H, W, C) -> (B, C, H, W)
|
|
476
|
+
x = x.permute(0, 3, 1, 2)
|
|
477
|
+
x = pool(x)
|
|
478
|
+
# (B, C, H', W') -> (B, H', W', C)
|
|
479
|
+
x = x.permute(0, 2, 3, 1)
|
|
480
|
+
if norm:
|
|
481
|
+
x = norm(x)
|
|
482
|
+
|
|
483
|
+
return x
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
class MultiScaleAttention(nn.Module):
|
|
487
|
+
"""Implements multi-scale self-attention with optional query pooling for efficient feature extraction."""
|
|
488
|
+
|
|
489
|
+
def __init__(
|
|
490
|
+
self,
|
|
491
|
+
dim: int,
|
|
492
|
+
dim_out: int,
|
|
493
|
+
num_heads: int,
|
|
494
|
+
q_pool: nn.Module = None,
|
|
495
|
+
):
|
|
496
|
+
"""Initializes a multi-scale attention module with configurable query pooling and linear projections."""
|
|
497
|
+
super().__init__()
|
|
498
|
+
|
|
499
|
+
self.dim = dim
|
|
500
|
+
self.dim_out = dim_out
|
|
501
|
+
|
|
502
|
+
self.num_heads = num_heads
|
|
503
|
+
head_dim = dim_out // num_heads
|
|
504
|
+
self.scale = head_dim**-0.5
|
|
505
|
+
|
|
506
|
+
self.q_pool = q_pool
|
|
507
|
+
self.qkv = nn.Linear(dim, dim_out * 3)
|
|
508
|
+
self.proj = nn.Linear(dim_out, dim_out)
|
|
509
|
+
|
|
510
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
511
|
+
"""Applies multi-scale attention to input tensor, optionally downsampling query features."""
|
|
512
|
+
B, H, W, _ = x.shape
|
|
513
|
+
# qkv with shape (B, H * W, 3, nHead, C)
|
|
514
|
+
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
|
|
515
|
+
# q, k, v with shape (B, H * W, nheads, C)
|
|
516
|
+
q, k, v = torch.unbind(qkv, 2)
|
|
517
|
+
|
|
518
|
+
# Q pooling (for downsample at stage changes)
|
|
519
|
+
if self.q_pool:
|
|
520
|
+
q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
|
|
521
|
+
H, W = q.shape[1:3] # downsampled shape
|
|
522
|
+
q = q.reshape(B, H * W, self.num_heads, -1)
|
|
523
|
+
|
|
524
|
+
# Torch's SDPA expects [B, nheads, H*W, C] so we transpose
|
|
525
|
+
x = F.scaled_dot_product_attention(
|
|
526
|
+
q.transpose(1, 2),
|
|
527
|
+
k.transpose(1, 2),
|
|
528
|
+
v.transpose(1, 2),
|
|
529
|
+
)
|
|
530
|
+
# Transpose back
|
|
531
|
+
x = x.transpose(1, 2)
|
|
532
|
+
x = x.reshape(B, H, W, -1)
|
|
533
|
+
|
|
534
|
+
x = self.proj(x)
|
|
535
|
+
|
|
536
|
+
return x
|
|
537
|
+
|
|
538
|
+
|
|
539
|
+
class MultiScaleBlock(nn.Module):
|
|
540
|
+
"""Multiscale attention block with window partitioning and query pooling for efficient vision transformers."""
|
|
541
|
+
|
|
542
|
+
def __init__(
|
|
543
|
+
self,
|
|
544
|
+
dim: int,
|
|
545
|
+
dim_out: int,
|
|
546
|
+
num_heads: int,
|
|
547
|
+
mlp_ratio: float = 4.0,
|
|
548
|
+
drop_path: float = 0.0,
|
|
549
|
+
norm_layer: Union[nn.Module, str] = "LayerNorm",
|
|
550
|
+
q_stride: Tuple[int, int] = None,
|
|
551
|
+
act_layer: nn.Module = nn.GELU,
|
|
552
|
+
window_size: int = 0,
|
|
553
|
+
):
|
|
554
|
+
"""Initializes a multi-scale attention block with optional window partitioning and downsampling."""
|
|
555
|
+
super().__init__()
|
|
556
|
+
|
|
557
|
+
if isinstance(norm_layer, str):
|
|
558
|
+
norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
|
|
559
|
+
|
|
560
|
+
self.dim = dim
|
|
561
|
+
self.dim_out = dim_out
|
|
562
|
+
self.norm1 = norm_layer(dim)
|
|
563
|
+
|
|
564
|
+
self.window_size = window_size
|
|
565
|
+
|
|
566
|
+
self.pool, self.q_stride = None, q_stride
|
|
567
|
+
if self.q_stride:
|
|
568
|
+
self.pool = nn.MaxPool2d(kernel_size=q_stride, stride=q_stride, ceil_mode=False)
|
|
569
|
+
|
|
570
|
+
self.attn = MultiScaleAttention(
|
|
571
|
+
dim,
|
|
572
|
+
dim_out,
|
|
573
|
+
num_heads=num_heads,
|
|
574
|
+
q_pool=self.pool,
|
|
575
|
+
)
|
|
576
|
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
577
|
+
|
|
578
|
+
self.norm2 = norm_layer(dim_out)
|
|
579
|
+
self.mlp = MLP(
|
|
580
|
+
dim_out,
|
|
581
|
+
int(dim_out * mlp_ratio),
|
|
582
|
+
dim_out,
|
|
583
|
+
num_layers=2,
|
|
584
|
+
act=act_layer,
|
|
585
|
+
)
|
|
586
|
+
|
|
587
|
+
if dim != dim_out:
|
|
588
|
+
self.proj = nn.Linear(dim, dim_out)
|
|
589
|
+
|
|
590
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
591
|
+
"""Applies multi-scale attention and MLP processing to input tensor, with optional windowing."""
|
|
592
|
+
shortcut = x # B, H, W, C
|
|
593
|
+
x = self.norm1(x)
|
|
594
|
+
|
|
595
|
+
# Skip connection
|
|
596
|
+
if self.dim != self.dim_out:
|
|
597
|
+
shortcut = do_pool(self.proj(x), self.pool)
|
|
598
|
+
|
|
599
|
+
# Window partition
|
|
600
|
+
window_size = self.window_size
|
|
601
|
+
if window_size > 0:
|
|
602
|
+
H, W = x.shape[1], x.shape[2]
|
|
603
|
+
x, pad_hw = window_partition(x, window_size)
|
|
604
|
+
|
|
605
|
+
# Window Attention + Q Pooling (if stage change)
|
|
606
|
+
x = self.attn(x)
|
|
607
|
+
if self.q_stride:
|
|
608
|
+
# Shapes have changed due to Q pooling
|
|
609
|
+
window_size = self.window_size // self.q_stride[0]
|
|
610
|
+
H, W = shortcut.shape[1:3]
|
|
611
|
+
|
|
612
|
+
pad_h = (window_size - H % window_size) % window_size
|
|
613
|
+
pad_w = (window_size - W % window_size) % window_size
|
|
614
|
+
pad_hw = (H + pad_h, W + pad_w)
|
|
615
|
+
|
|
616
|
+
# Reverse window partition
|
|
617
|
+
if self.window_size > 0:
|
|
618
|
+
x = window_unpartition(x, window_size, pad_hw, (H, W))
|
|
619
|
+
|
|
620
|
+
x = shortcut + self.drop_path(x)
|
|
621
|
+
# MLP
|
|
622
|
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
|
623
|
+
return x
|
|
624
|
+
|
|
625
|
+
|
|
626
|
+
class PositionEmbeddingSine(nn.Module):
|
|
627
|
+
"""Generates sinusoidal positional embeddings for 2D inputs like images."""
|
|
628
|
+
|
|
629
|
+
def __init__(
|
|
630
|
+
self,
|
|
631
|
+
num_pos_feats,
|
|
632
|
+
temperature: int = 10000,
|
|
633
|
+
normalize: bool = True,
|
|
634
|
+
scale: Optional[float] = None,
|
|
635
|
+
):
|
|
636
|
+
"""Initializes sinusoidal position embeddings for 2D image inputs."""
|
|
637
|
+
super().__init__()
|
|
638
|
+
assert num_pos_feats % 2 == 0, "Expecting even model width"
|
|
639
|
+
self.num_pos_feats = num_pos_feats // 2
|
|
640
|
+
self.temperature = temperature
|
|
641
|
+
self.normalize = normalize
|
|
642
|
+
if scale is not None and normalize is False:
|
|
643
|
+
raise ValueError("normalize should be True if scale is passed")
|
|
644
|
+
if scale is None:
|
|
645
|
+
scale = 2 * math.pi
|
|
646
|
+
self.scale = scale
|
|
647
|
+
|
|
648
|
+
self.cache = {}
|
|
649
|
+
|
|
650
|
+
def _encode_xy(self, x, y):
|
|
651
|
+
"""Encodes 2D positions using sine and cosine functions for positional embeddings."""
|
|
652
|
+
assert len(x) == len(y) and x.ndim == y.ndim == 1
|
|
653
|
+
x_embed = x * self.scale
|
|
654
|
+
y_embed = y * self.scale
|
|
655
|
+
|
|
656
|
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
|
657
|
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
|
658
|
+
|
|
659
|
+
pos_x = x_embed[:, None] / dim_t
|
|
660
|
+
pos_y = y_embed[:, None] / dim_t
|
|
661
|
+
pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1)
|
|
662
|
+
pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1)
|
|
663
|
+
return pos_x, pos_y
|
|
664
|
+
|
|
665
|
+
@torch.no_grad()
|
|
666
|
+
def encode_boxes(self, x, y, w, h):
|
|
667
|
+
"""Encodes box coordinates and dimensions into positional embeddings for object detection tasks."""
|
|
668
|
+
pos_x, pos_y = self._encode_xy(x, y)
|
|
669
|
+
pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
|
|
670
|
+
return pos
|
|
671
|
+
|
|
672
|
+
encode = encode_boxes # Backwards compatibility
|
|
673
|
+
|
|
674
|
+
@torch.no_grad()
|
|
675
|
+
def encode_points(self, x, y, labels):
|
|
676
|
+
"""Encodes 2D point coordinates with sinusoidal positional embeddings and appends labels."""
|
|
677
|
+
(bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
|
|
678
|
+
assert bx == by and nx == ny and bx == bl and nx == nl
|
|
679
|
+
pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
|
|
680
|
+
pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
|
|
681
|
+
pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
|
|
682
|
+
return pos
|
|
683
|
+
|
|
684
|
+
@torch.no_grad()
|
|
685
|
+
def forward(self, x: torch.Tensor):
|
|
686
|
+
"""Generate sinusoidal position embeddings for 2D inputs."""
|
|
687
|
+
cache_key = (x.shape[-2], x.shape[-1])
|
|
688
|
+
if cache_key in self.cache:
|
|
689
|
+
return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
|
|
690
|
+
y_embed = (
|
|
691
|
+
torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)
|
|
692
|
+
.view(1, -1, 1)
|
|
693
|
+
.repeat(x.shape[0], 1, x.shape[-1])
|
|
694
|
+
)
|
|
695
|
+
x_embed = (
|
|
696
|
+
torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)
|
|
697
|
+
.view(1, 1, -1)
|
|
698
|
+
.repeat(x.shape[0], x.shape[-2], 1)
|
|
699
|
+
)
|
|
700
|
+
|
|
701
|
+
if self.normalize:
|
|
702
|
+
eps = 1e-6
|
|
703
|
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
|
704
|
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
|
705
|
+
|
|
706
|
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
|
707
|
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
|
708
|
+
|
|
709
|
+
pos_x = x_embed[:, :, :, None] / dim_t
|
|
710
|
+
pos_y = y_embed[:, :, :, None] / dim_t
|
|
711
|
+
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
|
712
|
+
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
|
713
|
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
|
714
|
+
self.cache[cache_key] = pos[0]
|
|
715
|
+
return pos
|