dgenerate-ultralytics-headless 8.3.235__py3-none-any.whl → 8.3.237__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.235.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/METADATA +1 -1
- {dgenerate_ultralytics_headless-8.3.235.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/RECORD +41 -28
- tests/test_exports.py +15 -1
- ultralytics/__init__.py +1 -1
- ultralytics/engine/exporter.py +113 -12
- ultralytics/engine/predictor.py +3 -2
- ultralytics/engine/trainer.py +8 -0
- ultralytics/models/rtdetr/val.py +5 -1
- ultralytics/models/sam/__init__.py +14 -1
- ultralytics/models/sam/build.py +17 -8
- ultralytics/models/sam/build_sam3.py +374 -0
- ultralytics/models/sam/model.py +12 -4
- ultralytics/models/sam/modules/blocks.py +20 -8
- ultralytics/models/sam/modules/decoders.py +2 -3
- ultralytics/models/sam/modules/encoders.py +4 -1
- ultralytics/models/sam/modules/memory_attention.py +6 -2
- ultralytics/models/sam/modules/sam.py +150 -6
- ultralytics/models/sam/modules/utils.py +134 -4
- ultralytics/models/sam/predict.py +2076 -118
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +535 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +198 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +357 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/tokenizer_ve.py +242 -0
- ultralytics/models/sam/sam3/vitdet.py +546 -0
- ultralytics/models/sam/sam3/vl_combiner.py +165 -0
- ultralytics/models/yolo/obb/val.py +18 -7
- ultralytics/nn/autobackend.py +35 -0
- ultralytics/nn/modules/transformer.py +21 -1
- ultralytics/utils/checks.py +41 -0
- ultralytics/utils/ops.py +1 -3
- ultralytics/utils/torch_utils.py +1 -0
- {dgenerate_ultralytics_headless-8.3.235.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.235.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.235.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.235.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/top_level.txt +0 -0
|
@@ -13,7 +13,7 @@ from torch.nn.init import trunc_normal_
|
|
|
13
13
|
from ultralytics.nn.modules import MLP
|
|
14
14
|
from ultralytics.utils import LOGGER
|
|
15
15
|
|
|
16
|
-
from .blocks import SAM2TwoWayTransformer
|
|
16
|
+
from .blocks import SAM2TwoWayTransformer, TwoWayTransformer
|
|
17
17
|
from .decoders import MaskDecoder, SAM2MaskDecoder
|
|
18
18
|
from .encoders import ImageEncoderViT, PromptEncoder
|
|
19
19
|
from .utils import get_1d_sine_pe, select_closest_cond_frames
|
|
@@ -329,6 +329,7 @@ class SAM2Model(torch.nn.Module):
|
|
|
329
329
|
|
|
330
330
|
self._build_sam_heads()
|
|
331
331
|
self.max_cond_frames_in_attn = max_cond_frames_in_attn
|
|
332
|
+
self.add_all_frames_to_correct_as_cond = True
|
|
332
333
|
|
|
333
334
|
# Model compilation
|
|
334
335
|
if compile_image_encoder:
|
|
@@ -473,7 +474,7 @@ class SAM2Model(torch.nn.Module):
|
|
|
473
474
|
assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
|
|
474
475
|
if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
|
|
475
476
|
sam_mask_prompt = F.interpolate(
|
|
476
|
-
mask_inputs.
|
|
477
|
+
mask_inputs.to(backbone_features.dtype),
|
|
477
478
|
size=self.sam_prompt_encoder.mask_input_size,
|
|
478
479
|
align_corners=False,
|
|
479
480
|
mode="bilinear",
|
|
@@ -571,7 +572,7 @@ class SAM2Model(torch.nn.Module):
|
|
|
571
572
|
# produce an object pointer using the SAM decoder from the mask input
|
|
572
573
|
_, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
|
|
573
574
|
backbone_features=backbone_features,
|
|
574
|
-
mask_inputs=self.mask_downsample(mask_inputs_float),
|
|
575
|
+
mask_inputs=self.mask_downsample(mask_inputs_float.to(backbone_features.dtype)),
|
|
575
576
|
high_res_features=high_res_features,
|
|
576
577
|
)
|
|
577
578
|
# In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
|
|
@@ -818,7 +819,6 @@ class SAM2Model(torch.nn.Module):
|
|
|
818
819
|
mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
|
|
819
820
|
maskmem_out = self.memory_encoder(pix_feat, mask_for_mem, skip_mask_sigmoid=True) # sigmoid already applied
|
|
820
821
|
maskmem_features = maskmem_out["vision_features"]
|
|
821
|
-
maskmem_pos_enc = maskmem_out["vision_pos_enc"]
|
|
822
822
|
# add a no-object embedding to the spatial memory to indicate that the frame
|
|
823
823
|
# is predicted to be occluded (i.e. no object is appearing in the frame)
|
|
824
824
|
if self.no_obj_embed_spatial is not None:
|
|
@@ -827,7 +827,7 @@ class SAM2Model(torch.nn.Module):
|
|
|
827
827
|
..., None, None
|
|
828
828
|
].expand(*maskmem_features.shape)
|
|
829
829
|
|
|
830
|
-
return maskmem_features,
|
|
830
|
+
return maskmem_features, maskmem_out["vision_pos_enc"]
|
|
831
831
|
|
|
832
832
|
def _track_step(
|
|
833
833
|
self,
|
|
@@ -1005,7 +1005,151 @@ class SAM2Model(torch.nn.Module):
|
|
|
1005
1005
|
|
|
1006
1006
|
def set_imgsz(self, imgsz):
|
|
1007
1007
|
"""Set image size to make model compatible with different image sizes."""
|
|
1008
|
+
if hasattr(self.image_encoder, "set_imgsz"):
|
|
1009
|
+
self.image_encoder.set_imgsz(imgsz)
|
|
1008
1010
|
self.image_size = imgsz[0]
|
|
1009
1011
|
self.sam_prompt_encoder.input_image_size = imgsz
|
|
1010
|
-
self.sam_prompt_encoder.image_embedding_size = [
|
|
1012
|
+
self.sam_prompt_encoder.image_embedding_size = [
|
|
1013
|
+
x // self.backbone_stride for x in imgsz
|
|
1014
|
+
] # fixed ViT patch size of 16
|
|
1015
|
+
self.sam_prompt_encoder.mask_input_size = [
|
|
1016
|
+
x // self.backbone_stride * 4 for x in imgsz
|
|
1017
|
+
] # fixed ViT patch size of 16
|
|
1011
1018
|
self.sam_image_embedding_size = self.image_size // self.backbone_stride # update image embedding size
|
|
1019
|
+
|
|
1020
|
+
|
|
1021
|
+
class SAM3Model(SAM2Model):
|
|
1022
|
+
"""SAM3Model class for Segment Anything Model 3 with memory-based video object segmentation capabilities."""
|
|
1023
|
+
|
|
1024
|
+
def __init__(
|
|
1025
|
+
self,
|
|
1026
|
+
image_encoder,
|
|
1027
|
+
memory_attention,
|
|
1028
|
+
memory_encoder,
|
|
1029
|
+
num_maskmem=7,
|
|
1030
|
+
image_size=1008,
|
|
1031
|
+
backbone_stride=14,
|
|
1032
|
+
sigmoid_scale_for_mem_enc=1,
|
|
1033
|
+
sigmoid_bias_for_mem_enc=0,
|
|
1034
|
+
binarize_mask_from_pts_for_mem_enc=False,
|
|
1035
|
+
use_mask_input_as_output_without_sam=False,
|
|
1036
|
+
max_cond_frames_in_attn=-1,
|
|
1037
|
+
directly_add_no_mem_embed=False,
|
|
1038
|
+
use_high_res_features_in_sam=False,
|
|
1039
|
+
multimask_output_in_sam=False,
|
|
1040
|
+
multimask_min_pt_num=1,
|
|
1041
|
+
multimask_max_pt_num=1,
|
|
1042
|
+
multimask_output_for_tracking=False,
|
|
1043
|
+
use_multimask_token_for_obj_ptr: bool = False,
|
|
1044
|
+
iou_prediction_use_sigmoid=False,
|
|
1045
|
+
memory_temporal_stride_for_eval=1,
|
|
1046
|
+
non_overlap_masks_for_mem_enc=False,
|
|
1047
|
+
use_obj_ptrs_in_encoder=False,
|
|
1048
|
+
max_obj_ptrs_in_encoder=16,
|
|
1049
|
+
add_tpos_enc_to_obj_ptrs=True,
|
|
1050
|
+
proj_tpos_enc_in_obj_ptrs=False,
|
|
1051
|
+
use_signed_tpos_enc_to_obj_ptrs=False,
|
|
1052
|
+
only_obj_ptrs_in_the_past_for_eval=False,
|
|
1053
|
+
pred_obj_scores: bool = False,
|
|
1054
|
+
pred_obj_scores_mlp: bool = False,
|
|
1055
|
+
fixed_no_obj_ptr: bool = False,
|
|
1056
|
+
soft_no_obj_ptr: bool = False,
|
|
1057
|
+
use_mlp_for_obj_ptr_proj: bool = False,
|
|
1058
|
+
no_obj_embed_spatial: bool = False,
|
|
1059
|
+
sam_mask_decoder_extra_args=None,
|
|
1060
|
+
compile_image_encoder: bool = False,
|
|
1061
|
+
):
|
|
1062
|
+
"""SAM3Model class for Segment Anything Model 3 with memory-based video object segmentation capabilities."""
|
|
1063
|
+
super().__init__(
|
|
1064
|
+
image_encoder,
|
|
1065
|
+
memory_attention,
|
|
1066
|
+
memory_encoder,
|
|
1067
|
+
num_maskmem,
|
|
1068
|
+
image_size,
|
|
1069
|
+
backbone_stride,
|
|
1070
|
+
sigmoid_scale_for_mem_enc,
|
|
1071
|
+
sigmoid_bias_for_mem_enc,
|
|
1072
|
+
binarize_mask_from_pts_for_mem_enc,
|
|
1073
|
+
use_mask_input_as_output_without_sam,
|
|
1074
|
+
max_cond_frames_in_attn,
|
|
1075
|
+
directly_add_no_mem_embed,
|
|
1076
|
+
use_high_res_features_in_sam,
|
|
1077
|
+
multimask_output_in_sam,
|
|
1078
|
+
multimask_min_pt_num,
|
|
1079
|
+
multimask_max_pt_num,
|
|
1080
|
+
multimask_output_for_tracking,
|
|
1081
|
+
use_multimask_token_for_obj_ptr,
|
|
1082
|
+
iou_prediction_use_sigmoid,
|
|
1083
|
+
memory_temporal_stride_for_eval,
|
|
1084
|
+
non_overlap_masks_for_mem_enc,
|
|
1085
|
+
use_obj_ptrs_in_encoder,
|
|
1086
|
+
max_obj_ptrs_in_encoder,
|
|
1087
|
+
add_tpos_enc_to_obj_ptrs,
|
|
1088
|
+
proj_tpos_enc_in_obj_ptrs,
|
|
1089
|
+
use_signed_tpos_enc_to_obj_ptrs,
|
|
1090
|
+
only_obj_ptrs_in_the_past_for_eval,
|
|
1091
|
+
pred_obj_scores,
|
|
1092
|
+
pred_obj_scores_mlp,
|
|
1093
|
+
fixed_no_obj_ptr,
|
|
1094
|
+
soft_no_obj_ptr,
|
|
1095
|
+
use_mlp_for_obj_ptr_proj,
|
|
1096
|
+
no_obj_embed_spatial,
|
|
1097
|
+
sam_mask_decoder_extra_args,
|
|
1098
|
+
compile_image_encoder,
|
|
1099
|
+
)
|
|
1100
|
+
self.sam_mask_decoder = SAM2MaskDecoder(
|
|
1101
|
+
num_multimask_outputs=3,
|
|
1102
|
+
transformer=TwoWayTransformer(
|
|
1103
|
+
depth=2,
|
|
1104
|
+
embedding_dim=self.sam_prompt_embed_dim,
|
|
1105
|
+
mlp_dim=2048,
|
|
1106
|
+
num_heads=8,
|
|
1107
|
+
),
|
|
1108
|
+
transformer_dim=self.sam_prompt_embed_dim,
|
|
1109
|
+
iou_head_depth=3,
|
|
1110
|
+
iou_head_hidden_dim=256,
|
|
1111
|
+
use_high_res_features=self.use_high_res_features_in_sam,
|
|
1112
|
+
iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,
|
|
1113
|
+
pred_obj_scores=self.pred_obj_scores,
|
|
1114
|
+
pred_obj_scores_mlp=self.pred_obj_scores_mlp,
|
|
1115
|
+
use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
|
|
1116
|
+
**(self.sam_mask_decoder_extra_args or {}),
|
|
1117
|
+
)
|
|
1118
|
+
|
|
1119
|
+
def forward_image(self, img_batch: torch.Tensor):
|
|
1120
|
+
"""Process image batch through encoder to extract multi-level features for SAM model."""
|
|
1121
|
+
backbone_out = self.image_encoder.forward_image_sam2(img_batch)
|
|
1122
|
+
if self.use_high_res_features_in_sam:
|
|
1123
|
+
# precompute projected level 0 and level 1 features in SAM decoder
|
|
1124
|
+
# to avoid running it again on every SAM click
|
|
1125
|
+
backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
|
|
1126
|
+
backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
|
|
1127
|
+
return backbone_out
|
|
1128
|
+
|
|
1129
|
+
def set_imgsz(self, imgsz: tuple[int, int]):
|
|
1130
|
+
"""Set the image size for the model and mask downsampler."""
|
|
1131
|
+
super().set_imgsz(imgsz)
|
|
1132
|
+
self.memory_encoder.mask_downsampler.interpol_size = [size // 14 * 16 for size in imgsz]
|
|
1133
|
+
|
|
1134
|
+
@staticmethod
|
|
1135
|
+
def _suppress_shrinked_masks(pred_masks, new_pred_masks, shrink_threshold=0.3):
|
|
1136
|
+
"""Suppress masks that shrink in area after applying pixelwise non-overlapping constraints."""
|
|
1137
|
+
area_before = (pred_masks > 0).sum(dim=(-1, -2))
|
|
1138
|
+
area_after = (new_pred_masks > 0).sum(dim=(-1, -2))
|
|
1139
|
+
area_before = torch.clamp(area_before, min=1.0)
|
|
1140
|
+
area_ratio = area_after / area_before
|
|
1141
|
+
keep = area_ratio >= shrink_threshold
|
|
1142
|
+
keep_mask = keep[..., None, None].expand_as(pred_masks)
|
|
1143
|
+
pred_masks_after = torch.where(keep_mask, pred_masks, torch.clamp(pred_masks, max=-10.0))
|
|
1144
|
+
return pred_masks_after
|
|
1145
|
+
|
|
1146
|
+
def _suppress_object_pw_area_shrinkage(self, pred_masks):
|
|
1147
|
+
"""This function suppresses masks that shrink in area after applying pixelwise non-overlapping constraints. Note
|
|
1148
|
+
that the final output can still be overlapping.
|
|
1149
|
+
"""
|
|
1150
|
+
# Apply pixel-wise non-overlapping constraint based on mask scores
|
|
1151
|
+
pixel_level_non_overlapping_masks = self._apply_non_overlapping_constraints(pred_masks)
|
|
1152
|
+
# Fully suppress masks with high shrinkage (probably noisy) based on the pixel wise non-overlapping constraints
|
|
1153
|
+
# NOTE: The output of this function can be a no op if none of the masks shrinked by a large factor.
|
|
1154
|
+
pred_masks = self._suppress_shrinked_masks(pred_masks, pixel_level_non_overlapping_masks)
|
|
1155
|
+
return pred_masks
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
import math
|
|
5
6
|
from typing import Any
|
|
6
7
|
|
|
7
8
|
import torch
|
|
@@ -86,7 +87,7 @@ def get_1d_sine_pe(pos_inds: torch.Tensor, dim: int, temperature: float = 10000)
|
|
|
86
87
|
return pos_embed
|
|
87
88
|
|
|
88
89
|
|
|
89
|
-
def init_t_xy(end_x: int, end_y: int):
|
|
90
|
+
def init_t_xy(end_x: int, end_y: int, scale: float = 1.0, offset: int = 0):
|
|
90
91
|
"""Initialize 1D and 2D coordinate tensors for a grid of specified dimensions.
|
|
91
92
|
|
|
92
93
|
This function creates coordinate tensors for a grid with dimensions end_x × end_y. It generates a linear index
|
|
@@ -95,6 +96,8 @@ def init_t_xy(end_x: int, end_y: int):
|
|
|
95
96
|
Args:
|
|
96
97
|
end_x (int): Width of the grid (number of columns).
|
|
97
98
|
end_y (int): Height of the grid (number of rows).
|
|
99
|
+
scale (float): Scaling factor to apply to the coordinates.
|
|
100
|
+
offset (int): Offset to add to the coordinates.
|
|
98
101
|
|
|
99
102
|
Returns:
|
|
100
103
|
t_x (torch.Tensor): X-coordinates for each position, with shape (end_x * end_y).
|
|
@@ -110,10 +113,10 @@ def init_t_xy(end_x: int, end_y: int):
|
|
|
110
113
|
t = torch.arange(end_x * end_y, dtype=torch.float32)
|
|
111
114
|
t_x = (t % end_x).float()
|
|
112
115
|
t_y = torch.div(t, end_x, rounding_mode="floor").float()
|
|
113
|
-
return t_x, t_y
|
|
116
|
+
return t_x * scale + offset, t_y * scale + offset
|
|
114
117
|
|
|
115
118
|
|
|
116
|
-
def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
|
|
119
|
+
def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0, scale_pos: float = 1.0):
|
|
117
120
|
"""Compute axial complex exponential positional encodings for 2D spatial positions in a grid.
|
|
118
121
|
|
|
119
122
|
This function generates complex exponential positional encodings for a 2D grid of spatial positions, using separate
|
|
@@ -124,6 +127,7 @@ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
|
|
|
124
127
|
end_x (int): Width of the 2D grid.
|
|
125
128
|
end_y (int): Height of the 2D grid.
|
|
126
129
|
theta (float, optional): Scaling factor for frequency computation.
|
|
130
|
+
scale_pos (float, optional): Scaling factor for position coordinates.
|
|
127
131
|
|
|
128
132
|
Returns:
|
|
129
133
|
(torch.Tensor): Complex exponential positional encodings with shape (end_x*end_y, dim//2).
|
|
@@ -137,7 +141,7 @@ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
|
|
|
137
141
|
freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
|
|
138
142
|
freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
|
|
139
143
|
|
|
140
|
-
t_x, t_y = init_t_xy(end_x, end_y)
|
|
144
|
+
t_x, t_y = init_t_xy(end_x, end_y, scale=scale_pos)
|
|
141
145
|
freqs_x = torch.outer(t_x, freqs_x)
|
|
142
146
|
freqs_y = torch.outer(t_y, freqs_y)
|
|
143
147
|
freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
|
|
@@ -375,3 +379,129 @@ def add_decomposed_rel_pos(
|
|
|
375
379
|
)
|
|
376
380
|
|
|
377
381
|
return attn
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def get_abs_pos(
|
|
385
|
+
abs_pos: torch.Tensor,
|
|
386
|
+
has_cls_token: bool,
|
|
387
|
+
hw: tuple[int, int],
|
|
388
|
+
retain_cls_token: bool = False,
|
|
389
|
+
tiling: bool = False,
|
|
390
|
+
) -> torch.Tensor:
|
|
391
|
+
"""Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token dimension for the
|
|
392
|
+
original embeddings.
|
|
393
|
+
|
|
394
|
+
Args:
|
|
395
|
+
abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
|
|
396
|
+
has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
|
|
397
|
+
hw (Tuple): size of input image tokens.
|
|
398
|
+
retain_cls_token: whether to retain the cls_token
|
|
399
|
+
tiling: whether to tile the embeddings, *instead* of interpolation (a la abs_win)
|
|
400
|
+
|
|
401
|
+
Returns:
|
|
402
|
+
Absolute positional embeddings after processing with shape (1, H, W, C),: if retain_cls_token is False,
|
|
403
|
+
otherwise (1, 1+H*W, C).
|
|
404
|
+
"""
|
|
405
|
+
if retain_cls_token:
|
|
406
|
+
assert has_cls_token
|
|
407
|
+
|
|
408
|
+
h, w = hw
|
|
409
|
+
if has_cls_token:
|
|
410
|
+
cls_pos = abs_pos[:, :1]
|
|
411
|
+
abs_pos = abs_pos[:, 1:]
|
|
412
|
+
|
|
413
|
+
xy_num = abs_pos.shape[1]
|
|
414
|
+
size = int(math.sqrt(xy_num))
|
|
415
|
+
assert size * size == xy_num
|
|
416
|
+
|
|
417
|
+
if size != h or size != w:
|
|
418
|
+
new_abs_pos = abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2)
|
|
419
|
+
if tiling:
|
|
420
|
+
new_abs_pos = new_abs_pos.tile([1, 1] + [x // y + 1 for x, y in zip((h, w), new_abs_pos.shape[2:])])[
|
|
421
|
+
:, :, :h, :w
|
|
422
|
+
]
|
|
423
|
+
else:
|
|
424
|
+
new_abs_pos = F.interpolate(
|
|
425
|
+
new_abs_pos,
|
|
426
|
+
size=(h, w),
|
|
427
|
+
mode="bicubic",
|
|
428
|
+
align_corners=False,
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
if not retain_cls_token:
|
|
432
|
+
return new_abs_pos.permute(0, 2, 3, 1)
|
|
433
|
+
else:
|
|
434
|
+
# add cls_token back, flatten spatial dims
|
|
435
|
+
assert has_cls_token
|
|
436
|
+
return torch.cat(
|
|
437
|
+
[cls_pos, new_abs_pos.permute(0, 2, 3, 1).reshape(1, h * w, -1)],
|
|
438
|
+
dim=1,
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
else:
|
|
442
|
+
if not retain_cls_token:
|
|
443
|
+
return abs_pos.reshape(1, h, w, -1)
|
|
444
|
+
else:
|
|
445
|
+
assert has_cls_token
|
|
446
|
+
return torch.cat([cls_pos, abs_pos], dim=1)
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
def concat_rel_pos(
|
|
450
|
+
q: torch.Tensor,
|
|
451
|
+
k: torch.Tensor,
|
|
452
|
+
q_hw: tuple[int, int],
|
|
453
|
+
k_hw: tuple[int, int],
|
|
454
|
+
rel_pos_h: torch.Tensor,
|
|
455
|
+
rel_pos_w: torch.Tensor,
|
|
456
|
+
rescale: bool = False,
|
|
457
|
+
relative_coords: torch.Tensor = None,
|
|
458
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
459
|
+
"""Concatenate rel pos coeffs to the q & k tensors, so that qk^T is now effectively including rel pos biases.
|
|
460
|
+
|
|
461
|
+
Args:
|
|
462
|
+
q (torch.Tensor): q tensor with shape (B, L_q, C).
|
|
463
|
+
k (torch.Tensor): k tensor with shape (B, L_k, C).
|
|
464
|
+
q_hw: These are spatial size of q tensors.
|
|
465
|
+
k_hw: These are spatial size of k tensors.
|
|
466
|
+
rel_pos_h: These are relative pos embeddings/params of height.
|
|
467
|
+
rel_pos_w: These are relative pos embeddings/params of width.
|
|
468
|
+
rescale (bool): whether to rescale. e.g. for use when using sdpa, pytorch will scale by the wrong factor due to
|
|
469
|
+
the concat.
|
|
470
|
+
relative_coords (torch.Tensor, optional): Precomputed relative coords index tensor.
|
|
471
|
+
|
|
472
|
+
Returns:
|
|
473
|
+
q, k: But, padded so that qk^T accounts for rel pos biases.
|
|
474
|
+
"""
|
|
475
|
+
q_h, q_w = q_hw
|
|
476
|
+
k_h, k_w = k_hw
|
|
477
|
+
|
|
478
|
+
assert (q_h == q_w) and (k_h == k_w), "only square inputs supported"
|
|
479
|
+
|
|
480
|
+
if relative_coords is not None:
|
|
481
|
+
Rh = rel_pos_h[relative_coords]
|
|
482
|
+
Rw = rel_pos_w[relative_coords]
|
|
483
|
+
else:
|
|
484
|
+
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
|
|
485
|
+
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
|
|
486
|
+
|
|
487
|
+
B, _, dim = q.shape
|
|
488
|
+
r_q = q.reshape(B, q_h, q_w, dim)
|
|
489
|
+
|
|
490
|
+
old_scale = dim**0.5
|
|
491
|
+
new_scale = (dim + k_h + k_w) ** 0.5 if rescale else old_scale # for sdpa
|
|
492
|
+
# attn will be divided by new_scale, but we want to divide q by old_scale
|
|
493
|
+
scale_ratio = new_scale / old_scale
|
|
494
|
+
|
|
495
|
+
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) * new_scale # (B, q_h, q_w, k_h)
|
|
496
|
+
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) * new_scale # (B, q_h, q_w, k_w)
|
|
497
|
+
|
|
498
|
+
eye_h = torch.eye(k_h, dtype=q.dtype, device=q.device)
|
|
499
|
+
eye_w = torch.eye(k_w, dtype=q.dtype, device=q.device)
|
|
500
|
+
|
|
501
|
+
eye_h = eye_h.view(1, k_h, 1, k_h).expand([B, k_h, k_w, k_h])
|
|
502
|
+
eye_w = eye_w.view(1, 1, k_w, k_w).expand([B, k_h, k_w, k_w])
|
|
503
|
+
|
|
504
|
+
q = torch.cat([r_q * scale_ratio, rel_h, rel_w], dim=-1).view(B, q_h * q_w, -1)
|
|
505
|
+
k = torch.cat([k.view(B, k_h, k_w, -1), eye_h, eye_w], dim=-1).view(B, k_h * k_w, -1)
|
|
506
|
+
|
|
507
|
+
return q, k
|