dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.3.248__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +13 -14
- dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
- tests/__init__.py +5 -7
- tests/conftest.py +8 -15
- tests/test_cli.py +1 -1
- tests/test_cuda.py +5 -8
- tests/test_engine.py +1 -1
- tests/test_exports.py +57 -12
- tests/test_integrations.py +4 -4
- tests/test_python.py +84 -53
- tests/test_solutions.py +160 -151
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +56 -62
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/VOC.yaml +15 -16
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +21 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +5 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +1 -1
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +3 -4
- ultralytics/data/augment.py +285 -475
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +147 -25
- ultralytics/data/converter.py +36 -46
- ultralytics/data/dataset.py +46 -74
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +34 -43
- ultralytics/engine/exporter.py +319 -237
- ultralytics/engine/model.py +148 -188
- ultralytics/engine/predictor.py +29 -38
- ultralytics/engine/results.py +177 -311
- ultralytics/engine/trainer.py +83 -59
- ultralytics/engine/tuner.py +23 -34
- ultralytics/engine/validator.py +39 -22
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +17 -29
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +5 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +19 -37
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +200 -73
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2248 -350
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +9 -12
- ultralytics/models/yolo/classify/train.py +11 -32
- ultralytics/models/yolo/classify/val.py +29 -28
- ultralytics/models/yolo/detect/predict.py +7 -10
- ultralytics/models/yolo/detect/train.py +11 -20
- ultralytics/models/yolo/detect/val.py +70 -58
- ultralytics/models/yolo/model.py +36 -53
- ultralytics/models/yolo/obb/predict.py +5 -14
- ultralytics/models/yolo/obb/train.py +11 -14
- ultralytics/models/yolo/obb/val.py +39 -36
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +6 -21
- ultralytics/models/yolo/pose/train.py +10 -15
- ultralytics/models/yolo/pose/val.py +38 -57
- ultralytics/models/yolo/segment/predict.py +14 -18
- ultralytics/models/yolo/segment/train.py +3 -6
- ultralytics/models/yolo/segment/val.py +93 -45
- ultralytics/models/yolo/world/train.py +8 -14
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +30 -43
- ultralytics/models/yolo/yoloe/train_seg.py +5 -10
- ultralytics/models/yolo/yoloe/val.py +15 -20
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +145 -77
- ultralytics/nn/modules/__init__.py +60 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +132 -216
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +50 -103
- ultralytics/nn/modules/transformer.py +76 -88
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +94 -154
- ultralytics/nn/text_model.py +40 -67
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +11 -17
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +5 -6
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +7 -13
- ultralytics/solutions/instance_segmentation.py +5 -8
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +33 -31
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +10 -15
- ultralytics/solutions/solutions.py +75 -74
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +2 -4
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +14 -30
- ultralytics/trackers/track.py +3 -6
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +116 -116
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +32 -46
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +314 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +23 -31
- ultralytics/utils/callbacks/wb.py +10 -13
- ultralytics/utils/checks.py +99 -76
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +8 -12
- ultralytics/utils/downloads.py +20 -30
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +4 -236
- ultralytics/utils/export/engine.py +237 -0
- ultralytics/utils/export/imx.py +91 -55
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +24 -28
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +14 -22
- ultralytics/utils/metrics.py +126 -155
- ultralytics/utils/nms.py +13 -16
- ultralytics/utils/ops.py +107 -165
- ultralytics/utils/patches.py +33 -21
- ultralytics/utils/plotting.py +72 -80
- ultralytics/utils/tal.py +25 -39
- ultralytics/utils/torch_utils.py +52 -78
- ultralytics/utils/tqdm.py +20 -20
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +17 -5
- dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
import math
|
|
5
6
|
from typing import Any
|
|
6
7
|
|
|
7
8
|
import torch
|
|
@@ -9,8 +10,7 @@ import torch.nn.functional as F
|
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
def select_closest_cond_frames(frame_idx: int, cond_frame_outputs: dict[int, Any], max_cond_frame_num: int):
|
|
12
|
-
"""
|
|
13
|
-
Select the closest conditioning frames to a given frame index.
|
|
13
|
+
"""Select the closest conditioning frames to a given frame index.
|
|
14
14
|
|
|
15
15
|
Args:
|
|
16
16
|
frame_idx (int): Current frame index.
|
|
@@ -62,8 +62,7 @@ def select_closest_cond_frames(frame_idx: int, cond_frame_outputs: dict[int, Any
|
|
|
62
62
|
|
|
63
63
|
|
|
64
64
|
def get_1d_sine_pe(pos_inds: torch.Tensor, dim: int, temperature: float = 10000):
|
|
65
|
-
"""
|
|
66
|
-
Generate 1D sinusoidal positional embeddings for given positions and dimensions.
|
|
65
|
+
"""Generate 1D sinusoidal positional embeddings for given positions and dimensions.
|
|
67
66
|
|
|
68
67
|
Args:
|
|
69
68
|
pos_inds (torch.Tensor): Position indices for which to generate embeddings.
|
|
@@ -88,16 +87,17 @@ def get_1d_sine_pe(pos_inds: torch.Tensor, dim: int, temperature: float = 10000)
|
|
|
88
87
|
return pos_embed
|
|
89
88
|
|
|
90
89
|
|
|
91
|
-
def init_t_xy(end_x: int, end_y: int):
|
|
92
|
-
"""
|
|
93
|
-
Initialize 1D and 2D coordinate tensors for a grid of specified dimensions.
|
|
90
|
+
def init_t_xy(end_x: int, end_y: int, scale: float = 1.0, offset: int = 0):
|
|
91
|
+
"""Initialize 1D and 2D coordinate tensors for a grid of specified dimensions.
|
|
94
92
|
|
|
95
|
-
This function creates coordinate tensors for a grid with dimensions end_x × end_y. It generates a linear index
|
|
96
|
-
and corresponding x and y coordinate tensors.
|
|
93
|
+
This function creates coordinate tensors for a grid with dimensions end_x × end_y. It generates a linear index
|
|
94
|
+
tensor and corresponding x and y coordinate tensors.
|
|
97
95
|
|
|
98
96
|
Args:
|
|
99
97
|
end_x (int): Width of the grid (number of columns).
|
|
100
98
|
end_y (int): Height of the grid (number of rows).
|
|
99
|
+
scale (float): Scaling factor to apply to the coordinates.
|
|
100
|
+
offset (int): Offset to add to the coordinates.
|
|
101
101
|
|
|
102
102
|
Returns:
|
|
103
103
|
t_x (torch.Tensor): X-coordinates for each position, with shape (end_x * end_y).
|
|
@@ -113,21 +113,21 @@ def init_t_xy(end_x: int, end_y: int):
|
|
|
113
113
|
t = torch.arange(end_x * end_y, dtype=torch.float32)
|
|
114
114
|
t_x = (t % end_x).float()
|
|
115
115
|
t_y = torch.div(t, end_x, rounding_mode="floor").float()
|
|
116
|
-
return t_x, t_y
|
|
116
|
+
return t_x * scale + offset, t_y * scale + offset
|
|
117
117
|
|
|
118
118
|
|
|
119
|
-
def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
|
|
120
|
-
"""
|
|
121
|
-
Compute axial complex exponential positional encodings for 2D spatial positions in a grid.
|
|
119
|
+
def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0, scale_pos: float = 1.0):
|
|
120
|
+
"""Compute axial complex exponential positional encodings for 2D spatial positions in a grid.
|
|
122
121
|
|
|
123
|
-
This function generates complex exponential positional encodings for a 2D grid of spatial positions,
|
|
124
|
-
|
|
122
|
+
This function generates complex exponential positional encodings for a 2D grid of spatial positions, using separate
|
|
123
|
+
frequency components for the x and y dimensions.
|
|
125
124
|
|
|
126
125
|
Args:
|
|
127
126
|
dim (int): Dimension of the positional encoding.
|
|
128
127
|
end_x (int): Width of the 2D grid.
|
|
129
128
|
end_y (int): Height of the 2D grid.
|
|
130
129
|
theta (float, optional): Scaling factor for frequency computation.
|
|
130
|
+
scale_pos (float, optional): Scaling factor for position coordinates.
|
|
131
131
|
|
|
132
132
|
Returns:
|
|
133
133
|
(torch.Tensor): Complex exponential positional encodings with shape (end_x*end_y, dim//2).
|
|
@@ -141,7 +141,7 @@ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
|
|
|
141
141
|
freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
|
|
142
142
|
freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
|
|
143
143
|
|
|
144
|
-
t_x, t_y = init_t_xy(end_x, end_y)
|
|
144
|
+
t_x, t_y = init_t_xy(end_x, end_y, scale=scale_pos)
|
|
145
145
|
freqs_x = torch.outer(t_x, freqs_x)
|
|
146
146
|
freqs_y = torch.outer(t_y, freqs_y)
|
|
147
147
|
freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
|
|
@@ -150,11 +150,10 @@ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
|
|
|
150
150
|
|
|
151
151
|
|
|
152
152
|
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
|
153
|
-
"""
|
|
154
|
-
Reshape frequency tensor for broadcasting with input tensor.
|
|
153
|
+
"""Reshape frequency tensor for broadcasting with input tensor.
|
|
155
154
|
|
|
156
|
-
Reshapes a frequency tensor to ensure dimensional compatibility for broadcasting with an input tensor.
|
|
157
|
-
|
|
155
|
+
Reshapes a frequency tensor to ensure dimensional compatibility for broadcasting with an input tensor. This function
|
|
156
|
+
is typically used in positional encoding operations.
|
|
158
157
|
|
|
159
158
|
Args:
|
|
160
159
|
freqs_cis (torch.Tensor): Frequency tensor with shape matching the last two dimensions of x.
|
|
@@ -167,7 +166,7 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
|
|
167
166
|
AssertionError: If the shape of freqs_cis doesn't match the last two dimensions of x.
|
|
168
167
|
"""
|
|
169
168
|
ndim = x.ndim
|
|
170
|
-
assert
|
|
169
|
+
assert ndim >= 2
|
|
171
170
|
assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
|
|
172
171
|
shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
|
|
173
172
|
return freqs_cis.view(*shape)
|
|
@@ -179,8 +178,7 @@ def apply_rotary_enc(
|
|
|
179
178
|
freqs_cis: torch.Tensor,
|
|
180
179
|
repeat_freqs_k: bool = False,
|
|
181
180
|
):
|
|
182
|
-
"""
|
|
183
|
-
Apply rotary positional encoding to query and key tensors.
|
|
181
|
+
"""Apply rotary positional encoding to query and key tensors.
|
|
184
182
|
|
|
185
183
|
This function applies rotary positional encoding (RoPE) to query and key tensors using complex-valued frequency
|
|
186
184
|
components. RoPE is a technique that injects relative position information into self-attention mechanisms.
|
|
@@ -188,10 +186,10 @@ def apply_rotary_enc(
|
|
|
188
186
|
Args:
|
|
189
187
|
xq (torch.Tensor): Query tensor to encode with positional information.
|
|
190
188
|
xk (torch.Tensor): Key tensor to encode with positional information.
|
|
191
|
-
freqs_cis (torch.Tensor): Complex-valued frequency components for rotary encoding with shape matching the
|
|
192
|
-
|
|
193
|
-
repeat_freqs_k (bool, optional): Whether to repeat frequency components along sequence length dimension
|
|
194
|
-
|
|
189
|
+
freqs_cis (torch.Tensor): Complex-valued frequency components for rotary encoding with shape matching the last
|
|
190
|
+
two dimensions of xq.
|
|
191
|
+
repeat_freqs_k (bool, optional): Whether to repeat frequency components along sequence length dimension to match
|
|
192
|
+
key sequence length.
|
|
195
193
|
|
|
196
194
|
Returns:
|
|
197
195
|
xq_out (torch.Tensor): Query tensor with rotary positional encoding applied.
|
|
@@ -212,16 +210,20 @@ def apply_rotary_enc(
|
|
|
212
210
|
# No keys to rotate, due to dropout
|
|
213
211
|
return xq_out.type_as(xq).to(xq.device), xk
|
|
214
212
|
# Repeat freqs along seq_len dim to match k seq_len
|
|
215
|
-
if repeat_freqs_k:
|
|
216
|
-
|
|
217
|
-
|
|
213
|
+
if repeat_freqs_k and (r := xk_.shape[-2] // xq_.shape[-2]) > 1:
|
|
214
|
+
# MPS doesn't support repeat on complex tensors, decompose to real representation
|
|
215
|
+
if freqs_cis.device.type == "mps":
|
|
216
|
+
freqs_cis = torch.view_as_real(freqs_cis)
|
|
217
|
+
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 3)), r, 1, 1)
|
|
218
|
+
freqs_cis = torch.view_as_complex(freqs_cis.contiguous())
|
|
219
|
+
else:
|
|
220
|
+
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
|
|
218
221
|
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
|
219
222
|
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
|
|
220
223
|
|
|
221
224
|
|
|
222
225
|
def window_partition(x: torch.Tensor, window_size: int):
|
|
223
|
-
"""
|
|
224
|
-
Partition input tensor into non-overlapping windows with padding if needed.
|
|
226
|
+
"""Partition input tensor into non-overlapping windows with padding if needed.
|
|
225
227
|
|
|
226
228
|
Args:
|
|
227
229
|
x (torch.Tensor): Input tensor with shape (B, H, W, C).
|
|
@@ -251,23 +253,22 @@ def window_partition(x: torch.Tensor, window_size: int):
|
|
|
251
253
|
|
|
252
254
|
|
|
253
255
|
def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: tuple[int, int], hw: tuple[int, int]):
|
|
254
|
-
"""
|
|
255
|
-
Unpartition windowed sequences into original sequences and remove padding.
|
|
256
|
+
"""Unpartition windowed sequences into original sequences and remove padding.
|
|
256
257
|
|
|
257
|
-
This function reverses the windowing process, reconstructing the original input from windowed segments
|
|
258
|
-
|
|
258
|
+
This function reverses the windowing process, reconstructing the original input from windowed segments and removing
|
|
259
|
+
any padding that was added during the windowing process.
|
|
259
260
|
|
|
260
261
|
Args:
|
|
261
262
|
windows (torch.Tensor): Input tensor of windowed sequences with shape (B * num_windows, window_size,
|
|
262
|
-
window_size, C), where B is the batch size, num_windows is the number of windows, window_size is
|
|
263
|
-
|
|
263
|
+
window_size, C), where B is the batch size, num_windows is the number of windows, window_size is the size of
|
|
264
|
+
each window, and C is the number of channels.
|
|
264
265
|
window_size (int): Size of each window.
|
|
265
266
|
pad_hw (tuple[int, int]): Padded height and width (Hp, Wp) of the input before windowing.
|
|
266
267
|
hw (tuple[int, int]): Original height and width (H, W) of the input before padding and windowing.
|
|
267
268
|
|
|
268
269
|
Returns:
|
|
269
|
-
(torch.Tensor): Unpartitioned sequences with shape (B, H, W, C), where B is the batch size, H and W
|
|
270
|
-
|
|
270
|
+
(torch.Tensor): Unpartitioned sequences with shape (B, H, W, C), where B is the batch size, H and W are the
|
|
271
|
+
original height and width, and C is the number of channels.
|
|
271
272
|
|
|
272
273
|
Examples:
|
|
273
274
|
>>> windows = torch.rand(32, 8, 8, 64) # 32 windows of size 8x8 with 64 channels
|
|
@@ -289,18 +290,16 @@ def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: tuple[in
|
|
|
289
290
|
|
|
290
291
|
|
|
291
292
|
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
|
|
292
|
-
"""
|
|
293
|
-
Extract relative positional embeddings based on query and key sizes.
|
|
293
|
+
"""Extract relative positional embeddings based on query and key sizes.
|
|
294
294
|
|
|
295
295
|
Args:
|
|
296
296
|
q_size (int): Size of the query.
|
|
297
297
|
k_size (int): Size of the key.
|
|
298
|
-
rel_pos (torch.Tensor): Relative position embeddings with shape (L, C), where L is the maximum relative
|
|
299
|
-
|
|
298
|
+
rel_pos (torch.Tensor): Relative position embeddings with shape (L, C), where L is the maximum relative distance
|
|
299
|
+
and C is the embedding dimension.
|
|
300
300
|
|
|
301
301
|
Returns:
|
|
302
|
-
(torch.Tensor): Extracted positional embeddings according to relative positions, with shape (q_size,
|
|
303
|
-
k_size, C).
|
|
302
|
+
(torch.Tensor): Extracted positional embeddings according to relative positions, with shape (q_size, k_size, C).
|
|
304
303
|
|
|
305
304
|
Examples:
|
|
306
305
|
>>> q_size, k_size = 8, 16
|
|
@@ -338,8 +337,7 @@ def add_decomposed_rel_pos(
|
|
|
338
337
|
q_size: tuple[int, int],
|
|
339
338
|
k_size: tuple[int, int],
|
|
340
339
|
) -> torch.Tensor:
|
|
341
|
-
"""
|
|
342
|
-
Add decomposed Relative Positional Embeddings to the attention map.
|
|
340
|
+
"""Add decomposed Relative Positional Embeddings to the attention map.
|
|
343
341
|
|
|
344
342
|
This function calculates and applies decomposed Relative Positional Embeddings as described in the MVITv2
|
|
345
343
|
paper. It enhances the attention mechanism by incorporating spatial relationships between query and key
|
|
@@ -354,8 +352,8 @@ def add_decomposed_rel_pos(
|
|
|
354
352
|
k_size (tuple[int, int]): Spatial sequence size of key k as (k_h, k_w).
|
|
355
353
|
|
|
356
354
|
Returns:
|
|
357
|
-
(torch.Tensor): Updated attention map with added relative positional embeddings, shape
|
|
358
|
-
|
|
355
|
+
(torch.Tensor): Updated attention map with added relative positional embeddings, shape (B, q_h * q_w, k_h *
|
|
356
|
+
k_w).
|
|
359
357
|
|
|
360
358
|
Examples:
|
|
361
359
|
>>> B, C, q_h, q_w, k_h, k_w = 1, 64, 8, 8, 8, 8
|
|
@@ -386,3 +384,129 @@ def add_decomposed_rel_pos(
|
|
|
386
384
|
)
|
|
387
385
|
|
|
388
386
|
return attn
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
def get_abs_pos(
|
|
390
|
+
abs_pos: torch.Tensor,
|
|
391
|
+
has_cls_token: bool,
|
|
392
|
+
hw: tuple[int, int],
|
|
393
|
+
retain_cls_token: bool = False,
|
|
394
|
+
tiling: bool = False,
|
|
395
|
+
) -> torch.Tensor:
|
|
396
|
+
"""Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token dimension for the
|
|
397
|
+
original embeddings.
|
|
398
|
+
|
|
399
|
+
Args:
|
|
400
|
+
abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
|
|
401
|
+
has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
|
|
402
|
+
hw (Tuple): size of input image tokens.
|
|
403
|
+
retain_cls_token: whether to retain the cls_token
|
|
404
|
+
tiling: whether to tile the embeddings, *instead* of interpolation (a la abs_win)
|
|
405
|
+
|
|
406
|
+
Returns:
|
|
407
|
+
Absolute positional embeddings after processing with shape (1, H, W, C),: if retain_cls_token is False,
|
|
408
|
+
otherwise (1, 1+H*W, C).
|
|
409
|
+
"""
|
|
410
|
+
if retain_cls_token:
|
|
411
|
+
assert has_cls_token
|
|
412
|
+
|
|
413
|
+
h, w = hw
|
|
414
|
+
if has_cls_token:
|
|
415
|
+
cls_pos = abs_pos[:, :1]
|
|
416
|
+
abs_pos = abs_pos[:, 1:]
|
|
417
|
+
|
|
418
|
+
xy_num = abs_pos.shape[1]
|
|
419
|
+
size = int(math.sqrt(xy_num))
|
|
420
|
+
assert size * size == xy_num
|
|
421
|
+
|
|
422
|
+
if size != h or size != w:
|
|
423
|
+
new_abs_pos = abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2)
|
|
424
|
+
if tiling:
|
|
425
|
+
new_abs_pos = new_abs_pos.tile([1, 1] + [x // y + 1 for x, y in zip((h, w), new_abs_pos.shape[2:])])[
|
|
426
|
+
:, :, :h, :w
|
|
427
|
+
]
|
|
428
|
+
else:
|
|
429
|
+
new_abs_pos = F.interpolate(
|
|
430
|
+
new_abs_pos,
|
|
431
|
+
size=(h, w),
|
|
432
|
+
mode="bicubic",
|
|
433
|
+
align_corners=False,
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
if not retain_cls_token:
|
|
437
|
+
return new_abs_pos.permute(0, 2, 3, 1)
|
|
438
|
+
else:
|
|
439
|
+
# add cls_token back, flatten spatial dims
|
|
440
|
+
assert has_cls_token
|
|
441
|
+
return torch.cat(
|
|
442
|
+
[cls_pos, new_abs_pos.permute(0, 2, 3, 1).reshape(1, h * w, -1)],
|
|
443
|
+
dim=1,
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
else:
|
|
447
|
+
if not retain_cls_token:
|
|
448
|
+
return abs_pos.reshape(1, h, w, -1)
|
|
449
|
+
else:
|
|
450
|
+
assert has_cls_token
|
|
451
|
+
return torch.cat([cls_pos, abs_pos], dim=1)
|
|
452
|
+
|
|
453
|
+
|
|
454
|
+
def concat_rel_pos(
|
|
455
|
+
q: torch.Tensor,
|
|
456
|
+
k: torch.Tensor,
|
|
457
|
+
q_hw: tuple[int, int],
|
|
458
|
+
k_hw: tuple[int, int],
|
|
459
|
+
rel_pos_h: torch.Tensor,
|
|
460
|
+
rel_pos_w: torch.Tensor,
|
|
461
|
+
rescale: bool = False,
|
|
462
|
+
relative_coords: torch.Tensor = None,
|
|
463
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
464
|
+
"""Concatenate rel pos coeffs to the q & k tensors, so that qk^T is now effectively including rel pos biases.
|
|
465
|
+
|
|
466
|
+
Args:
|
|
467
|
+
q (torch.Tensor): q tensor with shape (B, L_q, C).
|
|
468
|
+
k (torch.Tensor): k tensor with shape (B, L_k, C).
|
|
469
|
+
q_hw: These are spatial size of q tensors.
|
|
470
|
+
k_hw: These are spatial size of k tensors.
|
|
471
|
+
rel_pos_h: These are relative pos embeddings/params of height.
|
|
472
|
+
rel_pos_w: These are relative pos embeddings/params of width.
|
|
473
|
+
rescale (bool): whether to rescale. e.g. for use when using sdpa, pytorch will scale by the wrong factor due to
|
|
474
|
+
the concat.
|
|
475
|
+
relative_coords (torch.Tensor, optional): Precomputed relative coords index tensor.
|
|
476
|
+
|
|
477
|
+
Returns:
|
|
478
|
+
q, k: But, padded so that qk^T accounts for rel pos biases.
|
|
479
|
+
"""
|
|
480
|
+
q_h, q_w = q_hw
|
|
481
|
+
k_h, k_w = k_hw
|
|
482
|
+
|
|
483
|
+
assert (q_h == q_w) and (k_h == k_w), "only square inputs supported"
|
|
484
|
+
|
|
485
|
+
if relative_coords is not None:
|
|
486
|
+
Rh = rel_pos_h[relative_coords]
|
|
487
|
+
Rw = rel_pos_w[relative_coords]
|
|
488
|
+
else:
|
|
489
|
+
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
|
|
490
|
+
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
|
|
491
|
+
|
|
492
|
+
B, _, dim = q.shape
|
|
493
|
+
r_q = q.reshape(B, q_h, q_w, dim)
|
|
494
|
+
|
|
495
|
+
old_scale = dim**0.5
|
|
496
|
+
new_scale = (dim + k_h + k_w) ** 0.5 if rescale else old_scale # for sdpa
|
|
497
|
+
# attn will be divided by new_scale, but we want to divide q by old_scale
|
|
498
|
+
scale_ratio = new_scale / old_scale
|
|
499
|
+
|
|
500
|
+
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) * new_scale # (B, q_h, q_w, k_h)
|
|
501
|
+
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) * new_scale # (B, q_h, q_w, k_w)
|
|
502
|
+
|
|
503
|
+
eye_h = torch.eye(k_h, dtype=q.dtype, device=q.device)
|
|
504
|
+
eye_w = torch.eye(k_w, dtype=q.dtype, device=q.device)
|
|
505
|
+
|
|
506
|
+
eye_h = eye_h.view(1, k_h, 1, k_h).expand([B, k_h, k_w, k_h])
|
|
507
|
+
eye_w = eye_w.view(1, 1, k_w, k_w).expand([B, k_h, k_w, k_w])
|
|
508
|
+
|
|
509
|
+
q = torch.cat([r_q * scale_ratio, rel_h, rel_w], dim=-1).view(B, q_h * q_w, -1)
|
|
510
|
+
k = torch.cat([k.view(B, k_h, k_w, -1), eye_h, eye_w], dim=-1).view(B, k_h * k_w, -1)
|
|
511
|
+
|
|
512
|
+
return q, k
|