ultralytics 8.2.72__py3-none-any.whl → 8.2.74__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (34) hide show
  1. ultralytics/__init__.py +2 -3
  2. ultralytics/cfg/trackers/botsort.yaml +1 -1
  3. ultralytics/cfg/trackers/bytetrack.yaml +1 -1
  4. ultralytics/models/__init__.py +1 -2
  5. ultralytics/models/sam/__init__.py +2 -2
  6. ultralytics/models/sam/amg.py +27 -21
  7. ultralytics/models/sam/build.py +200 -9
  8. ultralytics/models/sam/model.py +86 -34
  9. ultralytics/models/sam/modules/blocks.py +1131 -0
  10. ultralytics/models/sam/modules/decoders.py +390 -23
  11. ultralytics/models/sam/modules/encoders.py +508 -323
  12. ultralytics/models/{sam2 → sam}/modules/memory_attention.py +73 -6
  13. ultralytics/models/sam/modules/sam.py +887 -16
  14. ultralytics/models/sam/modules/tiny_encoder.py +376 -126
  15. ultralytics/models/sam/modules/transformer.py +155 -54
  16. ultralytics/models/{sam2 → sam}/modules/utils.py +105 -3
  17. ultralytics/models/sam/predict.py +382 -92
  18. ultralytics/trackers/bot_sort.py +2 -3
  19. ultralytics/trackers/byte_tracker.py +2 -3
  20. {ultralytics-8.2.72.dist-info → ultralytics-8.2.74.dist-info}/METADATA +44 -44
  21. {ultralytics-8.2.72.dist-info → ultralytics-8.2.74.dist-info}/RECORD +25 -33
  22. ultralytics/models/sam2/__init__.py +0 -6
  23. ultralytics/models/sam2/build.py +0 -156
  24. ultralytics/models/sam2/model.py +0 -97
  25. ultralytics/models/sam2/modules/__init__.py +0 -1
  26. ultralytics/models/sam2/modules/decoders.py +0 -305
  27. ultralytics/models/sam2/modules/encoders.py +0 -332
  28. ultralytics/models/sam2/modules/sam2.py +0 -804
  29. ultralytics/models/sam2/modules/sam2_blocks.py +0 -715
  30. ultralytics/models/sam2/predict.py +0 -177
  31. {ultralytics-8.2.72.dist-info → ultralytics-8.2.74.dist-info}/LICENSE +0 -0
  32. {ultralytics-8.2.72.dist-info → ultralytics-8.2.74.dist-info}/WHEEL +0 -0
  33. {ultralytics-8.2.72.dist-info → ultralytics-8.2.74.dist-info}/entry_points.txt +0 -0
  34. {ultralytics-8.2.72.dist-info → ultralytics-8.2.74.dist-info}/top_level.txt +0 -0
@@ -1,30 +1,48 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
- from typing import Any, Optional, Tuple, Type
3
+ from typing import List, Optional, Tuple, Type
4
4
 
5
- import numpy as np
6
5
  import torch
7
6
  import torch.nn as nn
8
7
  import torch.nn.functional as F
9
8
 
10
- from ultralytics.nn.modules import LayerNorm2d, MLPBlock
9
+ from ultralytics.nn.modules import LayerNorm2d
10
+
11
+ from .blocks import (
12
+ Block,
13
+ CXBlock,
14
+ Fuser,
15
+ MaskDownSampler,
16
+ MultiScaleBlock,
17
+ PatchEmbed,
18
+ PositionEmbeddingRandom,
19
+ PositionEmbeddingSine,
20
+ )
11
21
 
12
22
 
13
23
  class ImageEncoderViT(nn.Module):
14
24
  """
15
- An image encoder using Vision Transformer (ViT) architecture for encoding an image into a compact latent space. The
16
- encoder takes an image, splits it into patches, and processes these patches through a series of transformer blocks.
17
- The encoded patches are then processed through a neck to generate the final encoded representation.
25
+ An image encoder using Vision Transformer (ViT) architecture for encoding images into a compact latent space.
18
26
 
19
- This class and its supporting functions below lightly adapted from the ViTDet backbone available at
20
- https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py.
27
+ This class processes images by splitting them into patches, applying transformer blocks, and generating a final
28
+ encoded representation through a neck module.
21
29
 
22
30
  Attributes:
23
31
  img_size (int): Dimension of input images, assumed to be square.
24
32
  patch_embed (PatchEmbed): Module for patch embedding.
25
- pos_embed (nn.Parameter, optional): Absolute positional embedding for patches.
33
+ pos_embed (nn.Parameter | None): Absolute positional embedding for patches.
26
34
  blocks (nn.ModuleList): List of transformer blocks for processing patch embeddings.
27
35
  neck (nn.Sequential): Neck module to further process the output.
36
+
37
+ Methods:
38
+ forward: Processes input through patch embedding, positional embedding, blocks, and neck.
39
+
40
+ Examples:
41
+ >>> import torch
42
+ >>> encoder = ImageEncoderViT(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12)
43
+ >>> input_image = torch.randn(1, 3, 224, 224)
44
+ >>> output = encoder(input_image)
45
+ >>> print(output.shape)
28
46
  """
29
47
 
30
48
  def __init__(
@@ -47,22 +65,38 @@ class ImageEncoderViT(nn.Module):
47
65
  global_attn_indexes: Tuple[int, ...] = (),
48
66
  ) -> None:
