dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
- dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -6
- tests/conftest.py +15 -39
- tests/test_cli.py +17 -17
- tests/test_cuda.py +17 -8
- tests/test_engine.py +36 -10
- tests/test_exports.py +98 -37
- tests/test_integrations.py +12 -15
- tests/test_python.py +126 -82
- tests/test_solutions.py +319 -135
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +83 -87
- ultralytics/cfg/datasets/Argoverse.yaml +4 -4
- ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
- ultralytics/cfg/datasets/ImageNet.yaml +3 -3
- ultralytics/cfg/datasets/Objects365.yaml +24 -20
- ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
- ultralytics/cfg/datasets/VOC.yaml +10 -13
- ultralytics/cfg/datasets/VisDrone.yaml +43 -33
- ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
- ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
- ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
- ultralytics/cfg/datasets/coco-pose.yaml +26 -4
- ultralytics/cfg/datasets/coco.yaml +4 -4
- ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco128.yaml +2 -2
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco8.yaml +2 -2
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +5 -5
- ultralytics/cfg/datasets/dog-pose.yaml +32 -4
- ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
- ultralytics/cfg/datasets/lvis.yaml +9 -9
- ultralytics/cfg/datasets/medical-pills.yaml +4 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
- ultralytics/cfg/datasets/package-seg.yaml +5 -5
- ultralytics/cfg/datasets/signature.yaml +4 -4
- ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
- ultralytics/cfg/datasets/xView.yaml +5 -5
- ultralytics/cfg/default.yaml +96 -93
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +12 -12
- ultralytics/data/augment.py +531 -564
- ultralytics/data/base.py +76 -81
- ultralytics/data/build.py +206 -42
- ultralytics/data/converter.py +179 -78
- ultralytics/data/dataset.py +121 -121
- ultralytics/data/loaders.py +114 -91
- ultralytics/data/split.py +28 -15
- ultralytics/data/split_dota.py +67 -48
- ultralytics/data/utils.py +110 -89
- ultralytics/engine/exporter.py +422 -460
- ultralytics/engine/model.py +224 -252
- ultralytics/engine/predictor.py +94 -89
- ultralytics/engine/results.py +345 -595
- ultralytics/engine/trainer.py +231 -134
- ultralytics/engine/tuner.py +279 -73
- ultralytics/engine/validator.py +53 -46
- ultralytics/hub/__init__.py +26 -28
- ultralytics/hub/auth.py +30 -16
- ultralytics/hub/google/__init__.py +34 -36
- ultralytics/hub/session.py +53 -77
- ultralytics/hub/utils.py +23 -109
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +36 -18
- ultralytics/models/fastsam/predict.py +33 -44
- ultralytics/models/fastsam/utils.py +4 -5
- ultralytics/models/fastsam/val.py +12 -14
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +16 -20
- ultralytics/models/nas/predict.py +12 -14
- ultralytics/models/nas/val.py +4 -5
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +9 -9
- ultralytics/models/rtdetr/predict.py +22 -17
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +79 -59
- ultralytics/models/sam/__init__.py +8 -2
- ultralytics/models/sam/amg.py +53 -38
- ultralytics/models/sam/build.py +29 -31
- ultralytics/models/sam/model.py +33 -38
- ultralytics/models/sam/modules/blocks.py +159 -182
- ultralytics/models/sam/modules/decoders.py +38 -47
- ultralytics/models/sam/modules/encoders.py +114 -133
- ultralytics/models/sam/modules/memory_attention.py +38 -31
- ultralytics/models/sam/modules/sam.py +114 -93
- ultralytics/models/sam/modules/tiny_encoder.py +268 -291
- ultralytics/models/sam/modules/transformer.py +59 -66
- ultralytics/models/sam/modules/utils.py +55 -72
- ultralytics/models/sam/predict.py +745 -341
- ultralytics/models/utils/loss.py +118 -107
- ultralytics/models/utils/ops.py +118 -71
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +28 -26
- ultralytics/models/yolo/classify/train.py +50 -81
- ultralytics/models/yolo/classify/val.py +68 -61
- ultralytics/models/yolo/detect/predict.py +12 -15
- ultralytics/models/yolo/detect/train.py +56 -46
- ultralytics/models/yolo/detect/val.py +279 -223
- ultralytics/models/yolo/model.py +167 -86
- ultralytics/models/yolo/obb/predict.py +7 -11
- ultralytics/models/yolo/obb/train.py +23 -25
- ultralytics/models/yolo/obb/val.py +107 -99
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +12 -14
- ultralytics/models/yolo/pose/train.py +31 -69
- ultralytics/models/yolo/pose/val.py +119 -254
- ultralytics/models/yolo/segment/predict.py +21 -25
- ultralytics/models/yolo/segment/train.py +12 -66
- ultralytics/models/yolo/segment/val.py +126 -305
- ultralytics/models/yolo/world/train.py +53 -45
- ultralytics/models/yolo/world/train_world.py +51 -32
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +30 -37
- ultralytics/models/yolo/yoloe/train.py +89 -71
- ultralytics/models/yolo/yoloe/train_seg.py +15 -17
- ultralytics/models/yolo/yoloe/val.py +56 -41
- ultralytics/nn/__init__.py +9 -11
- ultralytics/nn/autobackend.py +179 -107
- ultralytics/nn/modules/__init__.py +67 -67
- ultralytics/nn/modules/activation.py +8 -7
- ultralytics/nn/modules/block.py +302 -323
- ultralytics/nn/modules/conv.py +61 -104
- ultralytics/nn/modules/head.py +488 -186
- ultralytics/nn/modules/transformer.py +183 -123
- ultralytics/nn/modules/utils.py +15 -20
- ultralytics/nn/tasks.py +327 -203
- ultralytics/nn/text_model.py +81 -65
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +19 -27
- ultralytics/solutions/analytics.py +36 -26
- ultralytics/solutions/config.py +29 -28
- ultralytics/solutions/distance_calculation.py +23 -24
- ultralytics/solutions/heatmap.py +17 -19
- ultralytics/solutions/instance_segmentation.py +21 -19
- ultralytics/solutions/object_blurrer.py +16 -17
- ultralytics/solutions/object_counter.py +48 -53
- ultralytics/solutions/object_cropper.py +22 -16
- ultralytics/solutions/parking_management.py +61 -58
- ultralytics/solutions/queue_management.py +19 -19
- ultralytics/solutions/region_counter.py +63 -50
- ultralytics/solutions/security_alarm.py +22 -25
- ultralytics/solutions/similarity_search.py +107 -60
- ultralytics/solutions/solutions.py +343 -262
- ultralytics/solutions/speed_estimation.py +35 -31
- ultralytics/solutions/streamlit_inference.py +104 -40
- ultralytics/solutions/templates/similarity-search.html +31 -24
- ultralytics/solutions/trackzone.py +24 -24
- ultralytics/solutions/vision_eye.py +11 -12
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +18 -27
- ultralytics/trackers/bot_sort.py +48 -39
- ultralytics/trackers/byte_tracker.py +94 -94
- ultralytics/trackers/track.py +7 -16
- ultralytics/trackers/utils/gmc.py +37 -69
- ultralytics/trackers/utils/kalman_filter.py +68 -76
- ultralytics/trackers/utils/matching.py +13 -17
- ultralytics/utils/__init__.py +251 -275
- ultralytics/utils/autobatch.py +19 -7
- ultralytics/utils/autodevice.py +68 -38
- ultralytics/utils/benchmarks.py +169 -130
- ultralytics/utils/callbacks/base.py +12 -13
- ultralytics/utils/callbacks/clearml.py +14 -15
- ultralytics/utils/callbacks/comet.py +139 -66
- ultralytics/utils/callbacks/dvc.py +19 -27
- ultralytics/utils/callbacks/hub.py +8 -6
- ultralytics/utils/callbacks/mlflow.py +6 -10
- ultralytics/utils/callbacks/neptune.py +11 -19
- ultralytics/utils/callbacks/platform.py +73 -0
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +9 -12
- ultralytics/utils/callbacks/wb.py +33 -30
- ultralytics/utils/checks.py +163 -114
- ultralytics/utils/cpu.py +89 -0
- ultralytics/utils/dist.py +24 -20
- ultralytics/utils/downloads.py +176 -146
- ultralytics/utils/errors.py +11 -13
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +81 -63
- ultralytics/utils/export/imx.py +294 -0
- ultralytics/utils/export/tensorflow.py +217 -0
- ultralytics/utils/files.py +33 -36
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +105 -120
- ultralytics/utils/logger.py +404 -0
- ultralytics/utils/loss.py +99 -61
- ultralytics/utils/metrics.py +649 -478
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +263 -451
- ultralytics/utils/patches.py +70 -31
- ultralytics/utils/plotting.py +253 -223
- ultralytics/utils/tal.py +48 -61
- ultralytics/utils/torch_utils.py +244 -251
- ultralytics/utils/tqdm.py +438 -0
- ultralytics/utils/triton.py +22 -23
- ultralytics/utils/tuner.py +11 -10
- dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
from __future__ import annotations
|
|
2
3
|
|
|
3
4
|
import copy
|
|
4
5
|
import math
|
|
5
6
|
from functools import partial
|
|
6
|
-
from typing import Any, Optional, Tuple, Type, Union
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import torch
|
|
@@ -17,8 +17,7 @@ from .utils import add_decomposed_rel_pos, apply_rotary_enc, compute_axial_cis,
|
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
class DropPath(nn.Module):
|
|
20
|
-
"""
|
|
21
|
-
Implements stochastic depth regularization for neural networks during training.
|
|
20
|
+
"""Implements stochastic depth regularization for neural networks during training.
|
|
22
21
|
|
|
23
22
|
Attributes:
|
|
24
23
|
drop_prob (float): Probability of dropping a path during training.
|
|
@@ -33,14 +32,14 @@ class DropPath(nn.Module):
|
|
|
33
32
|
>>> output = drop_path(x)
|
|
34
33
|
"""
|
|
35
34
|
|
|
36
|
-
def __init__(self, drop_prob=0.0, scale_by_keep=True):
|
|
35
|
+
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
|
|
37
36
|
"""Initialize DropPath module for stochastic depth regularization during training."""
|
|
38
37
|
super().__init__()
|
|
39
38
|
self.drop_prob = drop_prob
|
|
40
39
|
self.scale_by_keep = scale_by_keep
|
|
41
40
|
|
|
42
|
-
def forward(self, x):
|
|
43
|
-
"""
|
|
41
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
42
|
+
"""Apply stochastic depth to input tensor during training, with optional scaling."""
|
|
44
43
|
if self.drop_prob == 0.0 or not self.training:
|
|
45
44
|
return x
|
|
46
45
|
keep_prob = 1 - self.drop_prob
|
|
@@ -52,16 +51,14 @@ class DropPath(nn.Module):
|
|
|
52
51
|
|
|
53
52
|
|
|
54
53
|
class MaskDownSampler(nn.Module):
|
|
55
|
-
"""
|
|
56
|
-
A mask downsampling and embedding module for efficient processing of input masks.
|
|
54
|
+
"""A mask downsampling and embedding module for efficient processing of input masks.
|
|
57
55
|
|
|
58
|
-
This class implements a mask downsampler that progressively reduces the spatial dimensions of input masks
|
|
59
|
-
|
|
60
|
-
functions.
|
|
56
|
+
This class implements a mask downsampler that progressively reduces the spatial dimensions of input masks while
|
|
57
|
+
expanding their channel dimensions using convolutional layers, layer normalization, and activation functions.
|
|
61
58
|
|
|
62
59
|
Attributes:
|
|
63
|
-
encoder (nn.Sequential): A sequential container of convolutional layers, layer normalization, and
|
|
64
|
-
|
|
60
|
+
encoder (nn.Sequential): A sequential container of convolutional layers, layer normalization, and activation
|
|
61
|
+
functions for downsampling and embedding masks.
|
|
65
62
|
|
|
66
63
|
Methods:
|
|
67
64
|
forward: Downsamples and encodes input mask to embed_dim channels.
|
|
@@ -76,14 +73,14 @@ class MaskDownSampler(nn.Module):
|
|
|
76
73
|
|
|
77
74
|
def __init__(
|
|
78
75
|
self,
|
|
79
|
-
embed_dim=256,
|
|
80
|
-
kernel_size=4,
|
|
81
|
-
stride=4,
|
|
82
|
-
padding=0,
|
|
83
|
-
total_stride=16,
|
|
84
|
-
activation=nn.GELU,
|
|
76
|
+
embed_dim: int = 256,
|
|
77
|
+
kernel_size: int = 4,
|
|
78
|
+
stride: int = 4,
|
|
79
|
+
padding: int = 0,
|
|
80
|
+
total_stride: int = 16,
|
|
81
|
+
activation: type[nn.Module] = nn.GELU,
|
|
85
82
|
):
|
|
86
|
-
"""
|
|
83
|
+
"""Initialize a mask downsampler module for progressive downsampling and channel expansion."""
|
|
87
84
|
super().__init__()
|
|
88
85
|
num_layers = int(math.log2(total_stride) // math.log2(stride))
|
|
89
86
|
assert stride**num_layers == total_stride
|
|
@@ -106,17 +103,16 @@ class MaskDownSampler(nn.Module):
|
|
|
106
103
|
|
|
107
104
|
self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
|
|
108
105
|
|
|
109
|
-
def forward(self, x):
|
|
110
|
-
"""
|
|
106
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
107
|
+
"""Downsample and encode input mask to embed_dim channels using convolutional layers and LayerNorm2d."""
|
|
111
108
|
return self.encoder(x)
|
|
112
109
|
|
|
113
110
|
|
|
114
111
|
class CXBlock(nn.Module):
|
|
115
|
-
"""
|
|
116
|
-
ConvNeXt Block for efficient feature extraction in convolutional neural networks.
|
|
112
|
+
"""ConvNeXt Block for efficient feature extraction in convolutional neural networks.
|
|
117
113
|
|
|
118
|
-
This block implements a modified version of the ConvNeXt architecture, offering improved performance and
|
|
119
|
-
|
|
114
|
+
This block implements a modified version of the ConvNeXt architecture, offering improved performance and flexibility
|
|
115
|
+
in feature extraction.
|
|
120
116
|
|
|
121
117
|
Attributes:
|
|
122
118
|
dwconv (nn.Conv2d): Depthwise or standard 2D convolution layer.
|
|
@@ -141,15 +137,14 @@ class CXBlock(nn.Module):
|
|
|
141
137
|
|
|
142
138
|
def __init__(
|
|
143
139
|
self,
|
|
144
|
-
dim,
|
|
145
|
-
kernel_size=7,
|
|
146
|
-
padding=3,
|
|
147
|
-
drop_path=0.0,
|
|
148
|
-
layer_scale_init_value=1e-6,
|
|
149
|
-
use_dwconv=True,
|
|
140
|
+
dim: int,
|
|
141
|
+
kernel_size: int = 7,
|
|
142
|
+
padding: int = 3,
|
|
143
|
+
drop_path: float = 0.0,
|
|
144
|
+
layer_scale_init_value: float = 1e-6,
|
|
145
|
+
use_dwconv: bool = True,
|
|
150
146
|
):
|
|
151
|
-
"""
|
|
152
|
-
Initialize a ConvNeXt Block for efficient feature extraction in convolutional neural networks.
|
|
147
|
+
"""Initialize a ConvNeXt Block for efficient feature extraction in convolutional neural networks.
|
|
153
148
|
|
|
154
149
|
This block implements a modified version of the ConvNeXt architecture, offering improved performance and
|
|
155
150
|
flexibility in feature extraction.
|
|
@@ -188,8 +183,8 @@ class CXBlock(nn.Module):
|
|
|
188
183
|
)
|
|
189
184
|
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
190
185
|
|
|
191
|
-
def forward(self, x):
|
|
192
|
-
"""
|
|
186
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
187
|
+
"""Apply ConvNeXt block operations to input tensor, including convolutions and residual connection."""
|
|
193
188
|
input = x
|
|
194
189
|
x = self.dwconv(x)
|
|
195
190
|
x = self.norm(x)
|
|
@@ -206,8 +201,7 @@ class CXBlock(nn.Module):
|
|
|
206
201
|
|
|
207
202
|
|
|
208
203
|
class Fuser(nn.Module):
|
|
209
|
-
"""
|
|
210
|
-
A module for fusing features through multiple layers of a neural network.
|
|
204
|
+
"""A module for fusing features through multiple layers of a neural network.
|
|
211
205
|
|
|
212
206
|
This class applies a series of identical layers to an input tensor, optionally projecting the input first.
|
|
213
207
|
|
|
@@ -227,9 +221,8 @@ class Fuser(nn.Module):
|
|
|
227
221
|
torch.Size([1, 256, 32, 32])
|
|
228
222
|
"""
|
|
229
223
|
|
|
230
|
-
def __init__(self, layer, num_layers, dim=None, input_projection=False):
|
|
231
|
-
"""
|
|
232
|
-
Initializes the Fuser module for feature fusion through multiple layers.
|
|
224
|
+
def __init__(self, layer: nn.Module, num_layers: int, dim: int | None = None, input_projection: bool = False):
|
|
225
|
+
"""Initialize the Fuser module for feature fusion through multiple layers.
|
|
233
226
|
|
|
234
227
|
This module creates a sequence of identical layers and optionally applies an input projection.
|
|
235
228
|
|
|
@@ -253,8 +246,8 @@ class Fuser(nn.Module):
|
|
|
253
246
|
assert dim is not None
|
|
254
247
|
self.proj = nn.Conv2d(dim, dim, kernel_size=1)
|
|
255
248
|
|
|
256
|
-
def forward(self, x):
|
|
257
|
-
"""
|
|
249
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
250
|
+
"""Apply a series of layers to the input tensor, optionally projecting it first."""
|
|
258
251
|
x = self.proj(x)
|
|
259
252
|
for layer in self.layers:
|
|
260
253
|
x = layer(x)
|
|
@@ -262,12 +255,11 @@ class Fuser(nn.Module):
|
|
|
262
255
|
|
|
263
256
|
|
|
264
257
|
class SAM2TwoWayAttentionBlock(TwoWayAttentionBlock):
|
|
265
|
-
"""
|
|
266
|
-
A two-way attention block for performing self-attention and cross-attention in both directions.
|
|
258
|
+
"""A two-way attention block for performing self-attention and cross-attention in both directions.
|
|
267
259
|
|
|
268
|
-
This block extends the TwoWayAttentionBlock and consists of four main components: self-attention on
|
|
269
|
-
|
|
270
|
-
|
|
260
|
+
This block extends the TwoWayAttentionBlock and consists of four main components: self-attention on sparse inputs,
|
|
261
|
+
cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and cross-attention from dense to sparse
|
|
262
|
+
inputs.
|
|
271
263
|
|
|
272
264
|
Attributes:
|
|
273
265
|
self_attn (Attention): Self-attention layer for queries.
|
|
@@ -295,16 +287,15 @@ class SAM2TwoWayAttentionBlock(TwoWayAttentionBlock):
|
|
|
295
287
|
embedding_dim: int,
|
|
296
288
|
num_heads: int,
|
|
297
289
|
mlp_dim: int = 2048,
|
|
298
|
-
activation:
|
|
290
|
+
activation: type[nn.Module] = nn.ReLU,
|
|
299
291
|
attention_downsample_rate: int = 2,
|
|
300
292
|
skip_first_layer_pe: bool = False,
|
|
301
293
|
) -> None:
|
|
302
|
-
"""
|
|
303
|
-
Initializes a SAM2TwoWayAttentionBlock for performing self-attention and cross-attention in two directions.
|
|
294
|
+
"""Initialize a SAM2TwoWayAttentionBlock for performing self-attention and cross-attention in two directions.
|
|
304
295
|
|
|
305
296
|
This block extends the TwoWayAttentionBlock and consists of four main components: self-attention on sparse
|
|
306
|
-
inputs, cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and cross-attention
|
|
307
|
-
|
|
297
|
+
inputs, cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and cross-attention from
|
|
298
|
+
dense to sparse inputs.
|
|
308
299
|
|
|
309
300
|
Args:
|
|
310
301
|
embedding_dim (int): The channel dimension of the embeddings.
|
|
@@ -325,12 +316,11 @@ class SAM2TwoWayAttentionBlock(TwoWayAttentionBlock):
|
|
|
325
316
|
|
|
326
317
|
|
|
327
318
|
class SAM2TwoWayTransformer(TwoWayTransformer):
|
|
328
|
-
"""
|
|
329
|
-
A Two-Way Transformer module for simultaneous attention to image and query points.
|
|
319
|
+
"""A Two-Way Transformer module for simultaneous attention to image and query points.
|
|
330
320
|
|
|
331
|
-
This class extends the TwoWayTransformer, implementing a specialized transformer decoder that attends to an
|
|
332
|
-
|
|
333
|
-
|
|
321
|
+
This class extends the TwoWayTransformer, implementing a specialized transformer decoder that attends to an input
|
|
322
|
+
image using queries with supplied positional embeddings. It is particularly useful for tasks like object detection,
|
|
323
|
+
image segmentation, and point cloud processing.
|
|
334
324
|
|
|
335
325
|
Attributes:
|
|
336
326
|
depth (int): Number of layers in the transformer.
|
|
@@ -359,14 +349,13 @@ class SAM2TwoWayTransformer(TwoWayTransformer):
|
|
|
359
349
|
embedding_dim: int,
|
|
360
350
|
num_heads: int,
|
|
361
351
|
mlp_dim: int,
|
|
362
|
-
activation:
|
|
352
|
+
activation: type[nn.Module] = nn.ReLU,
|
|
363
353
|
attention_downsample_rate: int = 2,
|
|
364
354
|
) -> None:
|
|
365
|
-
"""
|
|
366
|
-
Initializes a SAM2TwoWayTransformer instance.
|
|
355
|
+
"""Initialize a SAM2TwoWayTransformer instance.
|
|
367
356
|
|
|
368
|
-
This transformer decoder attends to an input image using queries with supplied positional embeddings.
|
|
369
|
-
|
|
357
|
+
This transformer decoder attends to an input image using queries with supplied positional embeddings. It is
|
|
358
|
+
designed for tasks like object detection, image segmentation, and point cloud processing.
|
|
370
359
|
|
|
371
360
|
Args:
|
|
372
361
|
depth (int): Number of layers in the transformer.
|
|
@@ -403,15 +392,14 @@ class SAM2TwoWayTransformer(TwoWayTransformer):
|
|
|
403
392
|
|
|
404
393
|
|
|
405
394
|
class RoPEAttention(Attention):
|
|
406
|
-
"""
|
|
407
|
-
Implements rotary position encoding for attention mechanisms in transformer architectures.
|
|
395
|
+
"""Implements rotary position encoding for attention mechanisms in transformer architectures.
|
|
408
396
|
|
|
409
|
-
This class extends the base Attention class by incorporating Rotary Position Encoding (RoPE) to enhance
|
|
410
|
-
|
|
397
|
+
This class extends the base Attention class by incorporating Rotary Position Encoding (RoPE) to enhance the
|
|
398
|
+
positional awareness of the attention mechanism.
|
|
411
399
|
|
|
412
400
|
Attributes:
|
|
413
401
|
compute_cis (Callable): Function to compute axial complex numbers for rotary encoding.
|
|
414
|
-
freqs_cis (Tensor): Precomputed frequency tensor for rotary encoding.
|
|
402
|
+
freqs_cis (torch.Tensor): Precomputed frequency tensor for rotary encoding.
|
|
415
403
|
rope_k_repeat (bool): Flag to repeat query RoPE to match key length for cross-attention to memories.
|
|
416
404
|
|
|
417
405
|
Methods:
|
|
@@ -430,12 +418,12 @@ class RoPEAttention(Attention):
|
|
|
430
418
|
def __init__(
|
|
431
419
|
self,
|
|
432
420
|
*args,
|
|
433
|
-
rope_theta=10000.0,
|
|
434
|
-
rope_k_repeat=False,
|
|
435
|
-
feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution
|
|
421
|
+
rope_theta: float = 10000.0,
|
|
422
|
+
rope_k_repeat: bool = False,
|
|
423
|
+
feat_sizes: tuple[int, int] = (32, 32), # [w, h] for stride 16 feats at 512 resolution
|
|
436
424
|
**kwargs,
|
|
437
425
|
):
|
|
438
|
-
"""
|
|
426
|
+
"""Initialize RoPEAttention with rotary position encoding for enhanced positional awareness."""
|
|
439
427
|
super().__init__(*args, **kwargs)
|
|
440
428
|
|
|
441
429
|
self.compute_cis = partial(compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta)
|
|
@@ -443,8 +431,8 @@ class RoPEAttention(Attention):
|
|
|
443
431
|
self.freqs_cis = freqs_cis
|
|
444
432
|
self.rope_k_repeat = rope_k_repeat # repeat q rope to match k length, needed for cross-attention to memories
|
|
445
433
|
|
|
446
|
-
def forward(self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0) -> Tensor:
|
|
447
|
-
"""
|
|
434
|
+
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_k_exclude_rope: int = 0) -> torch.Tensor:
|
|
435
|
+
"""Apply rotary position encoding and compute attention between query, key, and value tensors."""
|
|
448
436
|
q = self.q_proj(q)
|
|
449
437
|
k = self.k_proj(k)
|
|
450
438
|
v = self.v_proj(v)
|
|
@@ -486,7 +474,7 @@ class RoPEAttention(Attention):
|
|
|
486
474
|
|
|
487
475
|
|
|
488
476
|
def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
|
|
489
|
-
"""
|
|
477
|
+
"""Apply pooling and optional normalization to a tensor, handling spatial dimension permutations."""
|
|
490
478
|
if pool is None:
|
|
491
479
|
return x
|
|
492
480
|
# (B, H, W, C) -> (B, C, H, W)
|
|
@@ -501,12 +489,11 @@ def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.T
|
|
|
501
489
|
|
|
502
490
|
|
|
503
491
|
class MultiScaleAttention(nn.Module):
|
|
504
|
-
"""
|
|
505
|
-
Implements multiscale self-attention with optional query pooling for efficient feature extraction.
|
|
492
|
+
"""Implements multiscale self-attention with optional query pooling for efficient feature extraction.
|
|
506
493
|
|
|
507
|
-
This class provides a flexible implementation of multiscale attention, allowing for optional
|
|
508
|
-
|
|
509
|
-
|
|
494
|
+
This class provides a flexible implementation of multiscale attention, allowing for optional downsampling of query
|
|
495
|
+
features through pooling. It's designed to enhance the model's ability to capture multiscale information in visual
|
|
496
|
+
tasks.
|
|
510
497
|
|
|
511
498
|
Attributes:
|
|
512
499
|
dim (int): Input dimension of the feature map.
|
|
@@ -537,7 +524,7 @@ class MultiScaleAttention(nn.Module):
|
|
|
537
524
|
num_heads: int,
|
|
538
525
|
q_pool: nn.Module = None,
|
|
539
526
|
):
|
|
540
|
-
"""
|
|
527
|
+
"""Initialize multiscale attention with optional query pooling for efficient feature extraction."""
|
|
541
528
|
super().__init__()
|
|
542
529
|
|
|
543
530
|
self.dim = dim
|
|
@@ -552,7 +539,7 @@ class MultiScaleAttention(nn.Module):
|
|
|
552
539
|
self.proj = nn.Linear(dim_out, dim_out)
|
|
553
540
|
|
|
554
541
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
555
|
-
"""
|
|
542
|
+
"""Apply multiscale attention with optional query pooling to extract multiscale features."""
|
|
556
543
|
B, H, W, _ = x.shape
|
|
557
544
|
# qkv with shape (B, H * W, 3, nHead, C)
|
|
558
545
|
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
|
|
@@ -581,11 +568,10 @@ class MultiScaleAttention(nn.Module):
|
|
|
581
568
|
|
|
582
569
|
|
|
583
570
|
class MultiScaleBlock(nn.Module):
|
|
584
|
-
"""
|
|
585
|
-
A multiscale attention block with window partitioning and query pooling for efficient vision transformers.
|
|
571
|
+
"""A multiscale attention block with window partitioning and query pooling for efficient vision transformers.
|
|
586
572
|
|
|
587
|
-
This class implements a multiscale attention mechanism with optional window partitioning and downsampling,
|
|
588
|
-
|
|
573
|
+
This class implements a multiscale attention mechanism with optional window partitioning and downsampling, designed
|
|
574
|
+
for use in vision transformer architectures.
|
|
589
575
|
|
|
590
576
|
Attributes:
|
|
591
577
|
dim (int): Input dimension of the block.
|
|
@@ -593,7 +579,7 @@ class MultiScaleBlock(nn.Module):
|
|
|
593
579
|
norm1 (nn.Module): First normalization layer.
|
|
594
580
|
window_size (int): Size of the window for partitioning.
|
|
595
581
|
pool (nn.Module | None): Pooling layer for query downsampling.
|
|
596
|
-
q_stride (
|
|
582
|
+
q_stride (tuple[int, int] | None): Stride for query pooling.
|
|
597
583
|
attn (MultiScaleAttention): Multi-scale attention module.
|
|
598
584
|
drop_path (nn.Module): Drop path layer for regularization.
|
|
599
585
|
norm2 (nn.Module): Second normalization layer.
|
|
@@ -618,12 +604,12 @@ class MultiScaleBlock(nn.Module):
|
|
|
618
604
|
num_heads: int,
|
|
619
605
|
mlp_ratio: float = 4.0,
|
|
620
606
|
drop_path: float = 0.0,
|
|
621
|
-
norm_layer:
|
|
622
|
-
q_stride:
|
|
623
|
-
act_layer: nn.Module = nn.GELU,
|
|
607
|
+
norm_layer: nn.Module | str = "LayerNorm",
|
|
608
|
+
q_stride: tuple[int, int] | None = None,
|
|
609
|
+
act_layer: type[nn.Module] = nn.GELU,
|
|
624
610
|
window_size: int = 0,
|
|
625
611
|
):
|
|
626
|
-
"""
|
|
612
|
+
"""Initialize a multiscale attention block with window partitioning and optional query pooling."""
|
|
627
613
|
super().__init__()
|
|
628
614
|
|
|
629
615
|
if isinstance(norm_layer, str):
|
|
@@ -660,7 +646,7 @@ class MultiScaleBlock(nn.Module):
|
|
|
660
646
|
self.proj = nn.Linear(dim, dim_out)
|
|
661
647
|
|
|
662
648
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
663
|
-
"""
|
|
649
|
+
"""Process input through multiscale attention and MLP, with optional windowing and downsampling."""
|
|
664
650
|
shortcut = x # B, H, W, C
|
|
665
651
|
x = self.norm1(x)
|
|
666
652
|
|
|
@@ -696,11 +682,10 @@ class MultiScaleBlock(nn.Module):
|
|
|
696
682
|
|
|
697
683
|
|
|
698
684
|
class PositionEmbeddingSine(nn.Module):
|
|
699
|
-
"""
|
|
700
|
-
A module for generating sinusoidal positional embeddings for 2D inputs like images.
|
|
685
|
+
"""A module for generating sinusoidal positional embeddings for 2D inputs like images.
|
|
701
686
|
|
|
702
|
-
This class implements sinusoidal position encoding for 2D spatial positions, which can be used in
|
|
703
|
-
|
|
687
|
+
This class implements sinusoidal position encoding for 2D spatial positions, which can be used in transformer-based
|
|
688
|
+
models for computer vision tasks.
|
|
704
689
|
|
|
705
690
|
Attributes:
|
|
706
691
|
num_pos_feats (int): Number of positional features (half of the embedding dimension).
|
|
@@ -725,12 +710,12 @@ class PositionEmbeddingSine(nn.Module):
|
|
|
725
710
|
|
|
726
711
|
def __init__(
|
|
727
712
|
self,
|
|
728
|
-
num_pos_feats,
|
|
713
|
+
num_pos_feats: int,
|
|
729
714
|
temperature: int = 10000,
|
|
730
715
|
normalize: bool = True,
|
|
731
|
-
scale:
|
|
716
|
+
scale: float | None = None,
|
|
732
717
|
):
|
|
733
|
-
"""
|
|
718
|
+
"""Initialize sinusoidal position embeddings for 2D image inputs."""
|
|
734
719
|
super().__init__()
|
|
735
720
|
assert num_pos_feats % 2 == 0, "Expecting even model width"
|
|
736
721
|
self.num_pos_feats = num_pos_feats // 2
|
|
@@ -744,8 +729,8 @@ class PositionEmbeddingSine(nn.Module):
|
|
|
744
729
|
|
|
745
730
|
self.cache = {}
|
|
746
731
|
|
|
747
|
-
def _encode_xy(self, x, y):
|
|
748
|
-
"""
|
|
732
|
+
def _encode_xy(self, x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
733
|
+
"""Encode 2D positions using sine/cosine functions for transformer positional embeddings."""
|
|
749
734
|
assert len(x) == len(y) and x.ndim == y.ndim == 1
|
|
750
735
|
x_embed = x * self.scale
|
|
751
736
|
y_embed = y * self.scale
|
|
@@ -760,16 +745,16 @@ class PositionEmbeddingSine(nn.Module):
|
|
|
760
745
|
return pos_x, pos_y
|
|
761
746
|
|
|
762
747
|
@torch.no_grad()
|
|
763
|
-
def encode_boxes(self, x, y, w, h):
|
|
764
|
-
"""
|
|
748
|
+
def encode_boxes(self, x: torch.Tensor, y: torch.Tensor, w: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
|
|
749
|
+
"""Encode box coordinates and dimensions into positional embeddings for detection."""
|
|
765
750
|
pos_x, pos_y = self._encode_xy(x, y)
|
|
766
751
|
return torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
|
|
767
752
|
|
|
768
753
|
encode = encode_boxes # Backwards compatibility
|
|
769
754
|
|
|
770
755
|
@torch.no_grad()
|
|
771
|
-
def encode_points(self, x, y, labels):
|
|
772
|
-
"""
|
|
756
|
+
def encode_points(self, x: torch.Tensor, y: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
|
757
|
+
"""Encode 2D points with sinusoidal embeddings and append labels."""
|
|
773
758
|
(bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
|
|
774
759
|
assert bx == by and nx == ny and bx == bl and nx == nl
|
|
775
760
|
pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
|
|
@@ -777,8 +762,8 @@ class PositionEmbeddingSine(nn.Module):
|
|
|
777
762
|
return torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
|
|
778
763
|
|
|
779
764
|
@torch.no_grad()
|
|
780
|
-
def forward(self, x: torch.Tensor):
|
|
781
|
-
"""
|
|
765
|
+
def forward(self, x: torch.Tensor) -> Tensor:
|
|
766
|
+
"""Generate sinusoidal position embeddings for 2D inputs like images."""
|
|
782
767
|
cache_key = (x.shape[-2], x.shape[-1])
|
|
783
768
|
if cache_key in self.cache:
|
|
784
769
|
return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
|
|
@@ -811,8 +796,7 @@ class PositionEmbeddingSine(nn.Module):
|
|
|
811
796
|
|
|
812
797
|
|
|
813
798
|
class PositionEmbeddingRandom(nn.Module):
|
|
814
|
-
"""
|
|
815
|
-
Positional encoding using random spatial frequencies.
|
|
799
|
+
"""Positional encoding using random spatial frequencies.
|
|
816
800
|
|
|
817
801
|
This class generates positional embeddings for input coordinates using random spatial frequencies. It is
|
|
818
802
|
particularly useful for transformer-based models that require position information.
|
|
@@ -833,8 +817,8 @@ class PositionEmbeddingRandom(nn.Module):
|
|
|
833
817
|
torch.Size([128, 32, 32])
|
|
834
818
|
"""
|
|
835
819
|
|
|
836
|
-
def __init__(self, num_pos_feats: int = 64, scale:
|
|
837
|
-
"""
|
|
820
|
+
def __init__(self, num_pos_feats: int = 64, scale: float | None = None) -> None:
|
|
821
|
+
"""Initialize random spatial frequency position embedding for transformers."""
|
|
838
822
|
super().__init__()
|
|
839
823
|
if scale is None or scale <= 0.0:
|
|
840
824
|
scale = 1.0
|
|
@@ -845,7 +829,7 @@ class PositionEmbeddingRandom(nn.Module):
|
|
|
845
829
|
torch.backends.cudnn.deterministic = False
|
|
846
830
|
|
|
847
831
|
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
|
|
848
|
-
"""
|
|
832
|
+
"""Encode normalized [0,1] coordinates using random spatial frequencies."""
|
|
849
833
|
# Assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
|
|
850
834
|
coords = 2 * coords - 1
|
|
851
835
|
coords = coords @ self.positional_encoding_gaussian_matrix
|
|
@@ -853,11 +837,14 @@ class PositionEmbeddingRandom(nn.Module):
|
|
|
853
837
|
# Outputs d_1 x ... x d_n x C shape
|
|
854
838
|
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
|
|
855
839
|
|
|
856
|
-
def forward(self, size:
|
|
857
|
-
"""
|
|
840
|
+
def forward(self, size: tuple[int, int]) -> torch.Tensor:
|
|
841
|
+
"""Generate positional encoding for a grid using random spatial frequencies."""
|
|
858
842
|
h, w = size
|
|
859
|
-
|
|
860
|
-
|
|
843
|
+
grid = torch.ones(
|
|
844
|
+
(h, w),
|
|
845
|
+
device=self.positional_encoding_gaussian_matrix.device,
|
|
846
|
+
dtype=self.positional_encoding_gaussian_matrix.dtype,
|
|
847
|
+
)
|
|
861
848
|
y_embed = grid.cumsum(dim=0) - 0.5
|
|
862
849
|
x_embed = grid.cumsum(dim=1) - 0.5
|
|
863
850
|
y_embed = y_embed / h
|
|
@@ -866,21 +853,20 @@ class PositionEmbeddingRandom(nn.Module):
|
|
|
866
853
|
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
|
|
867
854
|
return pe.permute(2, 0, 1) # C x H x W
|
|
868
855
|
|
|
869
|
-
def forward_with_coords(self, coords_input: torch.Tensor, image_size:
|
|
870
|
-
"""Positionally
|
|
856
|
+
def forward_with_coords(self, coords_input: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
|
|
857
|
+
"""Positionally encode input coordinates, normalizing them to [0,1] based on the given image size."""
|
|
871
858
|
coords = coords_input.clone()
|
|
872
859
|
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
|
|
873
860
|
coords[:, :, 1] = coords[:, :, 1] / image_size[0]
|
|
874
|
-
return self._pe_encoding(coords
|
|
861
|
+
return self._pe_encoding(coords) # B x N x C
|
|
875
862
|
|
|
876
863
|
|
|
877
864
|
class Block(nn.Module):
|
|
878
|
-
"""
|
|
879
|
-
Transformer block with support for window attention and residual propagation.
|
|
865
|
+
"""Transformer block with support for window attention and residual propagation.
|
|
880
866
|
|
|
881
|
-
This class implements a transformer block that can use either global or windowed self-attention,
|
|
882
|
-
|
|
883
|
-
|
|
867
|
+
This class implements a transformer block that can use either global or windowed self-attention, followed by a
|
|
868
|
+
feed-forward network. It supports relative positional embeddings and is designed for use in vision transformer
|
|
869
|
+
architectures.
|
|
884
870
|
|
|
885
871
|
Attributes:
|
|
886
872
|
norm1 (nn.Module): First normalization layer.
|
|
@@ -907,19 +893,18 @@ class Block(nn.Module):
|
|
|
907
893
|
num_heads: int,
|
|
908
894
|
mlp_ratio: float = 4.0,
|
|
909
895
|
qkv_bias: bool = True,
|
|
910
|
-
norm_layer:
|
|
911
|
-
act_layer:
|
|
896
|
+
norm_layer: type[nn.Module] = nn.LayerNorm,
|
|
897
|
+
act_layer: type[nn.Module] = nn.GELU,
|
|
912
898
|
use_rel_pos: bool = False,
|
|
913
899
|
rel_pos_zero_init: bool = True,
|
|
914
900
|
window_size: int = 0,
|
|
915
|
-
input_size:
|
|
901
|
+
input_size: tuple[int, int] | None = None,
|
|
916
902
|
) -> None:
|
|
917
|
-
"""
|
|
918
|
-
Initializes a transformer block with optional window attention and relative positional embeddings.
|
|
903
|
+
"""Initialize a transformer block with optional window attention and relative positional embeddings.
|
|
919
904
|
|
|
920
|
-
This constructor sets up a transformer block that can use either global or windowed self-attention,
|
|
921
|
-
|
|
922
|
-
|
|
905
|
+
This constructor sets up a transformer block that can use either global or windowed self-attention, followed by
|
|
906
|
+
a feed-forward network. It supports relative positional embeddings and is designed for use in vision transformer
|
|
907
|
+
architectures.
|
|
923
908
|
|
|
924
909
|
Args:
|
|
925
910
|
dim (int): Number of input channels.
|
|
@@ -931,7 +916,7 @@ class Block(nn.Module):
|
|
|
931
916
|
use_rel_pos (bool): If True, uses relative positional embeddings in attention.
|
|
932
917
|
rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero.
|
|
933
918
|
window_size (int): Size of attention window. If 0, uses global attention.
|
|
934
|
-
input_size (
|
|
919
|
+
input_size (tuple[int, int] | None): Input resolution for calculating relative positional parameter size.
|
|
935
920
|
|
|
936
921
|
Examples:
|
|
937
922
|
>>> block = Block(dim=256, num_heads=8, window_size=7)
|
|
@@ -957,7 +942,7 @@ class Block(nn.Module):
|
|
|
957
942
|
self.window_size = window_size
|
|
958
943
|
|
|
959
944
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
960
|
-
"""
|
|
945
|
+
"""Process input through transformer block with optional windowed self-attention and residual connection."""
|
|
961
946
|
shortcut = x
|
|
962
947
|
x = self.norm1(x)
|
|
963
948
|
# Window partition
|
|
@@ -975,35 +960,30 @@ class Block(nn.Module):
|
|
|
975
960
|
|
|
976
961
|
|
|
977
962
|
class REAttention(nn.Module):
|
|
978
|
-
"""
|
|
979
|
-
Rotary Embedding Attention module for efficient self-attention in transformer architectures.
|
|
963
|
+
"""Relative Position Attention module for efficient self-attention in transformer architectures.
|
|
980
964
|
|
|
981
|
-
This class implements a multi-head attention mechanism with
|
|
982
|
-
|
|
983
|
-
|
|
965
|
+
This class implements a multi-head attention mechanism with relative positional embeddings, designed for use in
|
|
966
|
+
vision transformer models. It supports optional query pooling and window partitioning for efficient processing of
|
|
967
|
+
large inputs.
|
|
984
968
|
|
|
985
969
|
Attributes:
|
|
986
|
-
compute_cis (Callable): Function to compute axial complex numbers for rotary encoding.
|
|
987
|
-
freqs_cis (Tensor): Precomputed frequency tensor for rotary encoding.
|
|
988
|
-
rope_k_repeat (bool): Flag to repeat query RoPE to match key length for cross-attention to memories.
|
|
989
|
-
q_proj (nn.Linear): Linear projection for query.
|
|
990
|
-
k_proj (nn.Linear): Linear projection for key.
|
|
991
|
-
v_proj (nn.Linear): Linear projection for value.
|
|
992
|
-
out_proj (nn.Linear): Output projection.
|
|
993
970
|
num_heads (int): Number of attention heads.
|
|
994
|
-
|
|
971
|
+
scale (float): Scaling factor for attention computation.
|
|
972
|
+
qkv (nn.Linear): Linear projection for query, key, and value.
|
|
973
|
+
proj (nn.Linear): Output projection layer.
|
|
974
|
+
use_rel_pos (bool): Whether to use relative positional embeddings.
|
|
975
|
+
rel_pos_h (nn.Parameter): Relative positional embeddings for height dimension.
|
|
976
|
+
rel_pos_w (nn.Parameter): Relative positional embeddings for width dimension.
|
|
995
977
|
|
|
996
978
|
Methods:
|
|
997
|
-
forward: Applies
|
|
979
|
+
forward: Applies multi-head attention with optional relative positional encoding to input tensor.
|
|
998
980
|
|
|
999
981
|
Examples:
|
|
1000
|
-
>>>
|
|
1001
|
-
>>>
|
|
1002
|
-
>>>
|
|
1003
|
-
>>> v = torch.randn(1, 1024, 256)
|
|
1004
|
-
>>> output = rope_attn(q, k, v)
|
|
982
|
+
>>> attention = REAttention(dim=256, num_heads=8, input_size=(32, 32))
|
|
983
|
+
>>> x = torch.randn(1, 32, 32, 256)
|
|
984
|
+
>>> output = attention(x)
|
|
1005
985
|
>>> print(output.shape)
|
|
1006
|
-
torch.Size([1,
|
|
986
|
+
torch.Size([1, 32, 32, 256])
|
|
1007
987
|
"""
|
|
1008
988
|
|
|
1009
989
|
def __init__(
|
|
@@ -1013,22 +993,21 @@ class REAttention(nn.Module):
|
|
|
1013
993
|
qkv_bias: bool = True,
|
|
1014
994
|
use_rel_pos: bool = False,
|
|
1015
995
|
rel_pos_zero_init: bool = True,
|
|
1016
|
-
input_size:
|
|
996
|
+
input_size: tuple[int, int] | None = None,
|
|
1017
997
|
) -> None:
|
|
1018
|
-
"""
|
|
1019
|
-
Initializes a Relative Position Attention module for transformer-based architectures.
|
|
998
|
+
"""Initialize a Relative Position Attention module for transformer-based architectures.
|
|
1020
999
|
|
|
1021
|
-
This module implements multi-head attention with optional relative positional encodings, designed
|
|
1022
|
-
|
|
1000
|
+
This module implements multi-head attention with optional relative positional encodings, designed specifically
|
|
1001
|
+
for vision tasks in transformer models.
|
|
1023
1002
|
|
|
1024
1003
|
Args:
|
|
1025
1004
|
dim (int): Number of input channels.
|
|
1026
|
-
num_heads (int): Number of attention heads.
|
|
1027
|
-
qkv_bias (bool): If True, adds a learnable bias to query, key, value projections.
|
|
1028
|
-
use_rel_pos (bool): If True, uses relative positional encodings.
|
|
1029
|
-
rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero.
|
|
1030
|
-
input_size (
|
|
1031
|
-
Required if use_rel_pos is True.
|
|
1005
|
+
num_heads (int): Number of attention heads.
|
|
1006
|
+
qkv_bias (bool): If True, adds a learnable bias to query, key, value projections.
|
|
1007
|
+
use_rel_pos (bool): If True, uses relative positional encodings.
|
|
1008
|
+
rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero.
|
|
1009
|
+
input_size (tuple[int, int] | None): Input resolution for calculating relative positional parameter size.
|
|
1010
|
+
Required if use_rel_pos is True.
|
|
1032
1011
|
|
|
1033
1012
|
Examples:
|
|
1034
1013
|
>>> attention = REAttention(dim=256, num_heads=8, input_size=(32, 32))
|
|
@@ -1053,7 +1032,7 @@ class REAttention(nn.Module):
|
|
|
1053
1032
|
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
|
|
1054
1033
|
|
|
1055
1034
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
1056
|
-
"""
|
|
1035
|
+
"""Apply multi-head attention with optional relative positional encoding to input tensor."""
|
|
1057
1036
|
B, H, W, _ = x.shape
|
|
1058
1037
|
# qkv with shape (3, B, nHead, H * W, C)
|
|
1059
1038
|
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
|
@@ -1071,12 +1050,11 @@ class REAttention(nn.Module):
|
|
|
1071
1050
|
|
|
1072
1051
|
|
|
1073
1052
|
class PatchEmbed(nn.Module):
|
|
1074
|
-
"""
|
|
1075
|
-
Image to Patch Embedding module for vision transformer architectures.
|
|
1053
|
+
"""Image to Patch Embedding module for vision transformer architectures.
|
|
1076
1054
|
|
|
1077
|
-
This module converts an input image into a sequence of patch embeddings using a convolutional layer.
|
|
1078
|
-
|
|
1079
|
-
|
|
1055
|
+
This module converts an input image into a sequence of patch embeddings using a convolutional layer. It is commonly
|
|
1056
|
+
used as the first layer in vision transformer architectures to transform image data into a suitable format for
|
|
1057
|
+
subsequent transformer blocks.
|
|
1080
1058
|
|
|
1081
1059
|
Attributes:
|
|
1082
1060
|
proj (nn.Conv2d): Convolutional layer for projecting image patches to embeddings.
|
|
@@ -1094,22 +1072,21 @@ class PatchEmbed(nn.Module):
|
|
|
1094
1072
|
|
|
1095
1073
|
def __init__(
|
|
1096
1074
|
self,
|
|
1097
|
-
kernel_size:
|
|
1098
|
-
stride:
|
|
1099
|
-
padding:
|
|
1075
|
+
kernel_size: tuple[int, int] = (16, 16),
|
|
1076
|
+
stride: tuple[int, int] = (16, 16),
|
|
1077
|
+
padding: tuple[int, int] = (0, 0),
|
|
1100
1078
|
in_chans: int = 3,
|
|
1101
1079
|
embed_dim: int = 768,
|
|
1102
1080
|
) -> None:
|
|
1103
|
-
"""
|
|
1104
|
-
Initializes the PatchEmbed module for converting image patches to embeddings.
|
|
1081
|
+
"""Initialize the PatchEmbed module for converting image patches to embeddings.
|
|
1105
1082
|
|
|
1106
|
-
This module is typically used as the first layer in vision transformer architectures to transform
|
|
1107
|
-
|
|
1083
|
+
This module is typically used as the first layer in vision transformer architectures to transform image data
|
|
1084
|
+
into a suitable format for subsequent transformer blocks.
|
|
1108
1085
|
|
|
1109
1086
|
Args:
|
|
1110
|
-
kernel_size (
|
|
1111
|
-
stride (
|
|
1112
|
-
padding (
|
|
1087
|
+
kernel_size (tuple[int, int]): Size of the convolutional kernel for patch extraction.
|
|
1088
|
+
stride (tuple[int, int]): Stride of the convolutional operation.
|
|
1089
|
+
padding (tuple[int, int]): Padding applied to the input before convolution.
|
|
1113
1090
|
in_chans (int): Number of input image channels.
|
|
1114
1091
|
embed_dim (int): Dimensionality of the output patch embeddings.
|
|
1115
1092
|
|
|
@@ -1125,5 +1102,5 @@ class PatchEmbed(nn.Module):
|
|
|
1125
1102
|
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
|
|
1126
1103
|
|
|
1127
1104
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
1128
|
-
"""
|
|
1105
|
+
"""Compute patch embedding by applying convolution and transposing resulting tensor."""
|
|
1129
1106
|
return self.proj(x).permute(0, 2, 3, 1) # B C H W -> B H W C
|