ultralytics 8.2.68__py3-none-any.whl → 8.2.70__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (37) hide show
  1. tests/test_cli.py +4 -16
  2. ultralytics/__init__.py +3 -2
  3. ultralytics/cfg/__init__.py +4 -0
  4. ultralytics/data/augment.py +1 -1
  5. ultralytics/hub/google/__init__.py +3 -3
  6. ultralytics/models/__init__.py +2 -1
  7. ultralytics/models/fastsam/__init__.py +1 -2
  8. ultralytics/models/fastsam/model.py +18 -0
  9. ultralytics/models/fastsam/predict.py +116 -1
  10. ultralytics/models/sam/build.py +2 -2
  11. ultralytics/models/sam/model.py +10 -2
  12. ultralytics/models/sam/modules/decoders.py +1 -42
  13. ultralytics/models/sam/modules/encoders.py +3 -1
  14. ultralytics/models/sam/modules/sam.py +5 -7
  15. ultralytics/models/sam/modules/transformer.py +4 -3
  16. ultralytics/models/sam/predict.py +12 -6
  17. ultralytics/models/sam2/__init__.py +6 -0
  18. ultralytics/models/sam2/build.py +156 -0
  19. ultralytics/models/sam2/model.py +97 -0
  20. ultralytics/models/sam2/modules/__init__.py +1 -0
  21. ultralytics/models/sam2/modules/decoders.py +305 -0
  22. ultralytics/models/sam2/modules/encoders.py +332 -0
  23. ultralytics/models/sam2/modules/memory_attention.py +170 -0
  24. ultralytics/models/sam2/modules/sam2.py +804 -0
  25. ultralytics/models/sam2/modules/sam2_blocks.py +715 -0
  26. ultralytics/models/sam2/modules/utils.py +191 -0
  27. ultralytics/models/sam2/predict.py +182 -0
  28. ultralytics/nn/modules/transformer.py +5 -3
  29. ultralytics/utils/ops.py +1 -1
  30. ultralytics/utils/torch_utils.py +9 -6
  31. {ultralytics-8.2.68.dist-info → ultralytics-8.2.70.dist-info}/METADATA +1 -1
  32. {ultralytics-8.2.68.dist-info → ultralytics-8.2.70.dist-info}/RECORD +36 -26
  33. {ultralytics-8.2.68.dist-info → ultralytics-8.2.70.dist-info}/WHEEL +1 -1
  34. ultralytics/models/fastsam/prompt.py +0 -352
  35. {ultralytics-8.2.68.dist-info → ultralytics-8.2.70.dist-info}/LICENSE +0 -0
  36. {ultralytics-8.2.68.dist-info → ultralytics-8.2.70.dist-info}/entry_points.txt +0 -0
  37. {ultralytics-8.2.68.dist-info → ultralytics-8.2.70.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,715 @@
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ import copy
4
+ import math
5
+ from functools import partial
6
+ from typing import Optional, Tuple, Type, Union
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import Tensor, nn
11
+
12
+ from ultralytics.models.sam.modules.transformer import (
13
+ Attention,
14
+ )
15
+ from ultralytics.models.sam.modules.transformer import (
16
+ TwoWayAttentionBlock as SAMTwoWayAttentionBlock,
17
+ )
18
+ from ultralytics.models.sam.modules.transformer import (
19
+ TwoWayTransformer as SAMTwoWayTransformer,
20
+ )
21
+ from ultralytics.nn.modules import MLP, LayerNorm2d
22
+
23
+ from .utils import apply_rotary_enc, compute_axial_cis, window_partition, window_unpartition
24
+
25
+
26
+ class DropPath(nn.Module):
27
+ """Implements stochastic depth regularization for neural networks during training."""
28
+
29
+ def __init__(self, drop_prob=0.0, scale_by_keep=True):
30
+ """Initialize DropPath module with specified drop probability and scaling option."""
31
+ super(DropPath, self).__init__()
32
+ self.drop_prob = drop_prob
33
+ self.scale_by_keep = scale_by_keep
34
+
35
+ def forward(self, x):
36
+ """Applies stochastic depth to input tensor during training, with optional scaling."""
37
+ if self.drop_prob == 0.0 or not self.training:
38
+ return x
39
+ keep_prob = 1 - self.drop_prob
40
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
41
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
42
+ if keep_prob > 0.0 and self.scale_by_keep:
43
+ random_tensor.div_(keep_prob)
44
+ return x * random_tensor
45
+
46
+
47
+ class MaskDownSampler(nn.Module):
48
+ """Downsamples and embeds masks using convolutional layers and layer normalization for efficient processing."""
49
+
50
+ def __init__(
51
+ self,
52
+ embed_dim=256,
53
+ kernel_size=4,
54
+ stride=4,
55
+ padding=0,
56
+ total_stride=16,
57
+ activation=nn.GELU,
58
+ ):
59
+ """Initializes a mask downsampler module for progressive downsampling and channel expansion."""
60
+ super().__init__()
61
+ num_layers = int(math.log2(total_stride) // math.log2(stride))
62
+ assert stride**num_layers == total_stride
63
+ self.encoder = nn.Sequential()
64
+ mask_in_chans, mask_out_chans = 1, 1
65
+ for _ in range(num_layers):
66
+ mask_out_chans = mask_in_chans * (stride**2)
67
+ self.encoder.append(
68
+ nn.Conv2d(
69
+ mask_in_chans,
70
+ mask_out_chans,
71
+ kernel_size=kernel_size,
72
+ stride=stride,
73
+ padding=padding,
74
+ )
75
+ )
76
+ self.encoder.append(LayerNorm2d(mask_out_chans))
77
+ self.encoder.append(activation())
78
+ mask_in_chans = mask_out_chans
79
+
80
+ self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
81
+
82
+ def forward(self, x):
83
+ """Downsamples and encodes input mask to embed_dim channels using convolutional layers and LayerNorm2d."""
84
+ return self.encoder(x)
85
+
86
+
87
+ # Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt)
88
+ class CXBlock(nn.Module):
89
+ """
90
+ ConvNeXt Block for efficient feature extraction in convolutional neural networks.
91
+
92
+ This block implements a modified version of the ConvNeXt architecture, offering two equivalent
93
+ implementations for improved performance and flexibility.
94
+
95
+ Attributes:
96
+ dwconv (nn.Conv2d): Depthwise convolution layer.
97
+ norm (LayerNorm2d): Layer normalization applied to channels.
98
+ pwconv1 (nn.Linear): First pointwise convolution implemented as a linear layer.
99
+ act (nn.GELU): GELU activation function.
100
+ pwconv2 (nn.Linear): Second pointwise convolution implemented as a linear layer.
101
+ gamma (nn.Parameter | None): Learnable scale parameter for layer scaling.
102
+ drop_path (nn.Module): DropPath layer for stochastic depth regularization.
103
+
104
+ Methods:
105
+ forward: Processes the input tensor through the ConvNeXt block.
106
+
107
+ Examples:
108
+ >>> import torch
109
+ >>> x = torch.randn(1, 64, 56, 56)
110
+ >>> block = CXBlock(dim=64, kernel_size=7, padding=3)
111
+ >>> output = block(x)
112
+ >>> print(output.shape)
113
+ torch.Size([1, 64, 56, 56])
114
+ """
115
+
116
+ def __init__(
117
+ self,
118
+ dim,
119
+ kernel_size=7,
120
+ padding=3,
121
+ drop_path=0.0,
122
+ layer_scale_init_value=1e-6,
123
+ use_dwconv=True,
124
+ ):
125
+ """
126
+ Initialize a ConvNeXt Block.
127
+
128
+ This block implements a ConvNeXt architecture with optional depthwise convolution, layer normalization,
129
+ pointwise convolutions, and GELU activation.
130
+
131
+ Args:
132
+ dim (int): Number of input channels.
133
+ kernel_size (int): Size of the convolutional kernel. Default is 7.
134
+ padding (int): Padding size for the convolution. Default is 3.
135
+ drop_path (float): Stochastic depth rate. Default is 0.0.
136
+ layer_scale_init_value (float): Initial value for Layer Scale. Default is 1e-6.
137
+ use_dwconv (bool): Whether to use depthwise convolution. Default is True.
138
+
139
+ Attributes:
140
+ dwconv (nn.Conv2d): Depthwise or standard 2D convolution layer.
141
+ norm (LayerNorm2d): Layer normalization applied to the output of dwconv.
142
+ pwconv1 (nn.Linear): First pointwise convolution implemented as a linear layer.
143
+ act (nn.GELU): GELU activation function.
144
+ pwconv2 (nn.Linear): Second pointwise convolution implemented as a linear layer.
145
+ gamma (nn.Parameter | None): Learnable scale parameter for the residual path.
146
+
147
+ Examples:
148
+ >>> block = CXBlock(dim=64, kernel_size=7, padding=3)
149
+ >>> x = torch.randn(1, 64, 32, 32)
150
+ >>> output = block(x)
151
+ >>> print(output.shape)
152
+ torch.Size([1, 64, 32, 32])
153
+ """
154
+ super().__init__()
155
+ self.dwconv = nn.Conv2d(
156
+ dim,
157
+ dim,
158
+ kernel_size=kernel_size,
159
+ padding=padding,
160
+ groups=dim if use_dwconv else 1,
161
+ ) # depthwise conv
162
+ self.norm = LayerNorm2d(dim, eps=1e-6)
163
+ self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
164
+ self.act = nn.GELU()
165
+ self.pwconv2 = nn.Linear(4 * dim, dim)
166
+ self.gamma = (
167
+ nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
168
+ if layer_scale_init_value > 0
169
+ else None
170
+ )
171
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
172
+
173
+ def forward(self, x):
174
+ """Applies ConvNeXt block operations to input tensor, including convolutions and residual connection."""
175
+ input = x
176
+ x = self.dwconv(x)
177
+ x = self.norm(x)
178
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
179
+ x = self.pwconv1(x)
180
+ x = self.act(x)
181
+ x = self.pwconv2(x)
182
+ if self.gamma is not None:
183
+ x = self.gamma * x
184
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
185
+
186
+ x = input + self.drop_path(x)
187
+ return x
188
+
189
+
190
+ class Fuser(nn.Module):
191
+ """
192
+ A module for fusing features through multiple layers of a neural network.
193
+
194
+ This class applies a series of identical layers to an input tensor, optionally projecting the input first.
195
+
196
+ Attributes:
197
+ proj (nn.Module): An optional input projection layer. Identity if no projection is needed.
198
+ layers (nn.ModuleList): A list of identical layers to be applied sequentially.
199
+
200
+ Methods:
201
+ forward: Applies the fuser to an input tensor.
202
+
203
+ Examples:
204
+ >>> layer = CXBlock(dim=256)
205
+ >>> fuser = Fuser(layer, num_layers=3, dim=256, input_projection=True)
206
+ >>> x = torch.randn(1, 256, 32, 32)
207
+ >>> output = fuser(x)
208
+ >>> print(output.shape)
209
+ torch.Size([1, 256, 32, 32])
210
+ """
211
+
212
+ def __init__(self, layer, num_layers, dim=None, input_projection=False):
213
+ """
214
+ Initializes the Fuser module.
215
+
216
+ This module creates a sequence of identical layers and optionally applies an input projection.
217
+
218
+ Args:
219
+ layer (nn.Module): The layer to be replicated in the fuser.
220
+ num_layers (int): The number of times to replicate the layer.
221
+ dim (int | None): The dimension for input projection, if used.
222
+ input_projection (bool): Whether to use input projection.
223
+
224
+ Attributes:
225
+ proj (nn.Module): The input projection layer, or nn.Identity if not used.
226
+ layers (nn.ModuleList): A list of replicated layers.
227
+
228
+ Examples:
229
+ >>> layer = nn.Linear(64, 64)
230
+ >>> fuser = Fuser(layer, num_layers=3, dim=64, input_projection=True)
231
+ >>> input_tensor = torch.randn(1, 64)
232
+ >>> output = fuser(input_tensor)
233
+ """
234
+ super().__init__()
235
+ self.proj = nn.Identity()
236
+ self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)])
237
+
238
+ if input_projection:
239
+ assert dim is not None
240
+ self.proj = nn.Conv2d(dim, dim, kernel_size=1)
241
+
242
+ def forward(self, x):
243
+ """Applies a series of layers to the input tensor, optionally projecting it first."""
244
+ x = self.proj(x)
245
+ for layer in self.layers:
246
+ x = layer(x)
247
+ return x
248
+
249
+
250
+ class TwoWayAttentionBlock(SAMTwoWayAttentionBlock):
251
+ """
252
+ A two-way attention block for performing self-attention and cross-attention in both directions.
253
+
254
+ This block extends the SAMTwoWayAttentionBlock and consists of four main components: self-attention on
255
+ sparse inputs, cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and
256
+ cross-attention from dense to sparse inputs.
257
+
258
+ Attributes:
259
+ self_attn (Attention): Self-attention layer for queries.
260
+ norm1 (nn.LayerNorm): Layer normalization after the first attention block.
261
+ cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.
262
+ norm2 (nn.LayerNorm): Layer normalization after the second attention block.
263
+ mlp (MLP): MLP block for transforming query embeddings.
264
+ norm3 (nn.LayerNorm): Layer normalization after the MLP block.
265
+ norm4 (nn.LayerNorm): Layer normalization after the third attention block.
266
+ cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.
267
+ skip_first_layer_pe (bool): Flag to skip positional encoding in the first layer.
268
+
269
+ Methods:
270
+ forward: Processes input through the attention blocks and MLP.
271
+
272
+ Examples:
273
+ >>> block = TwoWayAttentionBlock(embedding_dim=256, num_heads=8)
274
+ >>> sparse_input = torch.randn(1, 100, 256)
275
+ >>> dense_input = torch.randn(1, 256, 16, 16)
276
+ >>> sparse_output, dense_output = block(sparse_input, dense_input)
277
+ """
278
+
279
+ def __init__(
280
+ self,
281
+ embedding_dim: int,
282
+ num_heads: int,
283
+ mlp_dim: int = 2048,
284
+ activation: Type[nn.Module] = nn.ReLU,
285
+ attention_downsample_rate: int = 2,
286
+ skip_first_layer_pe: bool = False,
287
+ ) -> None:
288
+ """
289
+ Initializes a TwoWayAttentionBlock for performing self-attention and cross-attention in two directions.
290
+
291
+ This block consists of four main layers: self-attention on sparse inputs, cross-attention of sparse inputs
292
+ to dense inputs, an MLP block on sparse inputs, and cross-attention of dense inputs to sparse inputs.
293
+
294
+ Args:
295
+ embedding_dim (int): The channel dimension of the embeddings.
296
+ num_heads (int): The number of heads in the attention layers.
297
+ mlp_dim (int): The hidden dimension of the MLP block.
298
+ activation (Type[nn.Module]): The activation function of the MLP block.
299
+ attention_downsample_rate (int): The downsample rate for attention computations.
300
+ skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer.
301
+
302
+ Attributes:
303
+ self_attn (Attention): The self-attention layer for the queries.
304
+ norm1 (nn.LayerNorm): Layer normalization following the first attention block.
305
+ cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.
306
+ norm2 (nn.LayerNorm): Layer normalization following the second attention block.
307
+ mlp (MLP): MLP block that transforms the query embeddings.
308
+ norm3 (nn.LayerNorm): Layer normalization following the MLP block.
309
+ norm4 (nn.LayerNorm): Layer normalization following the third attention block.
310
+ cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.
311
+ skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer.
312
+
313
+ Examples:
314
+ >>> block = TwoWayAttentionBlock(embedding_dim=256, num_heads=8, mlp_dim=2048)
315
+ >>> sparse_inputs = torch.randn(1, 100, 256)
316
+ >>> dense_inputs = torch.randn(1, 256, 32, 32)
317
+ >>> sparse_outputs, dense_outputs = block(sparse_inputs, dense_inputs)
318
+ """
319
+ super().__init__(embedding_dim, num_heads, mlp_dim, activation, attention_downsample_rate, skip_first_layer_pe)
320
+ self.mlp = MLP(embedding_dim, mlp_dim, embedding_dim, num_layers=2, act=activation)
321
+
322
+
323
+ class TwoWayTransformer(SAMTwoWayTransformer):
324
+ """
325
+ A Two-Way Transformer module for simultaneous attention to image and query points.
326
+
327
+ This class implements a specialized transformer decoder that attends to an input image using queries with
328
+ supplied positional embeddings. It is particularly useful for tasks like object detection, image
329
+ segmentation, and point cloud processing.
330
+
331
+ Attributes:
332
+ depth (int): Number of layers in the transformer.
333
+ embedding_dim (int): Channel dimension for input embeddings.
334
+ num_heads (int): Number of heads for multihead attention.
335
+ mlp_dim (int): Internal channel dimension for the MLP block.
336
+ layers (nn.ModuleList): List of TwoWayAttentionBlock layers comprising the transformer.
337
+ final_attn_token_to_image (Attention): Final attention layer from queries to image.
338
+ norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.
339
+
340
+ Methods:
341
+ forward: Processes input image embeddings and query embeddings through the transformer.
342
+
343
+ Examples:
344
+ >>> transformer = TwoWayTransformer(depth=5, embedding_dim=256, num_heads=8, mlp_dim=2048)
345
+ >>> image_embedding = torch.randn(1, 256, 64, 64)
346
+ >>> query_embedding = torch.randn(1, 100, 256)
347
+ >>> output = transformer(image_embedding, query_embedding)
348
+ """
349
+
350
+ def __init__(
351
+ self,
352
+ depth: int,
353
+ embedding_dim: int,
354
+ num_heads: int,
355
+ mlp_dim: int,
356
+ activation: Type[nn.Module] = nn.ReLU,
357
+ attention_downsample_rate: int = 2,
358
+ ) -> None:
359
+ """
360
+ Initializes a TwoWayTransformer instance.
361
+
362
+ This transformer decoder attends to an input image using queries with supplied positional embeddings.
363
+ It is designed for tasks like object detection, image segmentation, and point cloud processing.
364
+
365
+ Args:
366
+ depth (int): Number of layers in the transformer.
367
+ embedding_dim (int): Channel dimension for the input embeddings.
368
+ num_heads (int): Number of heads for multihead attention. Must divide embedding_dim.
369
+ mlp_dim (int): Channel dimension internal to the MLP block.
370
+ activation (Type[nn.Module]): Activation function to use in the MLP block.
371
+ attention_downsample_rate (int): Downsampling rate for attention computations.
372
+
373
+ Attributes:
374
+ depth (int): Number of layers in the transformer.
375
+ embedding_dim (int): Channel dimension for the input embeddings.
376
+ num_heads (int): Number of heads for multihead attention.
377
+ mlp_dim (int): Internal channel dimension for the MLP block.
378
+ layers (nn.ModuleList): List of TwoWayAttentionBlock layers comprising the transformer.
379
+ final_attn_token_to_image (Attention): Final attention layer from queries to image.
380
+ norm_final_attn (nn.LayerNorm): Layer normalization applied to the final queries.
381
+
382
+ Examples:
383
+ >>> transformer = TwoWayTransformer(depth=5, embedding_dim=256, num_heads=8, mlp_dim=2048)
384
+ >>> transformer
385
+ TwoWayTransformer(
386
+ (layers): ModuleList(
387
+ (0-4): 5 x TwoWayAttentionBlock(...)
388
+ )
389
+ (final_attn_token_to_image): Attention(...)
390
+ (norm_final_attn): LayerNorm(...)
391
+ )
392
+ """
393
+ super().__init__(depth, embedding_dim, num_heads, mlp_dim, activation, attention_downsample_rate)
394
+ self.layers = nn.ModuleList()
395
+ for i in range(depth):
396
+ self.layers.append(
397
+ TwoWayAttentionBlock(
398
+ embedding_dim=embedding_dim,
399
+ num_heads=num_heads,
400
+ mlp_dim=mlp_dim,
401
+ activation=activation,
402
+ attention_downsample_rate=attention_downsample_rate,
403
+ skip_first_layer_pe=(i == 0),
404
+ )
405
+ )
406
+
407
+
408
+ class RoPEAttention(Attention):
409
+ """Implements rotary position encoding for attention mechanisms in transformer architectures."""
410
+
411
+ def __init__(
412
+ self,
413
+ *args,
414
+ rope_theta=10000.0,
415
+ # whether to repeat q rope to match k length
416
+ # this is needed for cross-attention to memories
417
+ rope_k_repeat=False,
418
+ feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution
419
+ **kwargs,
420
+ ):
421
+ """Initializes RoPEAttention with rotary position encoding for attention mechanisms."""
422
+ super().__init__(*args, **kwargs)
423
+
424
+ self.compute_cis = partial(compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta)
425
+ freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
426
+ self.freqs_cis = freqs_cis
427
+ self.rope_k_repeat = rope_k_repeat
428
+
429
+ def forward(self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0) -> Tensor:
430
+ """Applies rotary position encoding and computes attention between query, key, and value tensors."""
431
+ q = self.q_proj(q)
432
+ k = self.k_proj(k)
433
+ v = self.v_proj(v)
434
+
435
+ # Separate into heads
436
+ q = self._separate_heads(q, self.num_heads)
437
+ k = self._separate_heads(k, self.num_heads)
438
+ v = self._separate_heads(v, self.num_heads)
439
+
440
+ # Apply rotary position encoding
441
+ w = h = math.sqrt(q.shape[-2])
442
+ self.freqs_cis = self.freqs_cis.to(q.device)
443
+ if self.freqs_cis.shape[0] != q.shape[-2]:
444
+ self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
445
+ if q.shape[-2] != k.shape[-2]:
446
+ assert self.rope_k_repeat
447
+
448
+ num_k_rope = k.size(-2) - num_k_exclude_rope
449
+ q, k[:, :, :num_k_rope] = apply_rotary_enc(
450
+ q,
451
+ k[:, :, :num_k_rope],
452
+ freqs_cis=self.freqs_cis,
453
+ repeat_freqs_k=self.rope_k_repeat,
454
+ )
455
+
456
+ # Attention
457
+ _, _, _, c_per_head = q.shape
458
+ attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
459
+ attn = attn / math.sqrt(c_per_head)
460
+ attn = torch.softmax(attn, dim=-1)
461
+
462
+ # Get output
463
+ out = attn @ v
464
+
465
+ out = self._recombine_heads(out)
466
+ out = self.out_proj(out)
467
+
468
+ return out
469
+
470
+
471
+ def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
472
+ """Applies pooling and optional normalization to a tensor, handling permutations for spatial operations."""
473
+ if pool is None:
474
+ return x
475
+ # (B, H, W, C) -> (B, C, H, W)
476
+ x = x.permute(0, 3, 1, 2)
477
+ x = pool(x)
478
+ # (B, C, H', W') -> (B, H', W', C)
479
+ x = x.permute(0, 2, 3, 1)
480
+ if norm:
481
+ x = norm(x)
482
+
483
+ return x
484
+
485
+
486
+ class MultiScaleAttention(nn.Module):
487
+ """Implements multi-scale self-attention with optional query pooling for efficient feature extraction."""
488
+
489
+ def __init__(
490
+ self,
491
+ dim: int,
492
+ dim_out: int,
493
+ num_heads: int,
494
+ q_pool: nn.Module = None,
495
+ ):
496
+ """Initializes a multi-scale attention module with configurable query pooling and linear projections."""
497
+ super().__init__()
498
+
499
+ self.dim = dim
500
+ self.dim_out = dim_out
501
+
502
+ self.num_heads = num_heads
503
+ head_dim = dim_out // num_heads
504
+ self.scale = head_dim**-0.5
505
+
506
+ self.q_pool = q_pool
507
+ self.qkv = nn.Linear(dim, dim_out * 3)
508
+ self.proj = nn.Linear(dim_out, dim_out)
509
+
510
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
511
+ """Applies multi-scale attention to input tensor, optionally downsampling query features."""
512
+ B, H, W, _ = x.shape
513
+ # qkv with shape (B, H * W, 3, nHead, C)
514
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
515
+ # q, k, v with shape (B, H * W, nheads, C)
516
+ q, k, v = torch.unbind(qkv, 2)
517
+
518
+ # Q pooling (for downsample at stage changes)
519
+ if self.q_pool:
520
+ q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
521
+ H, W = q.shape[1:3] # downsampled shape
522
+ q = q.reshape(B, H * W, self.num_heads, -1)
523
+
524
+ # Torch's SDPA expects [B, nheads, H*W, C] so we transpose
525
+ x = F.scaled_dot_product_attention(
526
+ q.transpose(1, 2),
527
+ k.transpose(1, 2),
528
+ v.transpose(1, 2),
529
+ )
530
+ # Transpose back
531
+ x = x.transpose(1, 2)
532
+ x = x.reshape(B, H, W, -1)
533
+
534
+ x = self.proj(x)
535
+
536
+ return x
537
+
538
+
539
+ class MultiScaleBlock(nn.Module):
540
+ """Multiscale attention block with window partitioning and query pooling for efficient vision transformers."""
541
+
542
+ def __init__(
543
+ self,
544
+ dim: int,
545
+ dim_out: int,
546
+ num_heads: int,
547
+ mlp_ratio: float = 4.0,
548
+ drop_path: float = 0.0,
549
+ norm_layer: Union[nn.Module, str] = "LayerNorm",
550
+ q_stride: Tuple[int, int] = None,
551
+ act_layer: nn.Module = nn.GELU,
552
+ window_size: int = 0,
553
+ ):
554
+ """Initializes a multi-scale attention block with optional window partitioning and downsampling."""
555
+ super().__init__()
556
+
557
+ if isinstance(norm_layer, str):
558
+ norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
559
+
560
+ self.dim = dim
561
+ self.dim_out = dim_out
562
+ self.norm1 = norm_layer(dim)
563
+
564
+ self.window_size = window_size
565
+
566
+ self.pool, self.q_stride = None, q_stride
567
+ if self.q_stride:
568
+ self.pool = nn.MaxPool2d(kernel_size=q_stride, stride=q_stride, ceil_mode=False)
569
+
570
+ self.attn = MultiScaleAttention(
571
+ dim,
572
+ dim_out,
573
+ num_heads=num_heads,
574
+ q_pool=self.pool,
575
+ )
576
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
577
+
578
+ self.norm2 = norm_layer(dim_out)
579
+ self.mlp = MLP(
580
+ dim_out,
581
+ int(dim_out * mlp_ratio),
582
+ dim_out,
583
+ num_layers=2,
584
+ act=act_layer,
585
+ )
586
+
587
+ if dim != dim_out:
588
+ self.proj = nn.Linear(dim, dim_out)
589
+
590
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
591
+ """Applies multi-scale attention and MLP processing to input tensor, with optional windowing."""
592
+ shortcut = x # B, H, W, C
593
+ x = self.norm1(x)
594
+
595
+ # Skip connection
596
+ if self.dim != self.dim_out:
597
+ shortcut = do_pool(self.proj(x), self.pool)
598
+
599
+ # Window partition
600
+ window_size = self.window_size
601
+ if window_size > 0:
602
+ H, W = x.shape[1], x.shape[2]
603
+ x, pad_hw = window_partition(x, window_size)
604
+
605
+ # Window Attention + Q Pooling (if stage change)
606
+ x = self.attn(x)
607
+ if self.q_stride:
608
+ # Shapes have changed due to Q pooling
609
+ window_size = self.window_size // self.q_stride[0]
610
+ H, W = shortcut.shape[1:3]
611
+
612
+ pad_h = (window_size - H % window_size) % window_size
613
+ pad_w = (window_size - W % window_size) % window_size
614
+ pad_hw = (H + pad_h, W + pad_w)
615
+
616
+ # Reverse window partition
617
+ if self.window_size > 0:
618
+ x = window_unpartition(x, window_size, pad_hw, (H, W))
619
+
620
+ x = shortcut + self.drop_path(x)
621
+ # MLP
622
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
623
+ return x
624
+
625
+
626
+ class PositionEmbeddingSine(nn.Module):
627
+ """Generates sinusoidal positional embeddings for 2D inputs like images."""
628
+
629
+ def __init__(
630
+ self,
631
+ num_pos_feats,
632
+ temperature: int = 10000,
633
+ normalize: bool = True,
634
+ scale: Optional[float] = None,
635
+ ):
636
+ """Initializes sinusoidal position embeddings for 2D image inputs."""
637
+ super().__init__()
638
+ assert num_pos_feats % 2 == 0, "Expecting even model width"
639
+ self.num_pos_feats = num_pos_feats // 2
640
+ self.temperature = temperature
641
+ self.normalize = normalize
642
+ if scale is not None and normalize is False:
643
+ raise ValueError("normalize should be True if scale is passed")
644
+ if scale is None:
645
+ scale = 2 * math.pi
646
+ self.scale = scale
647
+
648
+ self.cache = {}
649
+
650
+ def _encode_xy(self, x, y):
651
+ """Encodes 2D positions using sine and cosine functions for positional embeddings."""
652
+ assert len(x) == len(y) and x.ndim == y.ndim == 1
653
+ x_embed = x * self.scale
654
+ y_embed = y * self.scale
655
+
656
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
657
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
658
+
659
+ pos_x = x_embed[:, None] / dim_t
660
+ pos_y = y_embed[:, None] / dim_t
661
+ pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1)
662
+ pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1)
663
+ return pos_x, pos_y
664
+
665
+ @torch.no_grad()
666
+ def encode_boxes(self, x, y, w, h):
667
+ """Encodes box coordinates and dimensions into positional embeddings for object detection tasks."""
668
+ pos_x, pos_y = self._encode_xy(x, y)
669
+ pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
670
+ return pos
671
+
672
+ encode = encode_boxes # Backwards compatibility
673
+
674
+ @torch.no_grad()
675
+ def encode_points(self, x, y, labels):
676
+ """Encodes 2D point coordinates with sinusoidal positional embeddings and appends labels."""
677
+ (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
678
+ assert bx == by and nx == ny and bx == bl and nx == nl
679
+ pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
680
+ pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
681
+ pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
682
+ return pos
683
+
684
+ @torch.no_grad()
685
+ def forward(self, x: torch.Tensor):
686
+ """Generate sinusoidal position embeddings for 2D inputs."""
687
+ cache_key = (x.shape[-2], x.shape[-1])
688
+ if cache_key in self.cache:
689
+ return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
690
+ y_embed = (
691
+ torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)
692
+ .view(1, -1, 1)
693
+ .repeat(x.shape[0], 1, x.shape[-1])
694
+ )
695
+ x_embed = (
696
+ torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)
697
+ .view(1, 1, -1)
698
+ .repeat(x.shape[0], x.shape[-2], 1)
699
+ )
700
+
701
+ if self.normalize:
702
+ eps = 1e-6
703
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
704
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
705
+
706
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
707
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
708
+
709
+ pos_x = x_embed[:, :, :, None] / dim_t
710
+ pos_y = y_embed[:, :, :, None] / dim_t
711
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
712
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
713
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
714
+ self.cache[cache_key] = pos[0]
715
+ return pos