49
67
  """
68
+ Initializes an ImageEncoderViT instance for encoding images using Vision Transformer architecture.
69
+
50
70
  Args:
51
- img_size (int): Input image size.
52
- patch_size (int): Patch size.
71
+ img_size (int): Input image size, assumed to be square.
72
+ patch_size (int): Size of image patches.
53
73
  in_chans (int): Number of input image channels.
54
- embed_dim (int): Patch embedding dimension.
55
- depth (int): Depth of ViT.
56
- num_heads (int): Number of attention heads in each ViT block.
57
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
58
- qkv_bias (bool): If True, add a learnable bias to query, key, value.
59
- norm_layer (nn.Module): Normalization layer.
60
- act_layer (nn.Module): Activation layer.
61
- use_abs_pos (bool): If True, use absolute positional embeddings.
62
- use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
63
- rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
64
- window_size (int): Window size for window attention blocks.
65
- global_attn_indexes (list): Indexes for blocks using global attention.
74
+ embed_dim (int): Dimension of patch embeddings.
75
+ depth (int): Number of transformer blocks.
76
+ num_heads (int): Number of attention heads in each block.
77
+ mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
78
+ out_chans (int): Number of output channels from the neck module.
79
+ qkv_bias (bool): If True, adds learnable bias to query, key, value projections.
80
+ norm_layer (Type[nn.Module]): Type of normalization layer to use.
81
+ act_layer (Type[nn.Module]): Type of activation layer to use.
82
+ use_abs_pos (bool): If True, uses absolute positional embeddings.
83
+ use_rel_pos (bool): If True, adds relative positional embeddings to attention maps.
84
+ rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero.
85
+ window_size (int): Size of attention window for windowed attention blocks.
86
+ global_attn_indexes (Tuple[int, ...]): Indices of blocks that use global attention.
87
+
88
+ Attributes:
89
+ img_size (int): Dimension of input images.
90
+ patch_embed (PatchEmbed): Module for patch embedding.
91
+ pos_embed (nn.Parameter | None): Absolute positional embedding for patches.
92
+ blocks (nn.ModuleList): List of transformer blocks.
93
+ neck (nn.Sequential): Neck module for final processing.
94
+
95
+ Examples:
96
+ >>> encoder = ImageEncoderViT(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12)
97
+ >>> input_image = torch.randn(1, 3, 224, 224)
98
+ >>> output = encoder(input_image)
99
+ >>> print(output.shape)
66
100
  """
67
101
  super().__init__()
68
102
  self.img_size = img_size
@@ -114,9 +148,7 @@ class ImageEncoderViT(nn.Module):
114
148
  )
115
149
 
116
150
  def forward(self, x: torch.Tensor) -> torch.Tensor:
117
- """Processes input through patch embedding, applies positional embedding if present, and passes through blocks
118
- and neck.
119
- """
151
+ """Processes input through patch embedding, positional embedding, transformer blocks, and neck module."""
120
152
  x = self.patch_embed(x)
121
153
  if self.pos_embed is not None:
122
154
  x = x + self.pos_embed
@@ -127,8 +159,7 @@ class ImageEncoderViT(nn.Module):
127
159
 
128
160
  class PromptEncoder(nn.Module):
129
161
  """
130
- Encodes different types of prompts, including points, boxes, and masks, for input to SAM's mask decoder. The encoder
131
- produces both sparse and dense embeddings for the input prompts.
162
+ Encodes different types of prompts for input to SAM's mask decoder, producing sparse and dense embeddings.
132
163
 
133
164
  Attributes:
134
165
  embed_dim (int): Dimension of the embeddings.
@@ -137,10 +168,23 @@ class PromptEncoder(nn.Module):
137
168
  pe_layer (PositionEmbeddingRandom): Module for random position embedding.
138
169
  num_point_embeddings (int): Number of point embeddings for different types of points.
139
170
  point_embeddings (nn.ModuleList): List of point embeddings.
140
- not_a_point_embed (nn.Embedding): Embedding for points that are not a part of any label.
171
+ not_a_point_embed (nn.Embedding): Embedding for points that are not part of any label.
141
172
  mask_input_size (Tuple[int, int]): Size of the input mask.
142
173
  mask_downscaling (nn.Sequential): Neural network for downscaling the mask.
143
174
  no_mask_embed (nn.Embedding): Embedding for cases where no mask is provided.
175
+
176
+ Methods:
177
+ get_dense_pe: Returns the positional encoding used to encode point prompts.
178
+ forward: Embeds different types of prompts, returning both sparse and dense embeddings.
179
+
180
+ Examples:
181
+ >>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
182
+ >>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5)))
183
+ >>> boxes = torch.rand(1, 2, 2)
184
+ >>> masks = torch.rand(1, 1, 256, 256)
185
+ >>> sparse_embeddings, dense_embeddings = prompt_encoder(points, boxes, masks)
186
+ >>> print(sparse_embeddings.shape, dense_embeddings.shape)
187
+ torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])
144
188
  """
145
189
 
146
190
  def __init__(
@@ -152,18 +196,37 @@ class PromptEncoder(nn.Module):
152
196
  activation: Type[nn.Module] = nn.GELU,
153
197
  ) -> None:
