ultralytics 8.1.28__py3-none-any.whl → 8.3.62__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +22 -0
- tests/conftest.py +83 -0
- tests/test_cli.py +122 -0
- tests/test_cuda.py +155 -0
- tests/test_engine.py +131 -0
- tests/test_exports.py +216 -0
- tests/test_integrations.py +150 -0
- tests/test_python.py +615 -0
- tests/test_solutions.py +94 -0
- ultralytics/__init__.py +11 -8
- ultralytics/cfg/__init__.py +569 -131
- ultralytics/cfg/datasets/Argoverse.yaml +2 -1
- ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
- ultralytics/cfg/datasets/ImageNet.yaml +2 -1
- ultralytics/cfg/datasets/Objects365.yaml +5 -4
- ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
- ultralytics/cfg/datasets/VOC.yaml +3 -2
- ultralytics/cfg/datasets/VisDrone.yaml +6 -5
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
- ultralytics/cfg/datasets/coco-pose.yaml +7 -6
- ultralytics/cfg/datasets/coco.yaml +3 -2
- ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
- ultralytics/cfg/datasets/coco128.yaml +4 -3
- ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
- ultralytics/cfg/datasets/coco8.yaml +3 -2
- ultralytics/cfg/datasets/crack-seg.yaml +3 -2
- ultralytics/cfg/datasets/dog-pose.yaml +24 -0
- ultralytics/cfg/datasets/dota8.yaml +3 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
- ultralytics/cfg/datasets/lvis.yaml +1236 -0
- ultralytics/cfg/datasets/medical-pills.yaml +22 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
- ultralytics/cfg/datasets/package-seg.yaml +5 -4
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
- ultralytics/cfg/datasets/xView.yaml +2 -1
- ultralytics/cfg/default.yaml +14 -11
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
- ultralytics/cfg/models/v3/yolov3.yaml +5 -2
- ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
- ultralytics/cfg/models/v5/yolov5.yaml +5 -2
- ultralytics/cfg/models/v6/yolov6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8.yaml +5 -2
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/solutions/default.yaml +24 -0
- ultralytics/cfg/trackers/botsort.yaml +8 -5
- ultralytics/cfg/trackers/bytetrack.yaml +8 -5
- ultralytics/data/__init__.py +14 -3
- ultralytics/data/annotator.py +37 -15
- ultralytics/data/augment.py +1783 -289
- ultralytics/data/base.py +62 -27
- ultralytics/data/build.py +36 -8
- ultralytics/data/converter.py +196 -36
- ultralytics/data/dataset.py +233 -94
- ultralytics/data/loaders.py +199 -96
- ultralytics/data/split_dota.py +39 -29
- ultralytics/data/utils.py +110 -40
- ultralytics/engine/__init__.py +1 -1
- ultralytics/engine/exporter.py +569 -242
- ultralytics/engine/model.py +604 -252
- ultralytics/engine/predictor.py +22 -11
- ultralytics/engine/results.py +1228 -218
- ultralytics/engine/trainer.py +190 -129
- ultralytics/engine/tuner.py +18 -18
- ultralytics/engine/validator.py +18 -15
- ultralytics/hub/__init__.py +31 -13
- ultralytics/hub/auth.py +11 -7
- ultralytics/hub/google/__init__.py +159 -0
- ultralytics/hub/session.py +128 -94
- ultralytics/hub/utils.py +20 -21
- ultralytics/models/__init__.py +4 -2
- ultralytics/models/fastsam/__init__.py +2 -3
- ultralytics/models/fastsam/model.py +26 -4
- ultralytics/models/fastsam/predict.py +127 -63
- ultralytics/models/fastsam/utils.py +1 -44
- ultralytics/models/fastsam/val.py +1 -1
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +21 -10
- ultralytics/models/nas/predict.py +3 -6
- ultralytics/models/nas/val.py +4 -4
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +1 -1
- ultralytics/models/rtdetr/predict.py +6 -8
- ultralytics/models/rtdetr/train.py +6 -2
- ultralytics/models/rtdetr/val.py +3 -3
- ultralytics/models/sam/__init__.py +3 -3
- ultralytics/models/sam/amg.py +29 -23
- ultralytics/models/sam/build.py +211 -13
- ultralytics/models/sam/model.py +91 -30
- ultralytics/models/sam/modules/__init__.py +1 -1
- ultralytics/models/sam/modules/blocks.py +1129 -0
- ultralytics/models/sam/modules/decoders.py +381 -53
- ultralytics/models/sam/modules/encoders.py +515 -324
- ultralytics/models/sam/modules/memory_attention.py +237 -0
- ultralytics/models/sam/modules/sam.py +969 -21
- ultralytics/models/sam/modules/tiny_encoder.py +425 -154
- ultralytics/models/sam/modules/transformer.py +159 -60
- ultralytics/models/sam/modules/utils.py +293 -0
- ultralytics/models/sam/predict.py +1263 -132
- ultralytics/models/utils/__init__.py +1 -1
- ultralytics/models/utils/loss.py +36 -24
- ultralytics/models/utils/ops.py +3 -7
- ultralytics/models/yolo/__init__.py +3 -3
- ultralytics/models/yolo/classify/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +7 -8
- ultralytics/models/yolo/classify/train.py +17 -22
- ultralytics/models/yolo/classify/val.py +8 -4
- ultralytics/models/yolo/detect/__init__.py +1 -1
- ultralytics/models/yolo/detect/predict.py +3 -5
- ultralytics/models/yolo/detect/train.py +11 -4
- ultralytics/models/yolo/detect/val.py +90 -52
- ultralytics/models/yolo/model.py +14 -9
- ultralytics/models/yolo/obb/__init__.py +1 -1
- ultralytics/models/yolo/obb/predict.py +2 -2
- ultralytics/models/yolo/obb/train.py +5 -3
- ultralytics/models/yolo/obb/val.py +41 -23
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +3 -5
- ultralytics/models/yolo/pose/train.py +2 -2
- ultralytics/models/yolo/pose/val.py +51 -17
- ultralytics/models/yolo/segment/__init__.py +1 -1
- ultralytics/models/yolo/segment/predict.py +3 -5
- ultralytics/models/yolo/segment/train.py +2 -2
- ultralytics/models/yolo/segment/val.py +60 -19
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +92 -0
- ultralytics/models/yolo/world/train_world.py +109 -0
- ultralytics/nn/__init__.py +1 -1
- ultralytics/nn/autobackend.py +228 -93
- ultralytics/nn/modules/__init__.py +39 -14
- ultralytics/nn/modules/activation.py +21 -0
- ultralytics/nn/modules/block.py +527 -67
- ultralytics/nn/modules/conv.py +24 -7
- ultralytics/nn/modules/head.py +177 -34
- ultralytics/nn/modules/transformer.py +6 -5
- ultralytics/nn/modules/utils.py +1 -2
- ultralytics/nn/tasks.py +225 -77
- ultralytics/solutions/__init__.py +30 -1
- ultralytics/solutions/ai_gym.py +96 -143
- ultralytics/solutions/analytics.py +247 -0
- ultralytics/solutions/distance_calculation.py +78 -135
- ultralytics/solutions/heatmap.py +93 -247
- ultralytics/solutions/object_counter.py +184 -259
- ultralytics/solutions/parking_management.py +246 -0
- ultralytics/solutions/queue_management.py +112 -0
- ultralytics/solutions/region_counter.py +116 -0
- ultralytics/solutions/security_alarm.py +144 -0
- ultralytics/solutions/solutions.py +178 -0
- ultralytics/solutions/speed_estimation.py +86 -174
- ultralytics/solutions/streamlit_inference.py +190 -0
- ultralytics/solutions/trackzone.py +68 -0
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +32 -13
- ultralytics/trackers/bot_sort.py +61 -28
- ultralytics/trackers/byte_tracker.py +83 -51
- ultralytics/trackers/track.py +21 -6
- ultralytics/trackers/utils/__init__.py +1 -1
- ultralytics/trackers/utils/gmc.py +62 -48
- ultralytics/trackers/utils/kalman_filter.py +166 -35
- ultralytics/trackers/utils/matching.py +40 -21
- ultralytics/utils/__init__.py +511 -239
- ultralytics/utils/autobatch.py +40 -22
- ultralytics/utils/benchmarks.py +266 -85
- ultralytics/utils/callbacks/__init__.py +1 -1
- ultralytics/utils/callbacks/base.py +1 -3
- ultralytics/utils/callbacks/clearml.py +7 -6
- ultralytics/utils/callbacks/comet.py +39 -17
- ultralytics/utils/callbacks/dvc.py +1 -1
- ultralytics/utils/callbacks/hub.py +16 -16
- ultralytics/utils/callbacks/mlflow.py +28 -24
- ultralytics/utils/callbacks/neptune.py +6 -2
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +18 -18
- ultralytics/utils/callbacks/wb.py +27 -20
- ultralytics/utils/checks.py +160 -100
- ultralytics/utils/dist.py +2 -1
- ultralytics/utils/downloads.py +44 -37
- ultralytics/utils/errors.py +1 -1
- ultralytics/utils/files.py +72 -38
- ultralytics/utils/instance.py +41 -19
- ultralytics/utils/loss.py +84 -56
- ultralytics/utils/metrics.py +61 -56
- ultralytics/utils/ops.py +94 -89
- ultralytics/utils/patches.py +30 -14
- ultralytics/utils/plotting.py +600 -269
- ultralytics/utils/tal.py +67 -26
- ultralytics/utils/torch_utils.py +302 -102
- ultralytics/utils/triton.py +2 -1
- ultralytics/utils/tuner.py +21 -12
- ultralytics-8.3.62.dist-info/METADATA +370 -0
- ultralytics-8.3.62.dist-info/RECORD +241 -0
- {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/WHEEL +1 -1
- ultralytics/data/explorer/__init__.py +0 -5
- ultralytics/data/explorer/explorer.py +0 -472
- ultralytics/data/explorer/gui/__init__.py +0 -1
- ultralytics/data/explorer/gui/dash.py +0 -268
- ultralytics/data/explorer/utils.py +0 -166
- ultralytics/models/fastsam/prompt.py +0 -357
- ultralytics-8.1.28.dist-info/METADATA +0 -373
- ultralytics-8.1.28.dist-info/RECORD +0 -197
- {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/LICENSE +0 -0
- {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
import math
|
4
4
|
from typing import Tuple, Type
|
@@ -11,19 +11,31 @@ from ultralytics.nn.modules import MLPBlock
|
|
11
11
|
|
12
12
|
class TwoWayTransformer(nn.Module):
|
13
13
|
"""
|
14
|
-
A Two-Way Transformer module
|
15
|
-
|
16
|
-
|
17
|
-
|
14
|
+
A Two-Way Transformer module for simultaneous attention to image and query points.
|
15
|
+
|
16
|
+
This class implements a specialized transformer decoder that attends to an input image using queries with
|
17
|
+
supplied positional embeddings. It's useful for tasks like object detection, image segmentation, and point
|
18
|
+
cloud processing.
|
18
19
|
|
19
20
|
Attributes:
|
20
|
-
depth (int):
|
21
|
-
embedding_dim (int):
|
22
|
-
num_heads (int):
|
23
|
-
mlp_dim (int):
|
24
|
-
layers (nn.ModuleList):
|
25
|
-
final_attn_token_to_image (Attention):
|
26
|
-
norm_final_attn (nn.LayerNorm):
|
21
|
+
depth (int): Number of layers in the transformer.
|
22
|
+
embedding_dim (int): Channel dimension for input embeddings.
|
23
|
+
num_heads (int): Number of heads for multihead attention.
|
24
|
+
mlp_dim (int): Internal channel dimension for the MLP block.
|
25
|
+
layers (nn.ModuleList): List of TwoWayAttentionBlock layers composing the transformer.
|
26
|
+
final_attn_token_to_image (Attention): Final attention layer from queries to image.
|
27
|
+
norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.
|
28
|
+
|
29
|
+
Methods:
|
30
|
+
forward: Processes image and point embeddings through the transformer.
|
31
|
+
|
32
|
+
Examples:
|
33
|
+
>>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
|
34
|
+
>>> image_embedding = torch.randn(1, 256, 32, 32)
|
35
|
+
>>> image_pe = torch.randn(1, 256, 32, 32)
|
36
|
+
>>> point_embedding = torch.randn(1, 100, 256)
|
37
|
+
>>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
|
38
|
+
>>> print(output_queries.shape, output_image.shape)
|
27
39
|
"""
|
28
40
|
|
29
41
|
def __init__(
|
@@ -36,15 +48,32 @@ class TwoWayTransformer(nn.Module):
|
|
36
48
|
attention_downsample_rate: int = 2,
|
37
49
|
) -> None:
|
38
50
|
"""
|
39
|
-
|
51
|
+
Initialize a Two-Way Transformer for simultaneous attention to image and query points.
|
40
52
|
|
41
53
|
Args:
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
54
|
+
depth (int): Number of layers in the transformer.
|
55
|
+
embedding_dim (int): Channel dimension for input embeddings.
|
56
|
+
num_heads (int): Number of heads for multihead attention. Must divide embedding_dim.
|
57
|
+
mlp_dim (int): Internal channel dimension for the MLP block.
|
58
|
+
activation (Type[nn.Module]): Activation function to use in the MLP block.
|
59
|
+
attention_downsample_rate (int): Downsampling rate for attention mechanism.
|
60
|
+
|
61
|
+
Attributes:
|
62
|
+
depth (int): Number of layers in the transformer.
|
63
|
+
embedding_dim (int): Channel dimension for input embeddings.
|
64
|
+
num_heads (int): Number of heads for multihead attention.
|
65
|
+
mlp_dim (int): Internal channel dimension for the MLP block.
|
66
|
+
layers (nn.ModuleList): List of TwoWayAttentionBlock layers.
|
67
|
+
final_attn_token_to_image (Attention): Final attention layer from queries to image.
|
68
|
+
norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.
|
69
|
+
|
70
|
+
Examples:
|
71
|
+
>>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
|
72
|
+
>>> image_embedding = torch.randn(1, 256, 32, 32)
|
73
|
+
>>> image_pe = torch.randn(1, 256, 32, 32)
|
74
|
+
>>> point_embedding = torch.randn(1, 100, 256)
|
75
|
+
>>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
|
76
|
+
>>> print(output_queries.shape, output_image.shape)
|
48
77
|
"""
|
49
78
|
super().__init__()
|
50
79
|
self.depth = depth
|
@@ -75,18 +104,25 @@ class TwoWayTransformer(nn.Module):
|
|
75
104
|
point_embedding: Tensor,
|
76
105
|
) -> Tuple[Tensor, Tensor]:
|
77
106
|
"""
|
107
|
+
Processes image and point embeddings through the Two-Way Transformer.
|
108
|
+
|
78
109
|
Args:
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
Must have shape B x N_points x embedding_dim for any N_points.
|
110
|
+
image_embedding (torch.Tensor): Image to attend to, with shape (B, embedding_dim, H, W).
|
111
|
+
image_pe (torch.Tensor): Positional encoding to add to the image, with same shape as image_embedding.
|
112
|
+
point_embedding (torch.Tensor): Embedding to add to query points, with shape (B, N_points, embedding_dim).
|
83
113
|
|
84
114
|
Returns:
|
85
|
-
|
86
|
-
|
115
|
+
(Tuple[torch.Tensor, torch.Tensor]): Processed point_embedding and image_embedding.
|
116
|
+
|
117
|
+
Examples:
|
118
|
+
>>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
|
119
|
+
>>> image_embedding = torch.randn(1, 256, 32, 32)
|
120
|
+
>>> image_pe = torch.randn(1, 256, 32, 32)
|
121
|
+
>>> point_embedding = torch.randn(1, 100, 256)
|
122
|
+
>>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
|
123
|
+
>>> print(output_queries.shape, output_image.shape)
|
87
124
|
"""
|
88
125
|
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
|
89
|
-
bs, c, h, w = image_embedding.shape
|
90
126
|
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
|
91
127
|
image_pe = image_pe.flatten(2).permute(0, 2, 1)
|
92
128
|
|
@@ -115,21 +151,34 @@ class TwoWayTransformer(nn.Module):
|
|
115
151
|
|
116
152
|
class TwoWayAttentionBlock(nn.Module):
|
117
153
|
"""
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
sparse inputs
|
154
|
+
A two-way attention block for simultaneous attention to image and query points.
|
155
|
+
|
156
|
+
This class implements a specialized transformer block with four main layers: self-attention on sparse inputs,
|
157
|
+
cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention of dense
|
158
|
+
inputs to sparse inputs.
|
122
159
|
|
123
160
|
Attributes:
|
124
|
-
self_attn (Attention):
|
125
|
-
norm1 (nn.LayerNorm): Layer normalization
|
161
|
+
self_attn (Attention): Self-attention layer for queries.
|
162
|
+
norm1 (nn.LayerNorm): Layer normalization after self-attention.
|
126
163
|
cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.
|
127
|
-
norm2 (nn.LayerNorm): Layer normalization
|
128
|
-
mlp (MLPBlock): MLP block
|
129
|
-
norm3 (nn.LayerNorm): Layer normalization
|
130
|
-
norm4 (nn.LayerNorm): Layer normalization
|
164
|
+
norm2 (nn.LayerNorm): Layer normalization after token-to-image attention.
|
165
|
+
mlp (MLPBlock): MLP block for transforming query embeddings.
|
166
|
+
norm3 (nn.LayerNorm): Layer normalization after MLP block.
|
167
|
+
norm4 (nn.LayerNorm): Layer normalization after image-to-token attention.
|
131
168
|
cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.
|
132
|
-
skip_first_layer_pe (bool): Whether to skip
|
169
|
+
skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer.
|
170
|
+
|
171
|
+
Methods:
|
172
|
+
forward: Applies self-attention and cross-attention to queries and keys.
|
173
|
+
|
174
|
+
Examples:
|
175
|
+
>>> embedding_dim, num_heads = 256, 8
|
176
|
+
>>> block = TwoWayAttentionBlock(embedding_dim, num_heads)
|
177
|
+
>>> queries = torch.randn(1, 100, embedding_dim)
|
178
|
+
>>> keys = torch.randn(1, 1000, embedding_dim)
|
179
|
+
>>> query_pe = torch.randn(1, 100, embedding_dim)
|
180
|
+
>>> key_pe = torch.randn(1, 1000, embedding_dim)
|
181
|
+
>>> processed_queries, processed_keys = block(queries, keys, query_pe, key_pe)
|
133
182
|
"""
|
134
183
|
|
135
184
|
def __init__(
|
@@ -142,16 +191,28 @@ class TwoWayAttentionBlock(nn.Module):
|
|
142
191
|
skip_first_layer_pe: bool = False,
|
143
192
|
) -> None:
|
144
193
|
"""
|
145
|
-
|
146
|
-
|
147
|
-
|
194
|
+
Initializes a TwoWayAttentionBlock for simultaneous attention to image and query points.
|
195
|
+
|
196
|
+
This block implements a specialized transformer layer with four main components: self-attention on sparse
|
197
|
+
inputs, cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention
|
198
|
+
of dense inputs to sparse inputs.
|
148
199
|
|
149
200
|
Args:
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
201
|
+
embedding_dim (int): Channel dimension of the embeddings.
|
202
|
+
num_heads (int): Number of attention heads in the attention layers.
|
203
|
+
mlp_dim (int): Hidden dimension of the MLP block.
|
204
|
+
activation (Type[nn.Module]): Activation function for the MLP block.
|
205
|
+
attention_downsample_rate (int): Downsampling rate for the attention mechanism.
|
206
|
+
skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer.
|
207
|
+
|
208
|
+
Examples:
|
209
|
+
>>> embedding_dim, num_heads = 256, 8
|
210
|
+
>>> block = TwoWayAttentionBlock(embedding_dim, num_heads)
|
211
|
+
>>> queries = torch.randn(1, 100, embedding_dim)
|
212
|
+
>>> keys = torch.randn(1, 1000, embedding_dim)
|
213
|
+
>>> query_pe = torch.randn(1, 100, embedding_dim)
|
214
|
+
>>> key_pe = torch.randn(1, 1000, embedding_dim)
|
215
|
+
>>> processed_queries, processed_keys = block(queries, keys, query_pe, key_pe)
|
155
216
|
"""
|
156
217
|
super().__init__()
|
157
218
|
self.self_attn = Attention(embedding_dim, num_heads)
|
@@ -169,8 +230,7 @@ class TwoWayAttentionBlock(nn.Module):
|
|
169
230
|
self.skip_first_layer_pe = skip_first_layer_pe
|
170
231
|
|
171
232
|
def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]:
|
172
|
-
"""
|
173
|
-
|
233
|
+
"""Applies two-way attention to process query and key embeddings in a transformer block."""
|
174
234
|
# Self attention block
|
175
235
|
if self.skip_first_layer_pe:
|
176
236
|
queries = self.self_attn(q=queries, k=queries, v=queries)
|
@@ -203,8 +263,34 @@ class TwoWayAttentionBlock(nn.Module):
|
|
203
263
|
|
204
264
|
|
205
265
|
class Attention(nn.Module):
|
206
|
-
"""
|
207
|
-
|
266
|
+
"""
|
267
|
+
An attention layer with downscaling capability for embedding size after projection.
|
268
|
+
|
269
|
+
This class implements a multi-head attention mechanism with the option to downsample the internal
|
270
|
+
dimension of queries, keys, and values.
|
271
|
+
|
272
|
+
Attributes:
|
273
|
+
embedding_dim (int): Dimensionality of input embeddings.
|
274
|
+
kv_in_dim (int): Dimensionality of key and value inputs.
|
275
|
+
internal_dim (int): Internal dimension after downsampling.
|
276
|
+
num_heads (int): Number of attention heads.
|
277
|
+
q_proj (nn.Linear): Linear projection for queries.
|
278
|
+
k_proj (nn.Linear): Linear projection for keys.
|
279
|
+
v_proj (nn.Linear): Linear projection for values.
|
280
|
+
out_proj (nn.Linear): Linear projection for output.
|
281
|
+
|
282
|
+
Methods:
|
283
|
+
_separate_heads: Separates input tensor into attention heads.
|
284
|
+
_recombine_heads: Recombines separated attention heads.
|
285
|
+
forward: Computes attention output for given query, key, and value tensors.
|
286
|
+
|
287
|
+
Examples:
|
288
|
+
>>> attn = Attention(embedding_dim=256, num_heads=8, downsample_rate=2)
|
289
|
+
>>> q = torch.randn(1, 100, 256)
|
290
|
+
>>> k = v = torch.randn(1, 50, 256)
|
291
|
+
>>> output = attn(q, k, v)
|
292
|
+
>>> print(output.shape)
|
293
|
+
torch.Size([1, 100, 256])
|
208
294
|
"""
|
209
295
|
|
210
296
|
def __init__(
|
@@ -212,46 +298,59 @@ class Attention(nn.Module):
|
|
212
298
|
embedding_dim: int,
|
213
299
|
num_heads: int,
|
214
300
|
downsample_rate: int = 1,
|
301
|
+
kv_in_dim: int = None,
|
215
302
|
) -> None:
|
216
303
|
"""
|
217
|
-
Initializes the Attention
|
304
|
+
Initializes the Attention module with specified dimensions and settings.
|
305
|
+
|
306
|
+
This class implements a multi-head attention mechanism with optional downsampling of the internal
|
307
|
+
dimension for queries, keys, and values.
|
218
308
|
|
219
309
|
Args:
|
220
|
-
embedding_dim (int):
|
221
|
-
num_heads (int):
|
222
|
-
downsample_rate (int
|
310
|
+
embedding_dim (int): Dimensionality of input embeddings.
|
311
|
+
num_heads (int): Number of attention heads.
|
312
|
+
downsample_rate (int): Factor by which internal dimensions are downsampled. Defaults to 1.
|
313
|
+
kv_in_dim (int | None): Dimensionality of key and value inputs. If None, uses embedding_dim.
|
223
314
|
|
224
315
|
Raises:
|
225
|
-
AssertionError: If
|
316
|
+
AssertionError: If num_heads does not evenly divide the internal dim (embedding_dim / downsample_rate).
|
317
|
+
|
318
|
+
Examples:
|
319
|
+
>>> attn = Attention(embedding_dim=256, num_heads=8, downsample_rate=2)
|
320
|
+
>>> q = torch.randn(1, 100, 256)
|
321
|
+
>>> k = v = torch.randn(1, 50, 256)
|
322
|
+
>>> output = attn(q, k, v)
|
323
|
+
>>> print(output.shape)
|
324
|
+
torch.Size([1, 100, 256])
|
226
325
|
"""
|
227
326
|
super().__init__()
|
228
327
|
self.embedding_dim = embedding_dim
|
328
|
+
self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
|
229
329
|
self.internal_dim = embedding_dim // downsample_rate
|
230
330
|
self.num_heads = num_heads
|
231
331
|
assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
|
232
332
|
|
233
333
|
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
|
234
|
-
self.k_proj = nn.Linear(
|
235
|
-
self.v_proj = nn.Linear(
|
334
|
+
self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
|
335
|
+
self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
|
236
336
|
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
|
237
337
|
|
238
338
|
@staticmethod
|
239
339
|
def _separate_heads(x: Tensor, num_heads: int) -> Tensor:
|
240
|
-
"""
|
340
|
+
"""Separates the input tensor into the specified number of attention heads."""
|
241
341
|
b, n, c = x.shape
|
242
342
|
x = x.reshape(b, n, num_heads, c // num_heads)
|
243
343
|
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
|
244
344
|
|
245
345
|
@staticmethod
|
246
346
|
def _recombine_heads(x: Tensor) -> Tensor:
|
247
|
-
"""
|
347
|
+
"""Recombines separated attention heads into a single tensor."""
|
248
348
|
b, n_heads, n_tokens, c_per_head = x.shape
|
249
349
|
x = x.transpose(1, 2)
|
250
350
|
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
|
251
351
|
|
252
352
|
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
253
|
-
"""
|
254
|
-
|
353
|
+
"""Applies multi-head attention to query, key, and value tensors with optional downsampling."""
|
255
354
|
# Input projections
|
256
355
|
q = self.q_proj(q)
|
257
356
|
k = self.k_proj(k)
|
@@ -0,0 +1,293 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
from typing import Tuple
|
4
|
+
|
5
|
+
import torch
|
6
|
+
import torch.nn.functional as F
|
7
|
+
|
8
|
+
|
9
|
+
def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
|
10
|
+
"""
|
11
|
+
Selects the closest conditioning frames to a given frame index.
|
12
|
+
|
13
|
+
Args:
|
14
|
+
frame_idx (int): Current frame index.
|
15
|
+
cond_frame_outputs (Dict[int, Any]): Dictionary of conditioning frame outputs keyed by frame indices.
|
16
|
+
max_cond_frame_num (int): Maximum number of conditioning frames to select.
|
17
|
+
|
18
|
+
Returns:
|
19
|
+
(Tuple[Dict[int, Any], Dict[int, Any]]): A tuple containing two dictionaries:
|
20
|
+
- selected_outputs: Selected items from cond_frame_outputs.
|
21
|
+
- unselected_outputs: Items not selected from cond_frame_outputs.
|
22
|
+
|
23
|
+
Examples:
|
24
|
+
>>> frame_idx = 5
|
25
|
+
>>> cond_frame_outputs = {1: "a", 3: "b", 7: "c", 9: "d"}
|
26
|
+
>>> max_cond_frame_num = 2
|
27
|
+
>>> selected, unselected = select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num)
|
28
|
+
>>> print(selected)
|
29
|
+
{3: 'b', 7: 'c'}
|
30
|
+
>>> print(unselected)
|
31
|
+
{1: 'a', 9: 'd'}
|
32
|
+
"""
|
33
|
+
if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num:
|
34
|
+
selected_outputs = cond_frame_outputs
|
35
|
+
unselected_outputs = {}
|
36
|
+
else:
|
37
|
+
assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames"
|
38
|
+
selected_outputs = {}
|
39
|
+
|
40
|
+
# the closest conditioning frame before `frame_idx` (if any)
|
41
|
+
idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None)
|
42
|
+
if idx_before is not None:
|
43
|
+
selected_outputs[idx_before] = cond_frame_outputs[idx_before]
|
44
|
+
|
45
|
+
# the closest conditioning frame after `frame_idx` (if any)
|
46
|
+
idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None)
|
47
|
+
if idx_after is not None:
|
48
|
+
selected_outputs[idx_after] = cond_frame_outputs[idx_after]
|
49
|
+
|
50
|
+
# add other temporally closest conditioning frames until reaching a total
|
51
|
+
# of `max_cond_frame_num` conditioning frames.
|
52
|
+
num_remain = max_cond_frame_num - len(selected_outputs)
|
53
|
+
inds_remain = sorted(
|
54
|
+
(t for t in cond_frame_outputs if t not in selected_outputs),
|
55
|
+
key=lambda x: abs(x - frame_idx),
|
56
|
+
)[:num_remain]
|
57
|
+
selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)
|
58
|
+
unselected_outputs = {t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs}
|
59
|
+
|
60
|
+
return selected_outputs, unselected_outputs
|
61
|
+
|
62
|
+
|
63
|
+
def get_1d_sine_pe(pos_inds, dim, temperature=10000):
|
64
|
+
"""Generates 1D sinusoidal positional embeddings for given positions and dimensions."""
|
65
|
+
pe_dim = dim // 2
|
66
|
+
dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
|
67
|
+
dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
|
68
|
+
|
69
|
+
pos_embed = pos_inds.unsqueeze(-1) / dim_t
|
70
|
+
pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1)
|
71
|
+
return pos_embed
|
72
|
+
|
73
|
+
|
74
|
+
def init_t_xy(end_x: int, end_y: int):
|
75
|
+
"""Initializes 1D and 2D coordinate tensors for a grid of specified dimensions."""
|
76
|
+
t = torch.arange(end_x * end_y, dtype=torch.float32)
|
77
|
+
t_x = (t % end_x).float()
|
78
|
+
t_y = torch.div(t, end_x, rounding_mode="floor").float()
|
79
|
+
return t_x, t_y
|
80
|
+
|
81
|
+
|
82
|
+
def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
|
83
|
+
"""Computes axial complex exponential positional encodings for 2D spatial positions in a grid."""
|
84
|
+
freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
|
85
|
+
freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
|
86
|
+
|
87
|
+
t_x, t_y = init_t_xy(end_x, end_y)
|
88
|
+
freqs_x = torch.outer(t_x, freqs_x)
|
89
|
+
freqs_y = torch.outer(t_y, freqs_y)
|
90
|
+
freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
|
91
|
+
freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
|
92
|
+
return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
|
93
|
+
|
94
|
+
|
95
|
+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
96
|
+
"""Reshapes frequency tensor for broadcasting with input tensor, ensuring dimensional compatibility."""
|
97
|
+
ndim = x.ndim
|
98
|
+
assert 0 <= 1 < ndim
|
99
|
+
assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
|
100
|
+
shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
|
101
|
+
return freqs_cis.view(*shape)
|
102
|
+
|
103
|
+
|
104
|
+
def apply_rotary_enc(
|
105
|
+
xq: torch.Tensor,
|
106
|
+
xk: torch.Tensor,
|
107
|
+
freqs_cis: torch.Tensor,
|
108
|
+
repeat_freqs_k: bool = False,
|
109
|
+
):
|
110
|
+
"""Applies rotary positional encoding to query and key tensors using complex-valued frequency components."""
|
111
|
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
112
|
+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) if xk.shape[-2] != 0 else None
|
113
|
+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
114
|
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
115
|
+
if xk_ is None:
|
116
|
+
# no keys to rotate, due to dropout
|
117
|
+
return xq_out.type_as(xq).to(xq.device), xk
|
118
|
+
# repeat freqs along seq_len dim to match k seq_len
|
119
|
+
if repeat_freqs_k:
|
120
|
+
r = xk_.shape[-2] // xq_.shape[-2]
|
121
|
+
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
|
122
|
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
123
|
+
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
|
124
|
+
|
125
|
+
|
126
|
+
def window_partition(x, window_size):
|
127
|
+
"""
|
128
|
+
Partitions input tensor into non-overlapping windows with padding if needed.
|
129
|
+
|
130
|
+
Args:
|
131
|
+
x (torch.Tensor): Input tensor with shape (B, H, W, C).
|
132
|
+
window_size (int): Size of each window.
|
133
|
+
|
134
|
+
Returns:
|
135
|
+
(Tuple[torch.Tensor, Tuple[int, int]]): A tuple containing:
|
136
|
+
- windows (torch.Tensor): Partitioned windows with shape (B * num_windows, window_size, window_size, C).
|
137
|
+
- (Hp, Wp) (Tuple[int, int]): Padded height and width before partition.
|
138
|
+
|
139
|
+
Examples:
|
140
|
+
>>> x = torch.randn(1, 16, 16, 3)
|
141
|
+
>>> windows, (Hp, Wp) = window_partition(x, window_size=4)
|
142
|
+
>>> print(windows.shape, Hp, Wp)
|
143
|
+
torch.Size([16, 4, 4, 3]) 16 16
|
144
|
+
"""
|
145
|
+
B, H, W, C = x.shape
|
146
|
+
|
147
|
+
pad_h = (window_size - H % window_size) % window_size
|
148
|
+
pad_w = (window_size - W % window_size) % window_size
|
149
|
+
if pad_h > 0 or pad_w > 0:
|
150
|
+
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
151
|
+
Hp, Wp = H + pad_h, W + pad_w
|
152
|
+
|
153
|
+
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
154
|
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
155
|
+
return windows, (Hp, Wp)
|
156
|
+
|
157
|
+
|
158
|
+
def window_unpartition(windows, window_size, pad_hw, hw):
|
159
|
+
"""
|
160
|
+
Unpartitions windowed sequences into original sequences and removes padding.
|
161
|
+
|
162
|
+
This function reverses the windowing process, reconstructing the original input from windowed segments
|
163
|
+
and removing any padding that was added during the windowing process.
|
164
|
+
|
165
|
+
Args:
|
166
|
+
windows (torch.Tensor): Input tensor of windowed sequences with shape (B * num_windows, window_size,
|
167
|
+
window_size, C), where B is the batch size, num_windows is the number of windows, window_size is
|
168
|
+
the size of each window, and C is the number of channels.
|
169
|
+
window_size (int): Size of each window.
|
170
|
+
pad_hw (Tuple[int, int]): Padded height and width (Hp, Wp) of the input before windowing.
|
171
|
+
hw (Tuple[int, int]): Original height and width (H, W) of the input before padding and windowing.
|
172
|
+
|
173
|
+
Returns:
|
174
|
+
(torch.Tensor): Unpartitioned sequences with shape (B, H, W, C), where B is the batch size, H and W
|
175
|
+
are the original height and width, and C is the number of channels.
|
176
|
+
|
177
|
+
Examples:
|
178
|
+
>>> windows = torch.rand(32, 8, 8, 64) # 32 windows of size 8x8 with 64 channels
|
179
|
+
>>> pad_hw = (16, 16) # Padded height and width
|
180
|
+
>>> hw = (15, 14) # Original height and width
|
181
|
+
>>> x = window_unpartition(windows, window_size=8, pad_hw=pad_hw, hw=hw)
|
182
|
+
>>> print(x.shape)
|
183
|
+
torch.Size([1, 15, 14, 64])
|
184
|
+
"""
|
185
|
+
Hp, Wp = pad_hw
|
186
|
+
H, W = hw
|
187
|
+
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
188
|
+
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
|
189
|
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
190
|
+
|
191
|
+
if Hp > H or Wp > W:
|
192
|
+
x = x[:, :H, :W, :].contiguous()
|
193
|
+
return x
|
194
|
+
|
195
|
+
|
196
|
+
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
|
197
|
+
"""
|
198
|
+
Extracts relative positional embeddings based on query and key sizes.
|
199
|
+
|
200
|
+
Args:
|
201
|
+
q_size (int): Size of the query.
|
202
|
+
k_size (int): Size of the key.
|
203
|
+
rel_pos (torch.Tensor): Relative position embeddings with shape (L, C), where L is the maximum relative
|
204
|
+
distance and C is the embedding dimension.
|
205
|
+
|
206
|
+
Returns:
|
207
|
+
(torch.Tensor): Extracted positional embeddings according to relative positions, with shape (q_size,
|
208
|
+
k_size, C).
|
209
|
+
|
210
|
+
Examples:
|
211
|
+
>>> q_size, k_size = 8, 16
|
212
|
+
>>> rel_pos = torch.randn(31, 64) # 31 = 2 * max(8, 16) - 1
|
213
|
+
>>> extracted_pos = get_rel_pos(q_size, k_size, rel_pos)
|
214
|
+
>>> print(extracted_pos.shape)
|
215
|
+
torch.Size([8, 16, 64])
|
216
|
+
"""
|
217
|
+
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
218
|
+
# Interpolate rel pos if needed.
|
219
|
+
if rel_pos.shape[0] != max_rel_dist:
|
220
|
+
# Interpolate rel pos.
|
221
|
+
rel_pos_resized = F.interpolate(
|
222
|
+
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
223
|
+
size=max_rel_dist,
|
224
|
+
mode="linear",
|
225
|
+
)
|
226
|
+
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
227
|
+
else:
|
228
|
+
rel_pos_resized = rel_pos
|
229
|
+
|
230
|
+
# Scale the coords with short length if shapes for q and k are different.
|
231
|
+
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
232
|
+
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
|
233
|
+
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
234
|
+
|
235
|
+
return rel_pos_resized[relative_coords.long()]
|
236
|
+
|
237
|
+
|
238
|
+
def add_decomposed_rel_pos(
|
239
|
+
attn: torch.Tensor,
|
240
|
+
q: torch.Tensor,
|
241
|
+
rel_pos_h: torch.Tensor,
|
242
|
+
rel_pos_w: torch.Tensor,
|
243
|
+
q_size: Tuple[int, int],
|
244
|
+
k_size: Tuple[int, int],
|
245
|
+
) -> torch.Tensor:
|
246
|
+
"""
|
247
|
+
Adds decomposed Relative Positional Embeddings to the attention map.
|
248
|
+
|
249
|
+
This function calculates and applies decomposed Relative Positional Embeddings as described in the MVITv2
|
250
|
+
paper. It enhances the attention mechanism by incorporating spatial relationships between query and key
|
251
|
+
positions.
|
252
|
+
|
253
|
+
Args:
|
254
|
+
attn (torch.Tensor): Attention map with shape (B, q_h * q_w, k_h * k_w).
|
255
|
+
q (torch.Tensor): Query tensor in the attention layer with shape (B, q_h * q_w, C).
|
256
|
+
rel_pos_h (torch.Tensor): Relative position embeddings for height axis with shape (Lh, C).
|
257
|
+
rel_pos_w (torch.Tensor): Relative position embeddings for width axis with shape (Lw, C).
|
258
|
+
q_size (Tuple[int, int]): Spatial sequence size of query q as (q_h, q_w).
|
259
|
+
k_size (Tuple[int, int]): Spatial sequence size of key k as (k_h, k_w).
|
260
|
+
|
261
|
+
Returns:
|
262
|
+
(torch.Tensor): Updated attention map with added relative positional embeddings, shape
|
263
|
+
(B, q_h * q_w, k_h * k_w).
|
264
|
+
|
265
|
+
Examples:
|
266
|
+
>>> B, C, q_h, q_w, k_h, k_w = 1, 64, 8, 8, 8, 8
|
267
|
+
>>> attn = torch.rand(B, q_h * q_w, k_h * k_w)
|
268
|
+
>>> q = torch.rand(B, q_h * q_w, C)
|
269
|
+
>>> rel_pos_h = torch.rand(2 * max(q_h, k_h) - 1, C)
|
270
|
+
>>> rel_pos_w = torch.rand(2 * max(q_w, k_w) - 1, C)
|
271
|
+
>>> q_size, k_size = (q_h, q_w), (k_h, k_w)
|
272
|
+
>>> updated_attn = add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size)
|
273
|
+
>>> print(updated_attn.shape)
|
274
|
+
torch.Size([1, 64, 64])
|
275
|
+
|
276
|
+
References:
|
277
|
+
https://github.com/facebookresearch/mvit/blob/main/mvit/models/attention.py
|
278
|
+
"""
|
279
|
+
q_h, q_w = q_size
|
280
|
+
k_h, k_w = k_size
|
281
|
+
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
|
282
|
+
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
|
283
|
+
|
284
|
+
B, _, dim = q.shape
|
285
|
+
r_q = q.reshape(B, q_h, q_w, dim)
|
286
|
+
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
287
|
+
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
288
|
+
|
289
|
+
attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
|
290
|
+
B, q_h * q_w, k_h * k_w
|
291
|
+
)
|
292
|
+
|
293
|
+
return attn
|