dgenerate-ultralytics-headless 8.3.222__py3-none-any.whl → 8.3.225__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/METADATA +2 -2
- dgenerate_ultralytics_headless-8.3.225.dist-info/RECORD +286 -0
- tests/conftest.py +5 -8
- tests/test_cli.py +1 -8
- tests/test_python.py +1 -2
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +34 -49
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +5 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/data/annotator.py +3 -4
- ultralytics/data/augment.py +244 -323
- ultralytics/data/base.py +12 -22
- ultralytics/data/build.py +47 -40
- ultralytics/data/converter.py +32 -42
- ultralytics/data/dataset.py +43 -71
- ultralytics/data/loaders.py +22 -34
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +27 -36
- ultralytics/engine/exporter.py +49 -116
- ultralytics/engine/model.py +144 -180
- ultralytics/engine/predictor.py +18 -29
- ultralytics/engine/results.py +165 -231
- ultralytics/engine/trainer.py +11 -19
- ultralytics/engine/tuner.py +13 -23
- ultralytics/engine/validator.py +6 -10
- ultralytics/hub/__init__.py +7 -12
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +3 -6
- ultralytics/models/fastsam/model.py +6 -8
- ultralytics/models/fastsam/predict.py +5 -10
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +2 -4
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/model.py +5 -8
- ultralytics/models/rtdetr/predict.py +15 -18
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +13 -20
- ultralytics/models/sam/amg.py +12 -18
- ultralytics/models/sam/build.py +6 -9
- ultralytics/models/sam/model.py +16 -23
- ultralytics/models/sam/modules/blocks.py +62 -84
- ultralytics/models/sam/modules/decoders.py +17 -24
- ultralytics/models/sam/modules/encoders.py +40 -56
- ultralytics/models/sam/modules/memory_attention.py +10 -16
- ultralytics/models/sam/modules/sam.py +41 -47
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +17 -27
- ultralytics/models/sam/modules/utils.py +31 -42
- ultralytics/models/sam/predict.py +172 -209
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/classify/predict.py +8 -11
- ultralytics/models/yolo/classify/train.py +8 -16
- ultralytics/models/yolo/classify/val.py +13 -20
- ultralytics/models/yolo/detect/predict.py +4 -8
- ultralytics/models/yolo/detect/train.py +11 -20
- ultralytics/models/yolo/detect/val.py +38 -48
- ultralytics/models/yolo/model.py +35 -47
- ultralytics/models/yolo/obb/predict.py +5 -8
- ultralytics/models/yolo/obb/train.py +11 -14
- ultralytics/models/yolo/obb/val.py +20 -28
- ultralytics/models/yolo/pose/predict.py +5 -8
- ultralytics/models/yolo/pose/train.py +4 -8
- ultralytics/models/yolo/pose/val.py +31 -39
- ultralytics/models/yolo/segment/predict.py +9 -14
- ultralytics/models/yolo/segment/train.py +3 -6
- ultralytics/models/yolo/segment/val.py +16 -26
- ultralytics/models/yolo/world/train.py +8 -14
- ultralytics/models/yolo/world/train_world.py +11 -16
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +30 -43
- ultralytics/models/yolo/yoloe/train_seg.py +5 -10
- ultralytics/models/yolo/yoloe/val.py +15 -20
- ultralytics/nn/autobackend.py +10 -18
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +99 -185
- ultralytics/nn/modules/conv.py +45 -90
- ultralytics/nn/modules/head.py +44 -98
- ultralytics/nn/modules/transformer.py +44 -76
- ultralytics/nn/modules/utils.py +14 -19
- ultralytics/nn/tasks.py +86 -146
- ultralytics/nn/text_model.py +25 -40
- ultralytics/solutions/ai_gym.py +10 -16
- ultralytics/solutions/analytics.py +7 -10
- ultralytics/solutions/config.py +4 -5
- ultralytics/solutions/distance_calculation.py +9 -12
- ultralytics/solutions/heatmap.py +7 -13
- ultralytics/solutions/instance_segmentation.py +5 -8
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +8 -12
- ultralytics/solutions/object_cropper.py +5 -8
- ultralytics/solutions/parking_management.py +12 -14
- ultralytics/solutions/queue_management.py +4 -6
- ultralytics/solutions/region_counter.py +7 -10
- ultralytics/solutions/security_alarm.py +14 -19
- ultralytics/solutions/similarity_search.py +7 -12
- ultralytics/solutions/solutions.py +31 -53
- ultralytics/solutions/speed_estimation.py +6 -9
- ultralytics/solutions/streamlit_inference.py +2 -4
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/basetrack.py +2 -4
- ultralytics/trackers/bot_sort.py +6 -11
- ultralytics/trackers/byte_tracker.py +10 -15
- ultralytics/trackers/track.py +3 -6
- ultralytics/trackers/utils/gmc.py +6 -12
- ultralytics/trackers/utils/kalman_filter.py +35 -43
- ultralytics/trackers/utils/matching.py +6 -10
- ultralytics/utils/__init__.py +61 -100
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +11 -13
- ultralytics/utils/benchmarks.py +25 -35
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +2 -4
- ultralytics/utils/callbacks/comet.py +30 -44
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +4 -6
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +4 -6
- ultralytics/utils/callbacks/wb.py +10 -13
- ultralytics/utils/checks.py +29 -56
- ultralytics/utils/cpu.py +1 -2
- ultralytics/utils/dist.py +8 -12
- ultralytics/utils/downloads.py +17 -27
- ultralytics/utils/errors.py +6 -8
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +4 -239
- ultralytics/utils/export/engine.py +237 -0
- ultralytics/utils/export/imx.py +11 -17
- ultralytics/utils/export/tensorflow.py +217 -0
- ultralytics/utils/files.py +10 -15
- ultralytics/utils/git.py +5 -7
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +11 -15
- ultralytics/utils/loss.py +8 -14
- ultralytics/utils/metrics.py +98 -138
- ultralytics/utils/nms.py +13 -16
- ultralytics/utils/ops.py +47 -74
- ultralytics/utils/patches.py +11 -18
- ultralytics/utils/plotting.py +29 -42
- ultralytics/utils/tal.py +25 -39
- ultralytics/utils/torch_utils.py +45 -73
- ultralytics/utils/tqdm.py +6 -8
- ultralytics/utils/triton.py +9 -12
- ultralytics/utils/tuner.py +1 -2
- dgenerate_ultralytics_headless-8.3.222.dist-info/RECORD +0 -283
- {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/top_level.txt +0 -0
|
@@ -11,12 +11,10 @@ from ultralytics.nn.modules import MLPBlock
|
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class TwoWayTransformer(nn.Module):
|
|
14
|
-
"""
|
|
15
|
-
A Two-Way Transformer module for simultaneous attention to image and query points.
|
|
14
|
+
"""A Two-Way Transformer module for simultaneous attention to image and query points.
|
|
16
15
|
|
|
17
|
-
This class implements a specialized transformer decoder that attends to an input image using queries with
|
|
18
|
-
|
|
19
|
-
cloud processing.
|
|
16
|
+
This class implements a specialized transformer decoder that attends to an input image using queries with supplied
|
|
17
|
+
positional embeddings. It's useful for tasks like object detection, image segmentation, and point cloud processing.
|
|
20
18
|
|
|
21
19
|
Attributes:
|
|
22
20
|
depth (int): Number of layers in the transformer.
|
|
@@ -48,8 +46,7 @@ class TwoWayTransformer(nn.Module):
|
|
|
48
46
|
activation: type[nn.Module] = nn.ReLU,
|
|
49
47
|
attention_downsample_rate: int = 2,
|
|
50
48
|
) -> None:
|
|
51
|
-
"""
|
|
52
|
-
Initialize a Two-Way Transformer for simultaneous attention to image and query points.
|
|
49
|
+
"""Initialize a Two-Way Transformer for simultaneous attention to image and query points.
|
|
53
50
|
|
|
54
51
|
Args:
|
|
55
52
|
depth (int): Number of layers in the transformer.
|
|
@@ -87,8 +84,7 @@ class TwoWayTransformer(nn.Module):
|
|
|
87
84
|
image_pe: torch.Tensor,
|
|
88
85
|
point_embedding: torch.Tensor,
|
|
89
86
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
90
|
-
"""
|
|
91
|
-
Process image and point embeddings through the Two-Way Transformer.
|
|
87
|
+
"""Process image and point embeddings through the Two-Way Transformer.
|
|
92
88
|
|
|
93
89
|
Args:
|
|
94
90
|
image_embedding (torch.Tensor): Image to attend to, with shape (B, embedding_dim, H, W).
|
|
@@ -127,12 +123,11 @@ class TwoWayTransformer(nn.Module):
|
|
|
127
123
|
|
|
128
124
|
|
|
129
125
|
class TwoWayAttentionBlock(nn.Module):
|
|
130
|
-
"""
|
|
131
|
-
A two-way attention block for simultaneous attention to image and query points.
|
|
126
|
+
"""A two-way attention block for simultaneous attention to image and query points.
|
|
132
127
|
|
|
133
128
|
This class implements a specialized transformer block with four main layers: self-attention on sparse inputs,
|
|
134
|
-
cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention of dense
|
|
135
|
-
|
|
129
|
+
cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention of dense inputs to
|
|
130
|
+
sparse inputs.
|
|
136
131
|
|
|
137
132
|
Attributes:
|
|
138
133
|
self_attn (Attention): Self-attention layer for queries.
|
|
@@ -167,12 +162,11 @@ class TwoWayAttentionBlock(nn.Module):
|
|
|
167
162
|
attention_downsample_rate: int = 2,
|
|
168
163
|
skip_first_layer_pe: bool = False,
|
|
169
164
|
) -> None:
|
|
170
|
-
"""
|
|
171
|
-
Initialize a TwoWayAttentionBlock for simultaneous attention to image and query points.
|
|
165
|
+
"""Initialize a TwoWayAttentionBlock for simultaneous attention to image and query points.
|
|
172
166
|
|
|
173
167
|
This block implements a specialized transformer layer with four main components: self-attention on sparse
|
|
174
|
-
inputs, cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention
|
|
175
|
-
|
|
168
|
+
inputs, cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention of
|
|
169
|
+
dense inputs to sparse inputs.
|
|
176
170
|
|
|
177
171
|
Args:
|
|
178
172
|
embedding_dim (int): Channel dimension of the embeddings.
|
|
@@ -200,8 +194,7 @@ class TwoWayAttentionBlock(nn.Module):
|
|
|
200
194
|
def forward(
|
|
201
195
|
self, queries: torch.Tensor, keys: torch.Tensor, query_pe: torch.Tensor, key_pe: torch.Tensor
|
|
202
196
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
203
|
-
"""
|
|
204
|
-
Apply two-way attention to process query and key embeddings in a transformer block.
|
|
197
|
+
"""Apply two-way attention to process query and key embeddings in a transformer block.
|
|
205
198
|
|
|
206
199
|
Args:
|
|
207
200
|
queries (torch.Tensor): Query embeddings with shape (B, N_queries, embedding_dim).
|
|
@@ -245,11 +238,10 @@ class TwoWayAttentionBlock(nn.Module):
|
|
|
245
238
|
|
|
246
239
|
|
|
247
240
|
class Attention(nn.Module):
|
|
248
|
-
"""
|
|
249
|
-
An attention layer with downscaling capability for embedding size after projection.
|
|
241
|
+
"""An attention layer with downscaling capability for embedding size after projection.
|
|
250
242
|
|
|
251
|
-
This class implements a multi-head attention mechanism with the option to downsample the internal
|
|
252
|
-
|
|
243
|
+
This class implements a multi-head attention mechanism with the option to downsample the internal dimension of
|
|
244
|
+
queries, keys, and values.
|
|
253
245
|
|
|
254
246
|
Attributes:
|
|
255
247
|
embedding_dim (int): Dimensionality of input embeddings.
|
|
@@ -282,8 +274,7 @@ class Attention(nn.Module):
|
|
|
282
274
|
downsample_rate: int = 1,
|
|
283
275
|
kv_in_dim: int | None = None,
|
|
284
276
|
) -> None:
|
|
285
|
-
"""
|
|
286
|
-
Initialize the Attention module with specified dimensions and settings.
|
|
277
|
+
"""Initialize the Attention module with specified dimensions and settings.
|
|
287
278
|
|
|
288
279
|
Args:
|
|
289
280
|
embedding_dim (int): Dimensionality of input embeddings.
|
|
@@ -321,8 +312,7 @@ class Attention(nn.Module):
|
|
|
321
312
|
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
|
|
322
313
|
|
|
323
314
|
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
|
324
|
-
"""
|
|
325
|
-
Apply multi-head attention to query, key, and value tensors with optional downsampling.
|
|
315
|
+
"""Apply multi-head attention to query, key, and value tensors with optional downsampling.
|
|
326
316
|
|
|
327
317
|
Args:
|
|
328
318
|
q (torch.Tensor): Query tensor with shape (B, N_q, embedding_dim).
|
|
@@ -9,8 +9,7 @@ import torch.nn.functional as F
|
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
def select_closest_cond_frames(frame_idx: int, cond_frame_outputs: dict[int, Any], max_cond_frame_num: int):
|
|
12
|
-
"""
|
|
13
|
-
Select the closest conditioning frames to a given frame index.
|
|
12
|
+
"""Select the closest conditioning frames to a given frame index.
|
|
14
13
|
|
|
15
14
|
Args:
|
|
16
15
|
frame_idx (int): Current frame index.
|
|
@@ -62,8 +61,7 @@ def select_closest_cond_frames(frame_idx: int, cond_frame_outputs: dict[int, Any
|
|
|
62
61
|
|
|
63
62
|
|
|
64
63
|
def get_1d_sine_pe(pos_inds: torch.Tensor, dim: int, temperature: float = 10000):
|
|
65
|
-
"""
|
|
66
|
-
Generate 1D sinusoidal positional embeddings for given positions and dimensions.
|
|
64
|
+
"""Generate 1D sinusoidal positional embeddings for given positions and dimensions.
|
|
67
65
|
|
|
68
66
|
Args:
|
|
69
67
|
pos_inds (torch.Tensor): Position indices for which to generate embeddings.
|
|
@@ -89,11 +87,10 @@ def get_1d_sine_pe(pos_inds: torch.Tensor, dim: int, temperature: float = 10000)
|
|
|
89
87
|
|
|
90
88
|
|
|
91
89
|
def init_t_xy(end_x: int, end_y: int):
|
|
92
|
-
"""
|
|
93
|
-
Initialize 1D and 2D coordinate tensors for a grid of specified dimensions.
|
|
90
|
+
"""Initialize 1D and 2D coordinate tensors for a grid of specified dimensions.
|
|
94
91
|
|
|
95
|
-
This function creates coordinate tensors for a grid with dimensions end_x × end_y. It generates a linear index
|
|
96
|
-
and corresponding x and y coordinate tensors.
|
|
92
|
+
This function creates coordinate tensors for a grid with dimensions end_x × end_y. It generates a linear index
|
|
93
|
+
tensor and corresponding x and y coordinate tensors.
|
|
97
94
|
|
|
98
95
|
Args:
|
|
99
96
|
end_x (int): Width of the grid (number of columns).
|
|
@@ -117,11 +114,10 @@ def init_t_xy(end_x: int, end_y: int):
|
|
|
117
114
|
|
|
118
115
|
|
|
119
116
|
def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
|
|
120
|
-
"""
|
|
121
|
-
Compute axial complex exponential positional encodings for 2D spatial positions in a grid.
|
|
117
|
+
"""Compute axial complex exponential positional encodings for 2D spatial positions in a grid.
|
|
122
118
|
|
|
123
|
-
This function generates complex exponential positional encodings for a 2D grid of spatial positions,
|
|
124
|
-
|
|
119
|
+
This function generates complex exponential positional encodings for a 2D grid of spatial positions, using separate
|
|
120
|
+
frequency components for the x and y dimensions.
|
|
125
121
|
|
|
126
122
|
Args:
|
|
127
123
|
dim (int): Dimension of the positional encoding.
|
|
@@ -150,11 +146,10 @@ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
|
|
|
150
146
|
|
|
151
147
|
|
|
152
148
|
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
|
153
|
-
"""
|
|
154
|
-
Reshape frequency tensor for broadcasting with input tensor.
|
|
149
|
+
"""Reshape frequency tensor for broadcasting with input tensor.
|
|
155
150
|
|
|
156
|
-
Reshapes a frequency tensor to ensure dimensional compatibility for broadcasting with an input tensor.
|
|
157
|
-
|
|
151
|
+
Reshapes a frequency tensor to ensure dimensional compatibility for broadcasting with an input tensor. This function
|
|
152
|
+
is typically used in positional encoding operations.
|
|
158
153
|
|
|
159
154
|
Args:
|
|
160
155
|
freqs_cis (torch.Tensor): Frequency tensor with shape matching the last two dimensions of x.
|
|
@@ -179,8 +174,7 @@ def apply_rotary_enc(
|
|
|
179
174
|
freqs_cis: torch.Tensor,
|
|
180
175
|
repeat_freqs_k: bool = False,
|
|
181
176
|
):
|
|
182
|
-
"""
|
|
183
|
-
Apply rotary positional encoding to query and key tensors.
|
|
177
|
+
"""Apply rotary positional encoding to query and key tensors.
|
|
184
178
|
|
|
185
179
|
This function applies rotary positional encoding (RoPE) to query and key tensors using complex-valued frequency
|
|
186
180
|
components. RoPE is a technique that injects relative position information into self-attention mechanisms.
|
|
@@ -188,10 +182,10 @@ def apply_rotary_enc(
|
|
|
188
182
|
Args:
|
|
189
183
|
xq (torch.Tensor): Query tensor to encode with positional information.
|
|
190
184
|
xk (torch.Tensor): Key tensor to encode with positional information.
|
|
191
|
-
freqs_cis (torch.Tensor): Complex-valued frequency components for rotary encoding with shape matching the
|
|
192
|
-
|
|
193
|
-
repeat_freqs_k (bool, optional): Whether to repeat frequency components along sequence length dimension
|
|
194
|
-
|
|
185
|
+
freqs_cis (torch.Tensor): Complex-valued frequency components for rotary encoding with shape matching the last
|
|
186
|
+
two dimensions of xq.
|
|
187
|
+
repeat_freqs_k (bool, optional): Whether to repeat frequency components along sequence length dimension to match
|
|
188
|
+
key sequence length.
|
|
195
189
|
|
|
196
190
|
Returns:
|
|
197
191
|
xq_out (torch.Tensor): Query tensor with rotary positional encoding applied.
|
|
@@ -220,8 +214,7 @@ def apply_rotary_enc(
|
|
|
220
214
|
|
|
221
215
|
|
|
222
216
|
def window_partition(x: torch.Tensor, window_size: int):
|
|
223
|
-
"""
|
|
224
|
-
Partition input tensor into non-overlapping windows with padding if needed.
|
|
217
|
+
"""Partition input tensor into non-overlapping windows with padding if needed.
|
|
225
218
|
|
|
226
219
|
Args:
|
|
227
220
|
x (torch.Tensor): Input tensor with shape (B, H, W, C).
|
|
@@ -251,23 +244,22 @@ def window_partition(x: torch.Tensor, window_size: int):
|
|
|
251
244
|
|
|
252
245
|
|
|
253
246
|
def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: tuple[int, int], hw: tuple[int, int]):
|
|
254
|
-
"""
|
|
255
|
-
Unpartition windowed sequences into original sequences and remove padding.
|
|
247
|
+
"""Unpartition windowed sequences into original sequences and remove padding.
|
|
256
248
|
|
|
257
|
-
This function reverses the windowing process, reconstructing the original input from windowed segments
|
|
258
|
-
|
|
249
|
+
This function reverses the windowing process, reconstructing the original input from windowed segments and removing
|
|
250
|
+
any padding that was added during the windowing process.
|
|
259
251
|
|
|
260
252
|
Args:
|
|
261
253
|
windows (torch.Tensor): Input tensor of windowed sequences with shape (B * num_windows, window_size,
|
|
262
|
-
window_size, C), where B is the batch size, num_windows is the number of windows, window_size is
|
|
263
|
-
|
|
254
|
+
window_size, C), where B is the batch size, num_windows is the number of windows, window_size is the size of
|
|
255
|
+
each window, and C is the number of channels.
|
|
264
256
|
window_size (int): Size of each window.
|
|
265
257
|
pad_hw (tuple[int, int]): Padded height and width (Hp, Wp) of the input before windowing.
|
|
266
258
|
hw (tuple[int, int]): Original height and width (H, W) of the input before padding and windowing.
|
|
267
259
|
|
|
268
260
|
Returns:
|
|
269
|
-
(torch.Tensor): Unpartitioned sequences with shape (B, H, W, C), where B is the batch size, H and W
|
|
270
|
-
|
|
261
|
+
(torch.Tensor): Unpartitioned sequences with shape (B, H, W, C), where B is the batch size, H and W are the
|
|
262
|
+
original height and width, and C is the number of channels.
|
|
271
263
|
|
|
272
264
|
Examples:
|
|
273
265
|
>>> windows = torch.rand(32, 8, 8, 64) # 32 windows of size 8x8 with 64 channels
|
|
@@ -289,18 +281,16 @@ def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: tuple[in
|
|
|
289
281
|
|
|
290
282
|
|
|
291
283
|
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
|
|
292
|
-
"""
|
|
293
|
-
Extract relative positional embeddings based on query and key sizes.
|
|
284
|
+
"""Extract relative positional embeddings based on query and key sizes.
|
|
294
285
|
|
|
295
286
|
Args:
|
|
296
287
|
q_size (int): Size of the query.
|
|
297
288
|
k_size (int): Size of the key.
|
|
298
|
-
rel_pos (torch.Tensor): Relative position embeddings with shape (L, C), where L is the maximum relative
|
|
299
|
-
|
|
289
|
+
rel_pos (torch.Tensor): Relative position embeddings with shape (L, C), where L is the maximum relative distance
|
|
290
|
+
and C is the embedding dimension.
|
|
300
291
|
|
|
301
292
|
Returns:
|
|
302
|
-
(torch.Tensor): Extracted positional embeddings according to relative positions, with shape (q_size,
|
|
303
|
-
k_size, C).
|
|
293
|
+
(torch.Tensor): Extracted positional embeddings according to relative positions, with shape (q_size, k_size, C).
|
|
304
294
|
|
|
305
295
|
Examples:
|
|
306
296
|
>>> q_size, k_size = 8, 16
|
|
@@ -338,8 +328,7 @@ def add_decomposed_rel_pos(
|
|
|
338
328
|
q_size: tuple[int, int],
|
|
339
329
|
k_size: tuple[int, int],
|
|
340
330
|
) -> torch.Tensor:
|
|
341
|
-
"""
|
|
342
|
-
Add decomposed Relative Positional Embeddings to the attention map.
|
|
331
|
+
"""Add decomposed Relative Positional Embeddings to the attention map.
|
|
343
332
|
|
|
344
333
|
This function calculates and applies decomposed Relative Positional Embeddings as described in the MVITv2
|
|
345
334
|
paper. It enhances the attention mechanism by incorporating spatial relationships between query and key
|
|
@@ -354,8 +343,8 @@ def add_decomposed_rel_pos(
|
|
|
354
343
|
k_size (tuple[int, int]): Spatial sequence size of key k as (k_h, k_w).
|
|
355
344
|
|
|
356
345
|
Returns:
|
|
357
|
-
(torch.Tensor): Updated attention map with added relative positional embeddings, shape
|
|
358
|
-
|
|
346
|
+
(torch.Tensor): Updated attention map with added relative positional embeddings, shape (B, q_h * q_w, k_h *
|
|
347
|
+
k_w).
|
|
359
348
|
|
|
360
349
|
Examples:
|
|
361
350
|
>>> B, C, q_h, q_w, k_h, k_w = 1, 64, 8, 8, 8, 8
|