154
198
  """
155
- Encodes prompts for input to SAM's mask decoder.
199
+ Initializes the PromptEncoder module for encoding various types of prompts.
200
+
201
+ This module encodes different types of prompts (points, boxes, masks) for input to SAM's mask decoder,
202
+ producing both sparse and dense embeddings.
156
203
 
157
204
  Args:
158
- embed_dim (int): The prompts' embedding dimension
159
- image_embedding_size (tuple(int, int)): The spatial size of the
160
- image embedding, as (H, W).
161
- input_image_size (int): The padded size of the image as input
162
- to the image encoder, as (H, W).
163
- mask_in_chans (int): The number of hidden channels used for
164
- encoding input masks.
165
- activation (nn.Module): The activation to use when encoding
166
- input masks.
205
+ embed_dim (int): The dimension of the embeddings.
206
+ image_embedding_size (Tuple[int, int]): The spatial size of the image embedding as (H, W).
207
+ input_image_size (Tuple[int, int]): The padded size of the input image as (H, W).
208
+ mask_in_chans (int): The number of hidden channels used for encoding input masks.
209
+ activation (Type[nn.Module]): The activation function to use when encoding input masks.
210
+
211
+ Attributes:
212
+ embed_dim (int): Dimension of the embeddings.
213
+ input_image_size (Tuple[int, int]): Size of the input image as (H, W).
214
+ image_embedding_size (Tuple[int, int]): Spatial size of the image embedding as (H, W).
215
+ pe_layer (PositionEmbeddingRandom): Module for random position embedding.
216
+ num_point_embeddings (int): Number of point embeddings for different types of points.
217
+ point_embeddings (nn.ModuleList): List of point embeddings.
218
+ not_a_point_embed (nn.Embedding): Embedding for points that are not part of any label.
219
+ mask_input_size (Tuple[int, int]): Size of the input mask.
220
+ mask_downscaling (nn.Sequential): Neural network for downscaling the mask.
221
+
222
+ Examples:
223
+ >>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
224
+ >>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5)))
225
+ >>> boxes = torch.rand(1, 2, 2)
226
+ >>> masks = torch.rand(1, 1, 256, 256)
227
+ >>> sparse_embeddings, dense_embeddings = prompt_encoder(points, boxes, masks)
228
+ >>> print(sparse_embeddings.shape, dense_embeddings.shape)
229
+ torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])
167
230
  """
168
231
  super().__init__()
169
232
  self.embed_dim = embed_dim
@@ -190,16 +253,25 @@ class PromptEncoder(nn.Module):
190
253
 
191
254
  def get_dense_pe(self) -> torch.Tensor:
192
255
  """
193
- Returns the positional encoding used to encode point prompts, applied to a dense set of points the shape of the
194
- image encoding.
256
+ Returns the dense positional encoding used for encoding point prompts.
257
+
258
+ This method generates a positional encoding for a dense set of points matching the shape of the image
259
+ encoding. The encoding is used to provide spatial information to the model when processing point prompts.
195
260
 
196
261
  Returns:
197
- torch.Tensor: Positional encoding with shape 1x(embed_dim)x(embedding_h)x(embedding_w)
262
+ (torch.Tensor): Positional encoding tensor with shape (1, embed_dim, H, W), where H and W are the
263
+ height and width of the image embedding size, respectively.
264
+
265
+ Examples:
266
+ >>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
267
+ >>> dense_pe = prompt_encoder.get_dense_pe()
268
+ >>> print(dense_pe.shape)
269
+ torch.Size([1, 256, 64, 64])
198
270
  """
199
271
  return self.pe_layer(self.image_embedding_size).unsqueeze(0)
200
272
 
201
273
  def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
202
- """Embeds point prompts."""
274
+ """Embeds point prompts by applying positional encoding and label-specific embeddings."""
203
275
  points = points + 0.5 # Shift to center of pixel
204
276
  if pad:
205
277
  padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
@@ -216,7 +288,7 @@ class PromptEncoder(nn.Module):
216
288
  return point_embedding
217
289
 
218
290
  def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
219
- """Embeds box prompts."""
291
+ """Embeds box prompts by applying positional encoding and adding corner embeddings."""
220
292
  boxes = boxes + 0.5 # Shift to center of pixel
221
293
  coords = boxes.reshape(-1, 2, 2)
222
294
  corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
@@ -225,7 +297,7 @@ class PromptEncoder(nn.Module):
225
297
  return corner_embedding
226
298
 
227
299
  def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
228
- """Embeds mask inputs."""
300
+ """Embeds mask inputs by downscaling and processing through convolutional layers."""
229
301
  return self.mask_downscaling(masks)
230
302
 
231
303
  @staticmethod
@@ -258,14 +330,25 @@ class PromptEncoder(nn.Module):
258
330
  Embeds different types of prompts, returning both sparse and dense embeddings.
259
331
 
260
332
  Args:
261
- points (tuple(torch.Tensor, torch.Tensor), None): point coordinates and labels to embed.
262
- boxes (torch.Tensor, None): boxes to embed
263
- masks (torch.Tensor, None): masks to embed
333
+ points (Tuple[torch.Tensor, torch.Tensor] | None): Point coordinates and labels to embed. The first
334
+ tensor contains coordinates with shape (B, N, 2), and the second tensor contains labels with
335
+ shape (B, N).
336
+ boxes (torch.Tensor | None): Boxes to embed with shape (B, M, 2, 2), where M is the number of boxes.
337
+ masks (torch.Tensor | None): Masks to embed with shape (B, 1, H, W).
264
338
 
265
339
  Returns:
266
- torch.Tensor: sparse embeddings for the points and boxes, with shape BxNx(embed_dim), where N is determined
267
- by the number of input points and boxes.
268
- torch.Tensor: dense embeddings for the masks, in the shape Bx(embed_dim)x(embed_H)x(embed_W)
340
+ (Tuple[torch.Tensor, torch.Tensor]): A tuple containing:
341
+ - sparse_embeddings (torch.Tensor): Sparse embeddings for points and boxes with shape (B, N, embed_dim).
342
+ - dense_embeddings (torch.Tensor): Dense embeddings for masks of shape (B, embed_dim, embed_H, embed_W).
343
+
344
+ Examples:
345
+ >>> encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
346
+ >>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5)))
347
+ >>> boxes = torch.rand(1, 2, 2, 2)
348
+ >>> masks = torch.rand(1, 1, 256, 256)
349
+ >>> sparse_emb, dense_emb = encoder(points, boxes, masks)
350
+ >>> print(sparse_emb.shape, dense_emb.shape)
351
+ torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])
269
352
  """
270
353
  bs = self._get_batch_size(points, boxes, masks)
271
354
  sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
@@ -287,319 +370,421 @@ class PromptEncoder(nn.Module):
287
370
  return sparse_embeddings, dense_embeddings
288
371
 
289
372
 
290
- class PositionEmbeddingRandom(nn.Module):
291
- """Positional encoding using random spatial frequencies."""
373
+ class MemoryEncoder(nn.Module):
374
+ """
375
+ Encodes pixel features and masks into a memory representation for efficient image segmentation.
292
376
 
293
- def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
294
- """Initializes a position embedding using random spatial frequencies."""
295
- super().__init__()
296
- if scale is None or scale <= 0.0:
297
- scale = 1.0
298
- self.register_buffer("positional_encoding_gaussian_matrix", scale * torch.randn((2, num_pos_feats)))
299
-
300
- # Set non-deterministic for forward() error 'cumsum_cuda_kernel does not have a deterministic implementation'
301
- torch.use_deterministic_algorithms(False)
302
- torch.backends.cudnn.deterministic = False
303
-
304
- def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
305
- """Positionally encode points that are normalized to [0,1]."""
306
- # Assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
307
- coords = 2 * coords - 1
308
- coords = coords @ self.positional_encoding_gaussian_matrix
309
- coords = 2 * np.pi * coords
310
- # Outputs d_1 x ... x d_n x C shape
311
- return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
312
-
313
- def forward(self, size: Tuple[int, int]) -> torch.Tensor:
314
- """Generate positional encoding for a grid of the specified size."""
315
- h, w = size
316
- device: Any = self.positional_encoding_gaussian_matrix.device
317
- grid = torch.ones((h, w), device=device, dtype=torch.float32)
318
- y_embed = grid.cumsum(dim=0) - 0.5
319
- x_embed = grid.cumsum(dim=1) - 0.5
320
- y_embed = y_embed / h
321
- x_embed = x_embed / w
322
-
323
- pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
324
- return pe.permute(2, 0, 1) # C x H x W
325
-
326
- def forward_with_coords(self, coords_input: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor:
327
- """Positionally encode points that are not normalized to [0,1]."""
328
- coords = coords_input.clone()
329
- coords[:, :, 0] = coords[:, :, 0] / image_size[1]
330
- coords[:, :, 1] = coords[:, :, 1] / image_size[0]
331
- return self._pe_encoding(coords.to(torch.float)) # B x N x C
332
-
333
-
334
- class Block(nn.Module):
335
- """Transformer blocks with support of window attention and residual propagation blocks."""
377
+ This class processes pixel-level features and masks, fusing them to generate encoded memory representations
378
+ suitable for downstream tasks in image segmentation models like SAM (Segment Anything Model).
379
+
380
+ Attributes:
381
+ mask_downsampler (MaskDownSampler): Module for downsampling input masks.
382
+ pix_feat_proj (nn.Conv2d): Convolutional layer for projecting pixel features.
383
+ fuser (Fuser): Module for fusing pixel features and masks.
384
+ position_encoding (PositionEmbeddingSine): Module for adding positional encoding to features.
385
+ out_proj (nn.Module): Output projection layer, either nn.Identity or nn.Conv2d.
386
+
387
+ Methods:
388
+ forward: Processes input pixel features and masks to generate encoded memory representations.
389
+
390
+ Examples:
391
+ >>> import torch
392
+ >>> encoder = MemoryEncoder(out_dim=256, in_dim=256)
393
+ >>> pix_feat = torch.randn(1, 256, 64, 64)
394
+ >>> masks = torch.randn(1, 1, 64, 64)
395
+ >>> encoded_feat, pos = encoder(pix_feat, masks)
396
+ >>> print(encoded_feat.shape, pos.shape)
397
+ torch.Size([1, 256, 64, 64]) torch.Size([1, 128, 64, 64])
398
+ """
336
399
 
337
400
  def __init__(
338
401
  self,
339
- dim: int,
340
- num_heads: int,
341
- mlp_ratio: float = 4.0,
342
- qkv_bias: bool = True,
343
- norm_layer: Type[nn.Module] = nn.LayerNorm,
344
- act_layer: Type[nn.Module] = nn.GELU,
345
- use_rel_pos: bool = False,
346
- rel_pos_zero_init: bool = True,
347
- window_size: int = 0,
348
- input_size: Optional[Tuple[int, int]] = None,
349
- ) -> None:
350
- """
351
- Args:
352
- dim (int): Number of input channels.
353
- num_heads (int): Number of attention heads in each ViT block.
354
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
355
- qkv_bias (bool): If True, add a learnable bias to query, key, value.
356
- norm_layer (nn.Module): Normalization layer.
357
- act_layer (nn.Module): Activation layer.
358
- use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
359
- rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
360
- window_size (int): Window size for window attention blocks. If it equals 0, then
361
- use global attention.
362
- input_size (tuple(int, int), None): Input resolution for calculating the relative
363
- positional parameter size.
364
- """
402
+ out_dim,
403
+ in_dim=256, # in_dim of pix_feats
404
+ ):
405
+ """Initializes the MemoryEncoder for encoding pixel features and masks into memory representations."""
365
406
  super().__init__()
366
- self.norm1 = norm_layer(dim)
367
- self.attn = Attention(
368
- dim,
369
- num_heads=num_heads,
370
- qkv_bias=qkv_bias,
371
- use_rel_pos=use_rel_pos,
372
- rel_pos_zero_init=rel_pos_zero_init,
373
- input_size=input_size if window_size == 0 else (window_size, window_size),
374
- )
375
407
 
376
- self.norm2 = norm_layer(dim)
377
- self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
408
+ self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1)
378
409
 
379
- self.window_size = window_size
410
+ self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
411
+ self.fuser = Fuser(CXBlock(dim=256), num_layers=2)
412
+ self.position_encoding = PositionEmbeddingSine(num_pos_feats=64)
413
+ self.out_proj = nn.Identity()
414
+ if out_dim != in_dim:
415
+ self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
380
416
 
381
- def forward(self, x: torch.Tensor) -> torch.Tensor:
382
- """Executes a forward pass through the transformer block with window attention and non-overlapping windows."""
383
- shortcut = x
384
- x = self.norm1(x)
385
- # Window partition
386
- if self.window_size > 0:
387
- H, W = x.shape[1], x.shape[2]
388
- x, pad_hw = window_partition(x, self.window_size)
417
+ def forward(
418
+ self,
419
+ pix_feat: torch.Tensor,
420
+ masks: torch.Tensor,
421
+ skip_mask_sigmoid: bool = False,
422
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
423
+ """Processes pixel features and masks to generate encoded memory representations for segmentation."""
424
+ if not skip_mask_sigmoid:
425
+ masks = F.sigmoid(masks)
426
+ masks = self.mask_downsampler(masks)
389
427
 
390
- x = self.attn(x)
391
- # Reverse window partition
392
- if self.window_size > 0:
393
- x = window_unpartition(x, self.window_size, pad_hw, (H, W))
428
+ # Fuse pix_feats and downsampled masks, in case the visual features are on CPU, cast them to CUDA
429
+ pix_feat = pix_feat.to(masks.device)
394
430
 
395
- x = shortcut + x
396
- return x + self.mlp(self.norm2(x))
431
+ x = self.pix_feat_proj(pix_feat)
432
+ x = x + masks
433
+ x = self.fuser(x)
434
+ x = self.out_proj(x)
397
435
 
436
+ pos = self.position_encoding(x).to(x.dtype)
398
437
 
399
- class Attention(nn.Module):
400
- """Multi-head Attention block with relative position embeddings."""
438
+ return {"vision_features": x, "vision_pos_enc": [pos]}
401
439
 
402
- def __init__(
403
- self,
404
- dim: int,
405
- num_heads: int = 8,
406
- qkv_bias: bool = True,
407
- use_rel_pos: bool = False,
408
- rel_pos_zero_init: bool = True,
409
- input_size: Optional[Tuple[int, int]] = None,
410
- ) -> None:
411
- """
412
- Initialize Attention module.
413
440
 
414
- Args:
415
- dim (int): Number of input channels.
416
- num_heads (int): Number of attention heads.
417
- qkv_bias (bool): If True, add a learnable bias to query, key, value.
418
- rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
419
- input_size (tuple(int, int), None): Input resolution for calculating the relative
420
- positional parameter size.
421
- """
422
- super().__init__()
423
- self.num_heads = num_heads
424
- head_dim = dim // num_heads
425
- self.scale = head_dim**-0.5
441
+ class ImageEncoder(nn.Module):
442
+ """
443
+ Encodes images using a trunk-neck architecture, producing multiscale features and positional encodings.
426
444
 
427
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
428
- self.proj = nn.Linear(dim, dim)
445
+ This class combines a trunk network for feature extraction with a neck network for feature refinement
446
+ and positional encoding generation. It can optionally discard the lowest resolution features.
429
447
 
430
- self.use_rel_pos = use_rel_pos
431
- if self.use_rel_pos:
432
- assert input_size is not None, "Input size must be provided if using relative positional encoding."
433
- # Initialize relative positional embeddings
434
- self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
435
- self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
448
+ Attributes:
449
+ trunk (nn.Module): The trunk network for initial feature extraction.
450
+ neck (nn.Module): The neck network for feature refinement and positional encoding generation.
451
+ scalp (int): Number of lowest resolution feature levels to discard.
452
+
453
+ Methods:
454
+ forward: Processes the input image through the trunk and neck networks.
455
+
456
+ Examples:
457
+ >>> trunk = SomeTrunkNetwork()
458
+ >>> neck = SomeNeckNetwork()
459
+ >>> encoder = ImageEncoder(trunk, neck, scalp=1)
460
+ >>> image = torch.randn(1, 3, 224, 224)
461
+ >>> output = encoder(image)
462
+ >>> print(output.keys())
463
+ dict_keys(['vision_features', 'vision_pos_enc', 'backbone_fpn'])
464
+ """
436
465
 
437
- def forward(self, x: torch.Tensor) -> torch.Tensor:
438
- """Applies the forward operation including attention, normalization, MLP, and indexing within window limits."""
439
- B, H, W, _ = x.shape
440
- # qkv with shape (3, B, nHead, H * W, C)
441
- qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
442
- # q, k, v with shape (B * nHead, H * W, C)
443
- q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
466
+ def __init__(
467
+ self,
468
+ trunk: nn.Module,
469
+ neck: nn.Module,
470
+ scalp: int = 0,
471
+ ):
472
+ """Initializes the ImageEncoder with trunk and neck networks for feature extraction and refinement."""
473
+ super().__init__()
474
+ self.trunk = trunk
475
+ self.neck = neck
476
+ self.scalp = scalp
477
+ assert (
478
+ self.trunk.channel_list == self.neck.backbone_channel_list
479
+ ), f"Channel dims of trunk {self.trunk.channel_list} and neck {self.neck.backbone_channel_list} do not match."
480
+
481
+ def forward(self, sample: torch.Tensor):
482
+ """Encodes input through patch embedding, positional embedding, transformer blocks, and neck module."""
483
+ features, pos = self.neck(self.trunk(sample))
484
+ if self.scalp > 0:
485
+ # Discard the lowest resolution features
486
+ features, pos = features[: -self.scalp], pos[: -self.scalp]
487
+
488
+ src = features[-1]
489
+ output = {
490
+ "vision_features": src,
491
+ "vision_pos_enc": pos,
492
+ "backbone_fpn": features,
493
+ }
494
+ return output
495
+
496
+
497
+ class FpnNeck(nn.Module):
498
+ """
499
+ A Feature Pyramid Network (FPN) neck variant for multiscale feature fusion in object detection models.
444
500
 
445
- attn = (q * self.scale) @ k.transpose(-2, -1)
501
+ This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing,
502
+ similar to ViT positional embedding interpolation.
446
503
 
447
- if self.use_rel_pos:
448
- attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
504
+ Attributes:
505
+ position_encoding (PositionEmbeddingSine): Sinusoidal positional encoding module.
506
+ convs (nn.ModuleList): List of convolutional layers for each backbone level.
507
+ backbone_channel_list (List[int]): List of channel dimensions from the backbone.
508
+ fpn_interp_model (str): Interpolation mode for FPN feature resizing.
509
+ fuse_type (str): Type of feature fusion, either 'sum' or 'avg'.
510
+ fpn_top_down_levels (List[int]): Levels to have top-down features in outputs.
511
+
512
+ Methods:
513
+ forward: Performs forward pass through the FPN neck.
514
+
515
+ Examples:
516
+ >>> backbone_channels = [64, 128, 256, 512]
517
+ >>> fpn_neck = FpnNeck(256, backbone_channels)
518
+ >>> inputs = [torch.rand(1, c, 32, 32) for c in backbone_channels]
519
+ >>> outputs, positions = fpn_neck(inputs)
520
+ >>> print(len(outputs), len(positions))
521
+ 4 4
522
+ """
449
523
 
450
- attn = attn.softmax(dim=-1)
451
- x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
452
- return self.proj(x)
524
+ def __init__(
525
+ self,
526
+ d_model: int,
527
+ backbone_channel_list: List[int],
528
+ kernel_size: int = 1,
529
+ stride: int = 1,
530
+ padding: int = 0,
531
+ fpn_interp_model: str = "bilinear",
532
+ fuse_type: str = "sum",
533
+ fpn_top_down_levels: Optional[List[int]] = None,
534
+ ):
535
+ """
536
+ Initializes a modified Feature Pyramid Network (FPN) neck.
453
537
 
538
+ This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing,
539
+ similar to ViT positional embedding interpolation.
454
540
 
455
- def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
456
- """
457
- Partition into non-overlapping windows with padding if needed.
458
- Args:
459
- x (tensor): input tokens with [B, H, W, C].
460
- window_size (int): window size.
461
-
462
- Returns:
463
- windows: windows after partition with [B * num_windows, window_size, window_size, C].
464
- (Hp, Wp): padded height and width before partition
465
- """
466
- B, H, W, C = x.shape
541
+ Args:
542
+ d_model (int): Dimension of the model.
543
+ backbone_channel_list (List[int]): List of channel dimensions from the backbone.
544
+ kernel_size (int): Kernel size for the convolutional layers.
545
+ stride (int): Stride for the convolutional layers.
546
+ padding (int): Padding for the convolutional layers.
547
+ fpn_interp_model (str): Interpolation mode for FPN feature resizing.
548
+ fuse_type (str): Type of feature fusion, either 'sum' or 'avg'.
549
+ fpn_top_down_levels (Optional[List[int]]): Levels to have top-down features in outputs.
550
+
551
+ Examples:
552
+ >>> backbone_channels = [64, 128, 256, 512]
553
+ >>> fpn_neck = FpnNeck(256, backbone_channels)
554
+ >>> print(fpn_neck)
555
+ """
556
+ super().__init__()
557
+ self.position_encoding = PositionEmbeddingSine(num_pos_feats=256)
558
+ self.convs = nn.ModuleList()
559
+ self.backbone_channel_list = backbone_channel_list
560
+ for dim in backbone_channel_list:
561
+ current = nn.Sequential()
562
+ current.add_module(
563
+ "conv",
564
+ nn.Conv2d(
565
+ in_channels=dim,
566
+ out_channels=d_model,
567
+ kernel_size=kernel_size,
568
+ stride=stride,
569
+ padding=padding,
570
+ ),
571
+ )
467
572
 
468
- pad_h = (window_size - H % window_size) % window_size
469
- pad_w = (window_size - W % window_size) % window_size
470
- if pad_h > 0 or pad_w > 0:
471
- x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
472
- Hp, Wp = H + pad_h, W + pad_w
573
+ self.convs.append(current)
574
+ self.fpn_interp_model = fpn_interp_model
575
+ assert fuse_type in ["sum", "avg"]
576
+ self.fuse_type = fuse_type
577
+
578
+ # levels to have top-down features in its outputs
579
+ # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
580
+ # have top-down propagation, while outputs of level 0 and level 1 have only
581
+ # lateral features from the same backbone level.
582
+ if fpn_top_down_levels is None:
583
+ # default is to have top-down features on all levels
584
+ fpn_top_down_levels = range(len(self.convs))
585
+ self.fpn_top_down_levels = list(fpn_top_down_levels)
586
+
587
+ def forward(self, xs: List[torch.Tensor]):
588
+ """
589
+ Performs forward pass through the Feature Pyramid Network (FPN) neck.
473
590
 
474
- x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
475
- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
476
- return windows, (Hp, Wp)
591
+ This method processes a list of input tensors from the backbone through the FPN, applying lateral connections
592
+ and top-down feature fusion. It generates output feature maps and corresponding positional encodings.
477
593
 
594
+ Args:
595
+ xs (List[torch.Tensor]): List of input tensors from the backbone, each with shape (B, C, H, W).
478
596
 
479
- def window_unpartition(
480
- windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
481
- ) -> torch.Tensor:
597
+ Returns:
598
+ (Tuple[List[torch.Tensor], List[torch.Tensor]]): A tuple containing:
599
+ - out (List[torch.Tensor]): List of output feature maps after FPN processing, each with shape
600
+ (B, d_model, H, W).
601
+ - pos (List[torch.Tensor]): List of positional encodings corresponding to each output feature map.
602
+
603
+ Examples:
604
+ >>> fpn_neck = FpnNeck(d_model=256, backbone_channel_list=[64, 128, 256, 512])
605
+ >>> inputs = [torch.rand(1, c, 32, 32) for c in [64, 128, 256, 512]]
606
+ >>> outputs, positions = fpn_neck(inputs)
607
+ >>> print(len(outputs), len(positions))
608
+ 4 4
609
+ """
610
+ out = [None] * len(self.convs)
611
+ pos = [None] * len(self.convs)
612
+ assert len(xs) == len(self.convs)
613
+ # fpn forward pass
614
+ # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
615
+ prev_features = None
616
+ # forward in top-down order (from low to high resolution)
617
+ n = len(self.convs) - 1
618
+ for i in range(n, -1, -1):
619
+ x = xs[i]
620
+ lateral_features = self.convs[n - i](x)
621
+ if i in self.fpn_top_down_levels and prev_features is not None:
622
+ top_down_features = F.interpolate(
623
+ prev_features.to(dtype=torch.float32),
624
+ scale_factor=2.0,
625
+ mode=self.fpn_interp_model,
626
+ align_corners=(None if self.fpn_interp_model == "nearest" else False),
627
+ antialias=False,
628
+ )
629
+ prev_features = lateral_features + top_down_features
630
+ if self.fuse_type == "avg":
631
+ prev_features /= 2
632
+ else:
633
+ prev_features = lateral_features
634
+ x_out = prev_features
635
+ out[i] = x_out
636
+ pos[i] = self.position_encoding(x_out).to(x_out.dtype)
637
+
638
+ return out, pos
639
+
640
+
641
+ class Hiera(nn.Module):
482
642
  """
483
- Window unpartition into original sequences and removing padding.
643
+ Hierarchical vision transformer for efficient multiscale feature extraction in image processing tasks.
484
644
 
485
- Args:
486
- windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
487
- window_size (int): window size.
488
- pad_hw (Tuple): padded height and width (Hp, Wp).
489
- hw (Tuple): original height and width (H, W) before padding.
645
+ This class implements a Hiera model, which is a hierarchical vision transformer architecture designed for
646
+ efficient multiscale feature extraction. It uses a series of transformer blocks organized into stages,
647
+ with optional pooling and global attention mechanisms.
490
648
 
491
- Returns:
492
- x: unpartitioned sequences with [B, H, W, C].
649
+ Attributes:
650
+ window_spec (Tuple[int, ...]): Window sizes for each stage.
651
+ q_stride (Tuple[int, int]): Downsampling stride between stages.
652
+ stage_ends (List[int]): Indices of the last block in each stage.
653
+ q_pool_blocks (List[int]): Indices of blocks where pooling is applied.
654
+ return_interm_layers (bool): Whether to return intermediate layer outputs.
655
+ patch_embed (PatchEmbed): Module for patch embedding.
656
+ global_att_blocks (Tuple[int, ...]): Indices of blocks with global attention.
657
+ window_pos_embed_bkg_spatial_size (Tuple[int, int]): Spatial size for window positional embedding background.
658
+ pos_embed (nn.Parameter): Positional embedding for the background.
659
+ pos_embed_window (nn.Parameter): Positional embedding for the window.
660
+ blocks (nn.ModuleList): List of MultiScaleBlock modules.
661
+ channel_list (List[int]): List of output channel dimensions for each stage.
662
+
663
+ Methods:
664
+ _get_pos_embed: Generates positional embeddings by interpolating and combining window and background embeddings.
665
+ forward: Performs the forward pass through the Hiera model.
666
+
667
+ Examples:
668
+ >>> model = Hiera(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3))
669
+ >>> input_tensor = torch.randn(1, 3, 224, 224)
670
+ >>> output_features = model(input_tensor)
671
+ >>> for feat in output_features:
672
+ ... print(feat.shape)
493
673
  """
494
- Hp, Wp = pad_hw
495
- H, W = hw
496
- B = windows.shape[0] // (Hp * Wp // window_size // window_size)
497
- x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
498
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
499
-
500
- if Hp > H or Wp > W:
501
- x = x[:, :H, :W, :].contiguous()
502
- return x
503
674
 
675
+ def __init__(
676
+ self,
677
+ embed_dim: int = 96, # initial embed dim
678
+ num_heads: int = 1, # initial number of heads
679
+ drop_path_rate: float = 0.0, # stochastic depth
680
+ q_pool: int = 3, # number of q_pool stages
681
+ q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
682
+ stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
683
+ dim_mul: float = 2.0, # dim_mul factor at stage shift
684
+ head_mul: float = 2.0, # head_mul factor at stage shift
685
+ window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
686
+ # window size per stage, when not using global att.
687
+ window_spec: Tuple[int, ...] = (
688
+ 8,
689
+ 4,
690
+ 14,
691
+ 7,
692
+ ),
693
+ # global attn in these blocks
694
+ global_att_blocks: Tuple[int, ...] = (
695
+ 12,
696
+ 16,
697
+ 20,
698
+ ),
699
+ return_interm_layers=True, # return feats from every stage
700
+ ):
701
+ """Initializes the Hiera model, configuring its hierarchical vision transformer architecture."""
702
+ super().__init__()
504
703
 
505
- def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
506
- """
507
- Get relative positional embeddings according to the relative positions of query and key sizes.
704
+ assert len(stages) == len(window_spec)
705
+ self.window_spec = window_spec
508
706
 
509
- Args:
510
- q_size (int): size of query q.
511
- k_size (int): size of key k.
512
- rel_pos (Tensor): relative position embeddings (L, C).
707
+ depth = sum(stages)
708
+ self.q_stride = q_stride
709
+ self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
710
+ assert 0 <= q_pool <= len(self.stage_ends[:-1])
711
+ self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
712
+ self.return_interm_layers = return_interm_layers
513
713
 
514
- Returns:
515
- Extracted positional embeddings according to relative positions.
516
- """
517
- max_rel_dist = int(2 * max(q_size, k_size) - 1)
518
- # Interpolate rel pos if needed.
519
- if rel_pos.shape[0] != max_rel_dist:
520
- # Interpolate rel pos.
521
- rel_pos_resized = F.interpolate(
522
- rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
523
- size=max_rel_dist,
524
- mode="linear",
714
+ self.patch_embed = PatchEmbed(
715
+ embed_dim=embed_dim,
716
+ kernel_size=(7, 7),
717
+ stride=(4, 4),
718
+ padding=(3, 3),
525
719
  )
526
- rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
527
- else:
528
- rel_pos_resized = rel_pos
529
-
530
- # Scale the coords with short length if shapes for q and k are different.
531
- q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
532
- k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
533
- relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
534
-
535
- return rel_pos_resized[relative_coords.long()]
536
-
537
-
538
- def add_decomposed_rel_pos(
539
- attn: torch.Tensor,
540
- q: torch.Tensor,
541
- rel_pos_h: torch.Tensor,
542
- rel_pos_w: torch.Tensor,
543
- q_size: Tuple[int, int],
544
- k_size: Tuple[int, int],
545
- ) -> torch.Tensor:
546
- """
547
- Calculate decomposed Relative Positional Embeddings from mvitv2 paper at
548
- https://github.com/facebookresearch/mvit/blob/main/mvit/models/attention.py.
549
-
550
- Args:
551
- attn (Tensor): attention map.
552
- q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
553
- rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
554
- rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
555
- q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
556
- k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
557
-
558
- Returns:
559
- attn (Tensor): attention map with added relative positional embeddings.
560
- """
561
- q_h, q_w = q_size
562
- k_h, k_w = k_size
563
- Rh = get_rel_pos(q_h, k_h, rel_pos_h)
564
- Rw = get_rel_pos(q_w, k_w, rel_pos_w)
720
+ # Which blocks have global att?
721
+ self.global_att_blocks = global_att_blocks
565
722
 
566
- B, _, dim = q.shape
567
- r_q = q.reshape(B, q_h, q_w, dim)
568
- rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
569
- rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
723
+ # Windowed positional embedding (https://arxiv.org/abs/2311.05613)
724
+ self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
725
+ self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size))
726
+ self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]))
570
727
 
571
- attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
572
- B, q_h * q_w, k_h * k_w
573
- )
728
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
574
729
 
575
- return attn
730
+ cur_stage = 1
731
+ self.blocks = nn.ModuleList()
576
732
 
733
+ for i in range(depth):
734
+ dim_out = embed_dim
735
+ # lags by a block, so first block of
736
+ # next stage uses an initial window size
737
+ # of previous stage and final window size of current stage
738
+ window_size = self.window_spec[cur_stage - 1]
577
739
 
578
- class PatchEmbed(nn.Module):
579
- """Image to Patch Embedding."""
740
+ if self.global_att_blocks is not None:
741
+ window_size = 0 if i in self.global_att_blocks else window_size
580
742
 
581
- def __init__(
582
- self,
583
- kernel_size: Tuple[int, int] = (16, 16),
584
- stride: Tuple[int, int] = (16, 16),
585
- padding: Tuple[int, int] = (0, 0),
586
- in_chans: int = 3,
587
- embed_dim: int = 768,
588
- ) -> None:
589
- """
590
- Initialize PatchEmbed module.
743
+ if i - 1 in self.stage_ends:
744
+ dim_out = int(embed_dim * dim_mul)
745
+ num_heads = int(num_heads * head_mul)
746
+ cur_stage += 1
591
747
 
592
- Args:
593
- kernel_size (Tuple): kernel size of the projection layer.
594
- stride (Tuple): stride of the projection layer.
595
- padding (Tuple): padding size of the projection layer.
596
- in_chans (int): Number of input image channels.
597
- embed_dim (int): Patch embedding dimension.
598
- """
599
- super().__init__()
748
+ block = MultiScaleBlock(
749
+ dim=embed_dim,
750
+ dim_out=dim_out,
751
+ num_heads=num_heads,
752
+ drop_path=dpr[i],
753
+ q_stride=self.q_stride if i in self.q_pool_blocks else None,
754
+ window_size=window_size,
755
+ )
600
756
 
601
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
757
+ embed_dim = dim_out
758
+ self.blocks.append(block)
602
759
 
603
- def forward(self, x: torch.Tensor) -> torch.Tensor:
604
- """Computes patch embedding by applying convolution and transposing resulting tensor."""
605
- return self.proj(x).permute(0, 2, 3, 1) # B C H W -> B H W C
760
+ self.channel_list = (
761
+ [self.blocks[i].dim_out for i in self.stage_ends[::-1]]
762
+ if return_interm_layers
763
+ else [self.blocks[-1].dim_out]
764
+ )
765
+
766
+ def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
767
+ """Generates positional embeddings by interpolating and combining window and background embeddings."""
768
+ h, w = hw
769
+ window_embed = self.pos_embed_window
770
+ pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
771
+ pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)])
772
+ pos_embed = pos_embed.permute(0, 2, 3, 1)
773
+ return pos_embed
774
+
775
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
776
+ """Performs forward pass through Hiera model, extracting multiscale features from input images."""
777
+ x = self.patch_embed(x)
778
+ # x: (B, H, W, C)
779
+
780
+ # Add pos embed
781
+ x = x + self._get_pos_embed(x.shape[1:3])
782
+
783
+ outputs = []
784
+ for i, blk in enumerate(self.blocks):
785
+ x = blk(x)
786
+ if (i == self.stage_ends[-1]) or (i in self.stage_ends and self.return_interm_layers):
787
+ feats = x.permute(0, 3, 1, 2)
788
+ outputs.append(feats)
789
+
790
+ return outputs