ultralytics 8.1.28__py3-none-any.whl → 8.3.62__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (247) hide show
  1. tests/__init__.py +22 -0
  2. tests/conftest.py +83 -0
  3. tests/test_cli.py +122 -0
  4. tests/test_cuda.py +155 -0
  5. tests/test_engine.py +131 -0
  6. tests/test_exports.py +216 -0
  7. tests/test_integrations.py +150 -0
  8. tests/test_python.py +615 -0
  9. tests/test_solutions.py +94 -0
  10. ultralytics/__init__.py +11 -8
  11. ultralytics/cfg/__init__.py +569 -131
  12. ultralytics/cfg/datasets/Argoverse.yaml +2 -1
  13. ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
  14. ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
  15. ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
  16. ultralytics/cfg/datasets/ImageNet.yaml +2 -1
  17. ultralytics/cfg/datasets/Objects365.yaml +5 -4
  18. ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
  19. ultralytics/cfg/datasets/VOC.yaml +3 -2
  20. ultralytics/cfg/datasets/VisDrone.yaml +6 -5
  21. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  22. ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
  23. ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
  24. ultralytics/cfg/datasets/coco-pose.yaml +7 -6
  25. ultralytics/cfg/datasets/coco.yaml +3 -2
  26. ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
  27. ultralytics/cfg/datasets/coco128.yaml +4 -3
  28. ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
  29. ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
  30. ultralytics/cfg/datasets/coco8.yaml +3 -2
  31. ultralytics/cfg/datasets/crack-seg.yaml +3 -2
  32. ultralytics/cfg/datasets/dog-pose.yaml +24 -0
  33. ultralytics/cfg/datasets/dota8.yaml +3 -2
  34. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
  35. ultralytics/cfg/datasets/lvis.yaml +1236 -0
  36. ultralytics/cfg/datasets/medical-pills.yaml +22 -0
  37. ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
  38. ultralytics/cfg/datasets/package-seg.yaml +5 -4
  39. ultralytics/cfg/datasets/signature.yaml +21 -0
  40. ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
  41. ultralytics/cfg/datasets/xView.yaml +2 -1
  42. ultralytics/cfg/default.yaml +14 -11
  43. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
  44. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  45. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  46. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  47. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  48. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  49. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
  51. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
  52. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
  53. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  54. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  55. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  56. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  57. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  58. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  59. ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
  60. ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
  61. ultralytics/cfg/models/v3/yolov3.yaml +5 -2
  62. ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
  63. ultralytics/cfg/models/v5/yolov5.yaml +5 -2
  64. ultralytics/cfg/models/v6/yolov6.yaml +5 -2
  65. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
  66. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
  67. ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
  68. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
  69. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
  70. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
  71. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
  72. ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
  73. ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
  74. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
  75. ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
  76. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
  77. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
  78. ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
  79. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
  80. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
  81. ultralytics/cfg/models/v8/yolov8.yaml +5 -2
  82. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  83. ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
  84. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  85. ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
  86. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  87. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  88. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  89. ultralytics/cfg/solutions/default.yaml +24 -0
  90. ultralytics/cfg/trackers/botsort.yaml +8 -5
  91. ultralytics/cfg/trackers/bytetrack.yaml +8 -5
  92. ultralytics/data/__init__.py +14 -3
  93. ultralytics/data/annotator.py +37 -15
  94. ultralytics/data/augment.py +1783 -289
  95. ultralytics/data/base.py +62 -27
  96. ultralytics/data/build.py +36 -8
  97. ultralytics/data/converter.py +196 -36
  98. ultralytics/data/dataset.py +233 -94
  99. ultralytics/data/loaders.py +199 -96
  100. ultralytics/data/split_dota.py +39 -29
  101. ultralytics/data/utils.py +110 -40
  102. ultralytics/engine/__init__.py +1 -1
  103. ultralytics/engine/exporter.py +569 -242
  104. ultralytics/engine/model.py +604 -252
  105. ultralytics/engine/predictor.py +22 -11
  106. ultralytics/engine/results.py +1228 -218
  107. ultralytics/engine/trainer.py +190 -129
  108. ultralytics/engine/tuner.py +18 -18
  109. ultralytics/engine/validator.py +18 -15
  110. ultralytics/hub/__init__.py +31 -13
  111. ultralytics/hub/auth.py +11 -7
  112. ultralytics/hub/google/__init__.py +159 -0
  113. ultralytics/hub/session.py +128 -94
  114. ultralytics/hub/utils.py +20 -21
  115. ultralytics/models/__init__.py +4 -2
  116. ultralytics/models/fastsam/__init__.py +2 -3
  117. ultralytics/models/fastsam/model.py +26 -4
  118. ultralytics/models/fastsam/predict.py +127 -63
  119. ultralytics/models/fastsam/utils.py +1 -44
  120. ultralytics/models/fastsam/val.py +1 -1
  121. ultralytics/models/nas/__init__.py +1 -1
  122. ultralytics/models/nas/model.py +21 -10
  123. ultralytics/models/nas/predict.py +3 -6
  124. ultralytics/models/nas/val.py +4 -4
  125. ultralytics/models/rtdetr/__init__.py +1 -1
  126. ultralytics/models/rtdetr/model.py +1 -1
  127. ultralytics/models/rtdetr/predict.py +6 -8
  128. ultralytics/models/rtdetr/train.py +6 -2
  129. ultralytics/models/rtdetr/val.py +3 -3
  130. ultralytics/models/sam/__init__.py +3 -3
  131. ultralytics/models/sam/amg.py +29 -23
  132. ultralytics/models/sam/build.py +211 -13
  133. ultralytics/models/sam/model.py +91 -30
  134. ultralytics/models/sam/modules/__init__.py +1 -1
  135. ultralytics/models/sam/modules/blocks.py +1129 -0
  136. ultralytics/models/sam/modules/decoders.py +381 -53
  137. ultralytics/models/sam/modules/encoders.py +515 -324
  138. ultralytics/models/sam/modules/memory_attention.py +237 -0
  139. ultralytics/models/sam/modules/sam.py +969 -21
  140. ultralytics/models/sam/modules/tiny_encoder.py +425 -154
  141. ultralytics/models/sam/modules/transformer.py +159 -60
  142. ultralytics/models/sam/modules/utils.py +293 -0
  143. ultralytics/models/sam/predict.py +1263 -132
  144. ultralytics/models/utils/__init__.py +1 -1
  145. ultralytics/models/utils/loss.py +36 -24
  146. ultralytics/models/utils/ops.py +3 -7
  147. ultralytics/models/yolo/__init__.py +3 -3
  148. ultralytics/models/yolo/classify/__init__.py +1 -1
  149. ultralytics/models/yolo/classify/predict.py +7 -8
  150. ultralytics/models/yolo/classify/train.py +17 -22
  151. ultralytics/models/yolo/classify/val.py +8 -4
  152. ultralytics/models/yolo/detect/__init__.py +1 -1
  153. ultralytics/models/yolo/detect/predict.py +3 -5
  154. ultralytics/models/yolo/detect/train.py +11 -4
  155. ultralytics/models/yolo/detect/val.py +90 -52
  156. ultralytics/models/yolo/model.py +14 -9
  157. ultralytics/models/yolo/obb/__init__.py +1 -1
  158. ultralytics/models/yolo/obb/predict.py +2 -2
  159. ultralytics/models/yolo/obb/train.py +5 -3
  160. ultralytics/models/yolo/obb/val.py +41 -23
  161. ultralytics/models/yolo/pose/__init__.py +1 -1
  162. ultralytics/models/yolo/pose/predict.py +3 -5
  163. ultralytics/models/yolo/pose/train.py +2 -2
  164. ultralytics/models/yolo/pose/val.py +51 -17
  165. ultralytics/models/yolo/segment/__init__.py +1 -1
  166. ultralytics/models/yolo/segment/predict.py +3 -5
  167. ultralytics/models/yolo/segment/train.py +2 -2
  168. ultralytics/models/yolo/segment/val.py +60 -19
  169. ultralytics/models/yolo/world/__init__.py +5 -0
  170. ultralytics/models/yolo/world/train.py +92 -0
  171. ultralytics/models/yolo/world/train_world.py +109 -0
  172. ultralytics/nn/__init__.py +1 -1
  173. ultralytics/nn/autobackend.py +228 -93
  174. ultralytics/nn/modules/__init__.py +39 -14
  175. ultralytics/nn/modules/activation.py +21 -0
  176. ultralytics/nn/modules/block.py +527 -67
  177. ultralytics/nn/modules/conv.py +24 -7
  178. ultralytics/nn/modules/head.py +177 -34
  179. ultralytics/nn/modules/transformer.py +6 -5
  180. ultralytics/nn/modules/utils.py +1 -2
  181. ultralytics/nn/tasks.py +225 -77
  182. ultralytics/solutions/__init__.py +30 -1
  183. ultralytics/solutions/ai_gym.py +96 -143
  184. ultralytics/solutions/analytics.py +247 -0
  185. ultralytics/solutions/distance_calculation.py +78 -135
  186. ultralytics/solutions/heatmap.py +93 -247
  187. ultralytics/solutions/object_counter.py +184 -259
  188. ultralytics/solutions/parking_management.py +246 -0
  189. ultralytics/solutions/queue_management.py +112 -0
  190. ultralytics/solutions/region_counter.py +116 -0
  191. ultralytics/solutions/security_alarm.py +144 -0
  192. ultralytics/solutions/solutions.py +178 -0
  193. ultralytics/solutions/speed_estimation.py +86 -174
  194. ultralytics/solutions/streamlit_inference.py +190 -0
  195. ultralytics/solutions/trackzone.py +68 -0
  196. ultralytics/trackers/__init__.py +1 -1
  197. ultralytics/trackers/basetrack.py +32 -13
  198. ultralytics/trackers/bot_sort.py +61 -28
  199. ultralytics/trackers/byte_tracker.py +83 -51
  200. ultralytics/trackers/track.py +21 -6
  201. ultralytics/trackers/utils/__init__.py +1 -1
  202. ultralytics/trackers/utils/gmc.py +62 -48
  203. ultralytics/trackers/utils/kalman_filter.py +166 -35
  204. ultralytics/trackers/utils/matching.py +40 -21
  205. ultralytics/utils/__init__.py +511 -239
  206. ultralytics/utils/autobatch.py +40 -22
  207. ultralytics/utils/benchmarks.py +266 -85
  208. ultralytics/utils/callbacks/__init__.py +1 -1
  209. ultralytics/utils/callbacks/base.py +1 -3
  210. ultralytics/utils/callbacks/clearml.py +7 -6
  211. ultralytics/utils/callbacks/comet.py +39 -17
  212. ultralytics/utils/callbacks/dvc.py +1 -1
  213. ultralytics/utils/callbacks/hub.py +16 -16
  214. ultralytics/utils/callbacks/mlflow.py +28 -24
  215. ultralytics/utils/callbacks/neptune.py +6 -2
  216. ultralytics/utils/callbacks/raytune.py +3 -4
  217. ultralytics/utils/callbacks/tensorboard.py +18 -18
  218. ultralytics/utils/callbacks/wb.py +27 -20
  219. ultralytics/utils/checks.py +160 -100
  220. ultralytics/utils/dist.py +2 -1
  221. ultralytics/utils/downloads.py +44 -37
  222. ultralytics/utils/errors.py +1 -1
  223. ultralytics/utils/files.py +72 -38
  224. ultralytics/utils/instance.py +41 -19
  225. ultralytics/utils/loss.py +84 -56
  226. ultralytics/utils/metrics.py +61 -56
  227. ultralytics/utils/ops.py +94 -89
  228. ultralytics/utils/patches.py +30 -14
  229. ultralytics/utils/plotting.py +600 -269
  230. ultralytics/utils/tal.py +67 -26
  231. ultralytics/utils/torch_utils.py +302 -102
  232. ultralytics/utils/triton.py +2 -1
  233. ultralytics/utils/tuner.py +21 -12
  234. ultralytics-8.3.62.dist-info/METADATA +370 -0
  235. ultralytics-8.3.62.dist-info/RECORD +241 -0
  236. {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/WHEEL +1 -1
  237. ultralytics/data/explorer/__init__.py +0 -5
  238. ultralytics/data/explorer/explorer.py +0 -472
  239. ultralytics/data/explorer/gui/__init__.py +0 -1
  240. ultralytics/data/explorer/gui/dash.py +0 -268
  241. ultralytics/data/explorer/utils.py +0 -166
  242. ultralytics/models/fastsam/prompt.py +0 -357
  243. ultralytics-8.1.28.dist-info/METADATA +0 -373
  244. ultralytics-8.1.28.dist-info/RECORD +0 -197
  245. {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/LICENSE +0 -0
  246. {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/entry_points.txt +0 -0
  247. {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/top_level.txt +0 -0
@@ -1,30 +1,48 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
- from typing import Any, Optional, Tuple, Type
3
+ from typing import List, Optional, Tuple, Type
4
4
 
5
- import numpy as np
6
5
  import torch
7
6
  import torch.nn as nn
8
7
  import torch.nn.functional as F
9
8
 
10
- from ultralytics.nn.modules import LayerNorm2d, MLPBlock
9
+ from ultralytics.nn.modules import LayerNorm2d
10
+
11
+ from .blocks import (
12
+ Block,
13
+ CXBlock,
14
+ Fuser,
15
+ MaskDownSampler,
16
+ MultiScaleBlock,
17
+ PatchEmbed,
18
+ PositionEmbeddingRandom,
19
+ PositionEmbeddingSine,
20
+ )
11
21
 
12
22
 
13
23
  class ImageEncoderViT(nn.Module):
14
24
  """
15
- An image encoder using Vision Transformer (ViT) architecture for encoding an image into a compact latent space. The
16
- encoder takes an image, splits it into patches, and processes these patches through a series of transformer blocks.
17
- The encoded patches are then processed through a neck to generate the final encoded representation.
25
+ An image encoder using Vision Transformer (ViT) architecture for encoding images into a compact latent space.
18
26
 
19
- This class and its supporting functions below lightly adapted from the ViTDet backbone available at
20
- https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py.
27
+ This class processes images by splitting them into patches, applying transformer blocks, and generating a final
28
+ encoded representation through a neck module.
21
29
 
22
30
  Attributes:
23
31
  img_size (int): Dimension of input images, assumed to be square.
24
32
  patch_embed (PatchEmbed): Module for patch embedding.
25
- pos_embed (nn.Parameter, optional): Absolute positional embedding for patches.
33
+ pos_embed (nn.Parameter | None): Absolute positional embedding for patches.
26
34
  blocks (nn.ModuleList): List of transformer blocks for processing patch embeddings.
27
35
  neck (nn.Sequential): Neck module to further process the output.
36
+
37
+ Methods:
38
+ forward: Processes input through patch embedding, positional embedding, blocks, and neck.
39
+
40
+ Examples:
41
+ >>> import torch
42
+ >>> encoder = ImageEncoderViT(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12)
43
+ >>> input_image = torch.randn(1, 3, 224, 224)
44
+ >>> output = encoder(input_image)
45
+ >>> print(output.shape)
28
46
  """
29
47
 
30
48
  def __init__(
@@ -47,22 +65,38 @@ class ImageEncoderViT(nn.Module):
47
65
  global_attn_indexes: Tuple[int, ...] = (),
48
66
  ) -> None:
49
67
  """
68
+ Initializes an ImageEncoderViT instance for encoding images using Vision Transformer architecture.
69
+
50
70
  Args:
51
- img_size (int): Input image size.
52
- patch_size (int): Patch size.
71
+ img_size (int): Input image size, assumed to be square.
72
+ patch_size (int): Size of image patches.
53
73
  in_chans (int): Number of input image channels.
54
- embed_dim (int): Patch embedding dimension.
55
- depth (int): Depth of ViT.
56
- num_heads (int): Number of attention heads in each ViT block.
57
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
58
- qkv_bias (bool): If True, add a learnable bias to query, key, value.
59
- norm_layer (nn.Module): Normalization layer.
60
- act_layer (nn.Module): Activation layer.
61
- use_abs_pos (bool): If True, use absolute positional embeddings.
62
- use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
63
- rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
64
- window_size (int): Window size for window attention blocks.
65
- global_attn_indexes (list): Indexes for blocks using global attention.
74
+ embed_dim (int): Dimension of patch embeddings.
75
+ depth (int): Number of transformer blocks.
76
+ num_heads (int): Number of attention heads in each block.
77
+ mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
78
+ out_chans (int): Number of output channels from the neck module.
79
+ qkv_bias (bool): If True, adds learnable bias to query, key, value projections.
80
+ norm_layer (Type[nn.Module]): Type of normalization layer to use.
81
+ act_layer (Type[nn.Module]): Type of activation layer to use.
82
+ use_abs_pos (bool): If True, uses absolute positional embeddings.
83
+ use_rel_pos (bool): If True, adds relative positional embeddings to attention maps.
84
+ rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero.
85
+ window_size (int): Size of attention window for windowed attention blocks.
86
+ global_attn_indexes (Tuple[int, ...]): Indices of blocks that use global attention.
87
+
88
+ Attributes:
89
+ img_size (int): Dimension of input images.
90
+ patch_embed (PatchEmbed): Module for patch embedding.
91
+ pos_embed (nn.Parameter | None): Absolute positional embedding for patches.
92
+ blocks (nn.ModuleList): List of transformer blocks.
93
+ neck (nn.Sequential): Neck module for final processing.
94
+
95
+ Examples:
96
+ >>> encoder = ImageEncoderViT(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12)
97
+ >>> input_image = torch.randn(1, 3, 224, 224)
98
+ >>> output = encoder(input_image)
99
+ >>> print(output.shape)
66
100
  """
67
101
  super().__init__()
68
102
  self.img_size = img_size
@@ -114,12 +148,15 @@ class ImageEncoderViT(nn.Module):
114
148
  )
115
149
 
116
150
  def forward(self, x: torch.Tensor) -> torch.Tensor:
117
- """Processes input through patch embedding, applies positional embedding if present, and passes through blocks
118
- and neck.
119
- """
151
+ """Processes input through patch embedding, positional embedding, transformer blocks, and neck module."""
120
152
  x = self.patch_embed(x)
121
153
  if self.pos_embed is not None:
122
- x = x + self.pos_embed
154
+ pos_embed = (
155
+ F.interpolate(self.pos_embed.permute(0, 3, 1, 2), scale_factor=self.img_size / 1024).permute(0, 2, 3, 1)
156
+ if self.img_size != 1024
157
+ else self.pos_embed
158
+ )
159
+ x = x + pos_embed
123
160
  for blk in self.blocks:
124
161
  x = blk(x)
125
162
  return self.neck(x.permute(0, 3, 1, 2))
@@ -127,8 +164,7 @@ class ImageEncoderViT(nn.Module):
127
164
 
128
165
  class PromptEncoder(nn.Module):
129
166
  """
130
- Encodes different types of prompts, including points, boxes, and masks, for input to SAM's mask decoder. The encoder
131
- produces both sparse and dense embeddings for the input prompts.
167
+ Encodes different types of prompts for input to SAM's mask decoder, producing sparse and dense embeddings.
132
168
 
133
169
  Attributes:
134
170
  embed_dim (int): Dimension of the embeddings.
@@ -137,10 +173,23 @@ class PromptEncoder(nn.Module):
137
173
  pe_layer (PositionEmbeddingRandom): Module for random position embedding.
138
174
  num_point_embeddings (int): Number of point embeddings for different types of points.
139
175
  point_embeddings (nn.ModuleList): List of point embeddings.
140
- not_a_point_embed (nn.Embedding): Embedding for points that are not a part of any label.
176
+ not_a_point_embed (nn.Embedding): Embedding for points that are not part of any label.
141
177
  mask_input_size (Tuple[int, int]): Size of the input mask.
142
178
  mask_downscaling (nn.Sequential): Neural network for downscaling the mask.
143
179
  no_mask_embed (nn.Embedding): Embedding for cases where no mask is provided.
180
+
181
+ Methods:
182
+ get_dense_pe: Returns the positional encoding used to encode point prompts.
183
+ forward: Embeds different types of prompts, returning both sparse and dense embeddings.
184
+
185
+ Examples:
186
+ >>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
187
+ >>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5)))
188
+ >>> boxes = torch.rand(1, 2, 2)
189
+ >>> masks = torch.rand(1, 1, 256, 256)
190
+ >>> sparse_embeddings, dense_embeddings = prompt_encoder(points, boxes, masks)
191
+ >>> print(sparse_embeddings.shape, dense_embeddings.shape)
192
+ torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])
144
193
  """
145
194
 
146
195
  def __init__(
@@ -152,18 +201,37 @@ class PromptEncoder(nn.Module):
152
201
  activation: Type[nn.Module] = nn.GELU,
153
202
  ) -> None:
154
203
  """
155
- Encodes prompts for input to SAM's mask decoder.
204
+ Initializes the PromptEncoder module for encoding various types of prompts.
205
+
206
+ This module encodes different types of prompts (points, boxes, masks) for input to SAM's mask decoder,
207
+ producing both sparse and dense embeddings.
156
208
 
157
209
  Args:
158
- embed_dim (int): The prompts' embedding dimension
159
- image_embedding_size (tuple(int, int)): The spatial size of the
160
- image embedding, as (H, W).
161
- input_image_size (int): The padded size of the image as input
162
- to the image encoder, as (H, W).
163
- mask_in_chans (int): The number of hidden channels used for
164
- encoding input masks.
165
- activation (nn.Module): The activation to use when encoding
166
- input masks.
210
+ embed_dim (int): The dimension of the embeddings.
211
+ image_embedding_size (Tuple[int, int]): The spatial size of the image embedding as (H, W).
212
+ input_image_size (Tuple[int, int]): The padded size of the input image as (H, W).
213
+ mask_in_chans (int): The number of hidden channels used for encoding input masks.
214
+ activation (Type[nn.Module]): The activation function to use when encoding input masks.
215
+
216
+ Attributes:
217
+ embed_dim (int): Dimension of the embeddings.
218
+ input_image_size (Tuple[int, int]): Size of the input image as (H, W).
219
+ image_embedding_size (Tuple[int, int]): Spatial size of the image embedding as (H, W).
220
+ pe_layer (PositionEmbeddingRandom): Module for random position embedding.
221
+ num_point_embeddings (int): Number of point embeddings for different types of points.
222
+ point_embeddings (nn.ModuleList): List of point embeddings.
223
+ not_a_point_embed (nn.Embedding): Embedding for points that are not part of any label.
224
+ mask_input_size (Tuple[int, int]): Size of the input mask.
225
+ mask_downscaling (nn.Sequential): Neural network for downscaling the mask.
226
+
227
+ Examples:
228
+ >>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
229
+ >>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5)))
230
+ >>> boxes = torch.rand(1, 2, 2)
231
+ >>> masks = torch.rand(1, 1, 256, 256)
232
+ >>> sparse_embeddings, dense_embeddings = prompt_encoder(points, boxes, masks)
233
+ >>> print(sparse_embeddings.shape, dense_embeddings.shape)
234
+ torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])
167
235
  """
168
236
  super().__init__()
169
237
  self.embed_dim = embed_dim
@@ -190,16 +258,25 @@ class PromptEncoder(nn.Module):
190
258
 
191
259
  def get_dense_pe(self) -> torch.Tensor:
192
260
  """
193
- Returns the positional encoding used to encode point prompts, applied to a dense set of points the shape of the
194
- image encoding.
261
+ Returns the dense positional encoding used for encoding point prompts.
262
+
263
+ This method generates a positional encoding for a dense set of points matching the shape of the image
264
+ encoding. The encoding is used to provide spatial information to the model when processing point prompts.
195
265
 
196
266
  Returns:
197
- torch.Tensor: Positional encoding with shape 1x(embed_dim)x(embedding_h)x(embedding_w)
267
+ (torch.Tensor): Positional encoding tensor with shape (1, embed_dim, H, W), where H and W are the
268
+ height and width of the image embedding size, respectively.
269
+
270
+ Examples:
271
+ >>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
272
+ >>> dense_pe = prompt_encoder.get_dense_pe()
273
+ >>> print(dense_pe.shape)
274
+ torch.Size([1, 256, 64, 64])
198
275
  """
199
276
  return self.pe_layer(self.image_embedding_size).unsqueeze(0)
200
277
 
201
278
  def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
202
- """Embeds point prompts."""
279
+ """Embeds point prompts by applying positional encoding and label-specific embeddings."""
203
280
  points = points + 0.5 # Shift to center of pixel
204
281
  if pad:
205
282
  padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
@@ -211,10 +288,12 @@ class PromptEncoder(nn.Module):
211
288
  point_embedding[labels == -1] += self.not_a_point_embed.weight
212
289
  point_embedding[labels == 0] += self.point_embeddings[0].weight
213
290
  point_embedding[labels == 1] += self.point_embeddings[1].weight
291
+ point_embedding[labels == 2] += self.point_embeddings[2].weight
292
+ point_embedding[labels == 3] += self.point_embeddings[3].weight
214
293
  return point_embedding
215
294
 
216
295
  def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
217
- """Embeds box prompts."""
296
+ """Embeds box prompts by applying positional encoding and adding corner embeddings."""
218
297
  boxes = boxes + 0.5 # Shift to center of pixel
219
298
  coords = boxes.reshape(-1, 2, 2)
220
299
  corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
@@ -223,11 +302,11 @@ class PromptEncoder(nn.Module):
223
302
  return corner_embedding
224
303
 
225
304
  def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
226
- """Embeds mask inputs."""
305
+ """Embeds mask inputs by downscaling and processing through convolutional layers."""
227
306
  return self.mask_downscaling(masks)
228
307
 
308
+ @staticmethod
229
309
  def _get_batch_size(
230
- self,
231
310
  points: Optional[Tuple[torch.Tensor, torch.Tensor]],
232
311
  boxes: Optional[torch.Tensor],
233
312
  masks: Optional[torch.Tensor],
@@ -256,14 +335,25 @@ class PromptEncoder(nn.Module):
256
335
  Embeds different types of prompts, returning both sparse and dense embeddings.
257
336
 
258
337
  Args:
259
- points (tuple(torch.Tensor, torch.Tensor), None): point coordinates and labels to embed.
260
- boxes (torch.Tensor, None): boxes to embed
261
- masks (torch.Tensor, None): masks to embed
338
+ points (Tuple[torch.Tensor, torch.Tensor] | None): Point coordinates and labels to embed. The first
339
+ tensor contains coordinates with shape (B, N, 2), and the second tensor contains labels with
340
+ shape (B, N).
341
+ boxes (torch.Tensor | None): Boxes to embed with shape (B, M, 2, 2), where M is the number of boxes.
342
+ masks (torch.Tensor | None): Masks to embed with shape (B, 1, H, W).
262
343
 
263
344
  Returns:
264
- torch.Tensor: sparse embeddings for the points and boxes, with shape BxNx(embed_dim), where N is determined
265
- by the number of input points and boxes.
266
- torch.Tensor: dense embeddings for the masks, in the shape Bx(embed_dim)x(embed_H)x(embed_W)
345
+ (Tuple[torch.Tensor, torch.Tensor]): A tuple containing:
346
+ - sparse_embeddings (torch.Tensor): Sparse embeddings for points and boxes with shape (B, N, embed_dim).
347
+ - dense_embeddings (torch.Tensor): Dense embeddings for masks of shape (B, embed_dim, embed_H, embed_W).
348
+
349
+ Examples:
350
+ >>> encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
351
+ >>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5)))
352
+ >>> boxes = torch.rand(1, 2, 2, 2)
353
+ >>> masks = torch.rand(1, 1, 256, 256)
354
+ >>> sparse_emb, dense_emb = encoder(points, boxes, masks)
355
+ >>> print(sparse_emb.shape, dense_emb.shape)
356
+ torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])
267
357
  """
268
358
  bs = self._get_batch_size(points, boxes, masks)
269
359
  sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
@@ -285,319 +375,420 @@ class PromptEncoder(nn.Module):
285
375
  return sparse_embeddings, dense_embeddings
286
376
 
287
377
 
288
- class PositionEmbeddingRandom(nn.Module):
289
- """Positional encoding using random spatial frequencies."""
378
+ class MemoryEncoder(nn.Module):
379
+ """
380
+ Encodes pixel features and masks into a memory representation for efficient image segmentation.
290
381
 
291
- def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
292
- """Initializes a position embedding using random spatial frequencies."""
293
- super().__init__()
294
- if scale is None or scale <= 0.0:
295
- scale = 1.0
296
- self.register_buffer("positional_encoding_gaussian_matrix", scale * torch.randn((2, num_pos_feats)))
297
-
298
- # Set non-deterministic for forward() error 'cumsum_cuda_kernel does not have a deterministic implementation'
299
- torch.use_deterministic_algorithms(False)
300
- torch.backends.cudnn.deterministic = False
301
-
302
- def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
303
- """Positionally encode points that are normalized to [0,1]."""
304
- # Assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
305
- coords = 2 * coords - 1
306
- coords = coords @ self.positional_encoding_gaussian_matrix
307
- coords = 2 * np.pi * coords
308
- # Outputs d_1 x ... x d_n x C shape
309
- return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
310
-
311
- def forward(self, size: Tuple[int, int]) -> torch.Tensor:
312
- """Generate positional encoding for a grid of the specified size."""
313
- h, w = size
314
- device: Any = self.positional_encoding_gaussian_matrix.device
315
- grid = torch.ones((h, w), device=device, dtype=torch.float32)
316
- y_embed = grid.cumsum(dim=0) - 0.5
317
- x_embed = grid.cumsum(dim=1) - 0.5
318
- y_embed = y_embed / h
319
- x_embed = x_embed / w
320
-
321
- pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
322
- return pe.permute(2, 0, 1) # C x H x W
323
-
324
- def forward_with_coords(self, coords_input: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor:
325
- """Positionally encode points that are not normalized to [0,1]."""
326
- coords = coords_input.clone()
327
- coords[:, :, 0] = coords[:, :, 0] / image_size[1]
328
- coords[:, :, 1] = coords[:, :, 1] / image_size[0]
329
- return self._pe_encoding(coords.to(torch.float)) # B x N x C
330
-
331
-
332
- class Block(nn.Module):
333
- """Transformer blocks with support of window attention and residual propagation blocks."""
382
+ This class processes pixel-level features and masks, fusing them to generate encoded memory representations
383
+ suitable for downstream tasks in image segmentation models like SAM (Segment Anything Model).
384
+
385
+ Attributes:
386
+ mask_downsampler (MaskDownSampler): Module for downsampling input masks.
387
+ pix_feat_proj (nn.Conv2d): Convolutional layer for projecting pixel features.
388
+ fuser (Fuser): Module for fusing pixel features and masks.
389
+ position_encoding (PositionEmbeddingSine): Module for adding positional encoding to features.
390
+ out_proj (nn.Module): Output projection layer, either nn.Identity or nn.Conv2d.
391
+
392
+ Methods:
393
+ forward: Processes input pixel features and masks to generate encoded memory representations.
394
+
395
+ Examples:
396
+ >>> import torch
397
+ >>> encoder = MemoryEncoder(out_dim=256, in_dim=256)
398
+ >>> pix_feat = torch.randn(1, 256, 64, 64)
399
+ >>> masks = torch.randn(1, 1, 64, 64)
400
+ >>> encoded_feat, pos = encoder(pix_feat, masks)
401
+ >>> print(encoded_feat.shape, pos.shape)
402
+ torch.Size([1, 256, 64, 64]) torch.Size([1, 128, 64, 64])
403
+ """
334
404
 
335
405
  def __init__(
336
406
  self,
337
- dim: int,
338
- num_heads: int,
339
- mlp_ratio: float = 4.0,
340
- qkv_bias: bool = True,
341
- norm_layer: Type[nn.Module] = nn.LayerNorm,
342
- act_layer: Type[nn.Module] = nn.GELU,
343
- use_rel_pos: bool = False,
344
- rel_pos_zero_init: bool = True,
345
- window_size: int = 0,
346
- input_size: Optional[Tuple[int, int]] = None,
347
- ) -> None:
348
- """
349
- Args:
350
- dim (int): Number of input channels.
351
- num_heads (int): Number of attention heads in each ViT block.
352
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
353
- qkv_bias (bool): If True, add a learnable bias to query, key, value.
354
- norm_layer (nn.Module): Normalization layer.
355
- act_layer (nn.Module): Activation layer.
356
- use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
357
- rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
358
- window_size (int): Window size for window attention blocks. If it equals 0, then
359
- use global attention.
360
- input_size (tuple(int, int), None): Input resolution for calculating the relative
361
- positional parameter size.
362
- """
407
+ out_dim,
408
+ in_dim=256, # in_dim of pix_feats
409
+ ):
410
+ """Initializes the MemoryEncoder for encoding pixel features and masks into memory representations."""
363
411
  super().__init__()
364
- self.norm1 = norm_layer(dim)
365
- self.attn = Attention(
366
- dim,
367
- num_heads=num_heads,
368
- qkv_bias=qkv_bias,
369
- use_rel_pos=use_rel_pos,
370
- rel_pos_zero_init=rel_pos_zero_init,
371
- input_size=input_size if window_size == 0 else (window_size, window_size),
372
- )
373
412
 
374
- self.norm2 = norm_layer(dim)
375
- self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
413
+ self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1)
376
414
 
377
- self.window_size = window_size
415
+ self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
416
+ self.fuser = Fuser(CXBlock(dim=256), num_layers=2)
417
+ self.position_encoding = PositionEmbeddingSine(num_pos_feats=64)
418
+ self.out_proj = nn.Identity()
419
+ if out_dim != in_dim:
420
+ self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
378
421
 
379
- def forward(self, x: torch.Tensor) -> torch.Tensor:
380
- """Executes a forward pass through the transformer block with window attention and non-overlapping windows."""
381
- shortcut = x
382
- x = self.norm1(x)
383
- # Window partition
384
- if self.window_size > 0:
385
- H, W = x.shape[1], x.shape[2]
386
- x, pad_hw = window_partition(x, self.window_size)
422
+ def forward(
423
+ self,
424
+ pix_feat: torch.Tensor,
425
+ masks: torch.Tensor,
426
+ skip_mask_sigmoid: bool = False,
427
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
428
+ """Processes pixel features and masks to generate encoded memory representations for segmentation."""
429
+ if not skip_mask_sigmoid:
430
+ masks = F.sigmoid(masks)
431
+ masks = self.mask_downsampler(masks)
432
+
433
+ # Fuse pix_feats and downsampled masks, in case the visual features are on CPU, cast them to CUDA
434
+ pix_feat = pix_feat.to(masks.device)
387
435
 
388
- x = self.attn(x)
389
- # Reverse window partition
390
- if self.window_size > 0:
391
- x = window_unpartition(x, self.window_size, pad_hw, (H, W))
436
+ x = self.pix_feat_proj(pix_feat)
437
+ x = x + masks
438
+ x = self.fuser(x)
439
+ x = self.out_proj(x)
392
440
 
393
- x = shortcut + x
394
- return x + self.mlp(self.norm2(x))
441
+ pos = self.position_encoding(x).to(x.dtype)
395
442
 
443
+ return {"vision_features": x, "vision_pos_enc": [pos]}
396
444
 
397
- class Attention(nn.Module):
398
- """Multi-head Attention block with relative position embeddings."""
445
+
446
+ class ImageEncoder(nn.Module):
447
+ """
448
+ Encodes images using a trunk-neck architecture, producing multiscale features and positional encodings.
449
+
450
+ This class combines a trunk network for feature extraction with a neck network for feature refinement
451
+ and positional encoding generation. It can optionally discard the lowest resolution features.
452
+
453
+ Attributes:
454
+ trunk (nn.Module): The trunk network for initial feature extraction.
455
+ neck (nn.Module): The neck network for feature refinement and positional encoding generation.
456
+ scalp (int): Number of lowest resolution feature levels to discard.
457
+
458
+ Methods:
459
+ forward: Processes the input image through the trunk and neck networks.
460
+
461
+ Examples:
462
+ >>> trunk = SomeTrunkNetwork()
463
+ >>> neck = SomeNeckNetwork()
464
+ >>> encoder = ImageEncoder(trunk, neck, scalp=1)
465
+ >>> image = torch.randn(1, 3, 224, 224)
466
+ >>> output = encoder(image)
467
+ >>> print(output.keys())
468
+ dict_keys(['vision_features', 'vision_pos_enc', 'backbone_fpn'])
469
+ """
399
470
 
400
471
  def __init__(
401
472
  self,
402
- dim: int,
403
- num_heads: int = 8,
404
- qkv_bias: bool = True,
405
- use_rel_pos: bool = False,
406
- rel_pos_zero_init: bool = True,
407
- input_size: Optional[Tuple[int, int]] = None,
408
- ) -> None:
409
- """
410
- Initialize Attention module.
411
-
412
- Args:
413
- dim (int): Number of input channels.
414
- num_heads (int): Number of attention heads.
415
- qkv_bias (bool): If True, add a learnable bias to query, key, value.
416
- rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
417
- input_size (tuple(int, int), None): Input resolution for calculating the relative
418
- positional parameter size.
419
- """
473
+ trunk: nn.Module,
474
+ neck: nn.Module,
475
+ scalp: int = 0,
476
+ ):
477
+ """Initializes the ImageEncoder with trunk and neck networks for feature extraction and refinement."""
420
478
  super().__init__()
421
- self.num_heads = num_heads
422
- head_dim = dim // num_heads
423
- self.scale = head_dim**-0.5
479
+ self.trunk = trunk
480
+ self.neck = neck
481
+ self.scalp = scalp
482
+ assert self.trunk.channel_list == self.neck.backbone_channel_list, (
483
+ f"Channel dims of trunk {self.trunk.channel_list} and neck {self.neck.backbone_channel_list} do not match."
484
+ )
424
485
 
425
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
426
- self.proj = nn.Linear(dim, dim)
486
+ def forward(self, sample: torch.Tensor):
487
+ """Encodes input through patch embedding, positional embedding, transformer blocks, and neck module."""
488
+ features, pos = self.neck(self.trunk(sample))
489
+ if self.scalp > 0:
490
+ # Discard the lowest resolution features
491
+ features, pos = features[: -self.scalp], pos[: -self.scalp]
427
492
 
428
- self.use_rel_pos = use_rel_pos
429
- if self.use_rel_pos:
430
- assert input_size is not None, "Input size must be provided if using relative positional encoding."
431
- # Initialize relative positional embeddings
432
- self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
433
- self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
493
+ src = features[-1]
494
+ return {
495
+ "vision_features": src,
496
+ "vision_pos_enc": pos,
497
+ "backbone_fpn": features,
498
+ }
434
499
 
435
- def forward(self, x: torch.Tensor) -> torch.Tensor:
436
- """Applies the forward operation including attention, normalization, MLP, and indexing within window limits."""
437
- B, H, W, _ = x.shape
438
- # qkv with shape (3, B, nHead, H * W, C)
439
- qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
440
- # q, k, v with shape (B * nHead, H * W, C)
441
- q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
442
500
 
443
- attn = (q * self.scale) @ k.transpose(-2, -1)
501
+ class FpnNeck(nn.Module):
502
+ """
503
+ A Feature Pyramid Network (FPN) neck variant for multiscale feature fusion in object detection models.
444
504
 
445
- if self.use_rel_pos:
446
- attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
505
+ This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing,
506
+ similar to ViT positional embedding interpolation.
447
507
 
448
- attn = attn.softmax(dim=-1)
449
- x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
450
- return self.proj(x)
508
+ Attributes:
509
+ position_encoding (PositionEmbeddingSine): Sinusoidal positional encoding module.
510
+ convs (nn.ModuleList): List of convolutional layers for each backbone level.
511
+ backbone_channel_list (List[int]): List of channel dimensions from the backbone.
512
+ fpn_interp_model (str): Interpolation mode for FPN feature resizing.
513
+ fuse_type (str): Type of feature fusion, either 'sum' or 'avg'.
514
+ fpn_top_down_levels (List[int]): Levels to have top-down features in outputs.
515
+
516
+ Methods:
517
+ forward: Performs forward pass through the FPN neck.
518
+
519
+ Examples:
520
+ >>> backbone_channels = [64, 128, 256, 512]
521
+ >>> fpn_neck = FpnNeck(256, backbone_channels)
522
+ >>> inputs = [torch.rand(1, c, 32, 32) for c in backbone_channels]
523
+ >>> outputs, positions = fpn_neck(inputs)
524
+ >>> print(len(outputs), len(positions))
525
+ 4 4
526
+ """
451
527
 
528
+ def __init__(
529
+ self,
530
+ d_model: int,
531
+ backbone_channel_list: List[int],
532
+ kernel_size: int = 1,
533
+ stride: int = 1,
534
+ padding: int = 0,
535
+ fpn_interp_model: str = "bilinear",
536
+ fuse_type: str = "sum",
537
+ fpn_top_down_levels: Optional[List[int]] = None,
538
+ ):
539
+ """
540
+ Initializes a modified Feature Pyramid Network (FPN) neck.
452
541
 
453
- def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
454
- """
455
- Partition into non-overlapping windows with padding if needed.
456
- Args:
457
- x (tensor): input tokens with [B, H, W, C].
458
- window_size (int): window size.
459
-
460
- Returns:
461
- windows: windows after partition with [B * num_windows, window_size, window_size, C].
462
- (Hp, Wp): padded height and width before partition
463
- """
464
- B, H, W, C = x.shape
542
+ This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing,
543
+ similar to ViT positional embedding interpolation.
465
544
 
466
- pad_h = (window_size - H % window_size) % window_size
467
- pad_w = (window_size - W % window_size) % window_size
468
- if pad_h > 0 or pad_w > 0:
469
- x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
470
- Hp, Wp = H + pad_h, W + pad_w
545
+ Args:
546
+ d_model (int): Dimension of the model.
547
+ backbone_channel_list (List[int]): List of channel dimensions from the backbone.
548
+ kernel_size (int): Kernel size for the convolutional layers.
549
+ stride (int): Stride for the convolutional layers.
550
+ padding (int): Padding for the convolutional layers.
551
+ fpn_interp_model (str): Interpolation mode for FPN feature resizing.
552
+ fuse_type (str): Type of feature fusion, either 'sum' or 'avg'.
553
+ fpn_top_down_levels (Optional[List[int]]): Levels to have top-down features in outputs.
554
+
555
+ Examples:
556
+ >>> backbone_channels = [64, 128, 256, 512]
557
+ >>> fpn_neck = FpnNeck(256, backbone_channels)
558
+ >>> print(fpn_neck)
559
+ """
560
+ super().__init__()
561
+ self.position_encoding = PositionEmbeddingSine(num_pos_feats=256)
562
+ self.convs = nn.ModuleList()
563
+ self.backbone_channel_list = backbone_channel_list
564
+ for dim in backbone_channel_list:
565
+ current = nn.Sequential()
566
+ current.add_module(
567
+ "conv",
568
+ nn.Conv2d(
569
+ in_channels=dim,
570
+ out_channels=d_model,
571
+ kernel_size=kernel_size,
572
+ stride=stride,
573
+ padding=padding,
574
+ ),
575
+ )
576
+
577
+ self.convs.append(current)
578
+ self.fpn_interp_model = fpn_interp_model
579
+ assert fuse_type in {"sum", "avg"}
580
+ self.fuse_type = fuse_type
581
+
582
+ # levels to have top-down features in its outputs
583
+ # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
584
+ # have top-down propagation, while outputs of level 0 and level 1 have only
585
+ # lateral features from the same backbone level.
586
+ if fpn_top_down_levels is None:
587
+ # default is to have top-down features on all levels
588
+ fpn_top_down_levels = range(len(self.convs))
589
+ self.fpn_top_down_levels = list(fpn_top_down_levels)
590
+
591
+ def forward(self, xs: List[torch.Tensor]):
592
+ """
593
+ Performs forward pass through the Feature Pyramid Network (FPN) neck.
471
594
 
472
- x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
473
- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
474
- return windows, (Hp, Wp)
595
+ This method processes a list of input tensors from the backbone through the FPN, applying lateral connections
596
+ and top-down feature fusion. It generates output feature maps and corresponding positional encodings.
475
597
 
598
+ Args:
599
+ xs (List[torch.Tensor]): List of input tensors from the backbone, each with shape (B, C, H, W).
476
600
 
477
- def window_unpartition(
478
- windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
479
- ) -> torch.Tensor:
601
+ Returns:
602
+ (Tuple[List[torch.Tensor], List[torch.Tensor]]): A tuple containing:
603
+ - out (List[torch.Tensor]): List of output feature maps after FPN processing, each with shape
604
+ (B, d_model, H, W).
605
+ - pos (List[torch.Tensor]): List of positional encodings corresponding to each output feature map.
606
+
607
+ Examples:
608
+ >>> fpn_neck = FpnNeck(d_model=256, backbone_channel_list=[64, 128, 256, 512])
609
+ >>> inputs = [torch.rand(1, c, 32, 32) for c in [64, 128, 256, 512]]
610
+ >>> outputs, positions = fpn_neck(inputs)
611
+ >>> print(len(outputs), len(positions))
612
+ 4 4
613
+ """
614
+ out = [None] * len(self.convs)
615
+ pos = [None] * len(self.convs)
616
+ assert len(xs) == len(self.convs)
617
+ # fpn forward pass
618
+ # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
619
+ prev_features = None
620
+ # forward in top-down order (from low to high resolution)
621
+ n = len(self.convs) - 1
622
+ for i in range(n, -1, -1):
623
+ x = xs[i]
624
+ lateral_features = self.convs[n - i](x)
625
+ if i in self.fpn_top_down_levels and prev_features is not None:
626
+ top_down_features = F.interpolate(
627
+ prev_features.to(dtype=torch.float32),
628
+ scale_factor=2.0,
629
+ mode=self.fpn_interp_model,
630
+ align_corners=(None if self.fpn_interp_model == "nearest" else False),
631
+ antialias=False,
632
+ )
633
+ prev_features = lateral_features + top_down_features
634
+ if self.fuse_type == "avg":
635
+ prev_features /= 2
636
+ else:
637
+ prev_features = lateral_features
638
+ x_out = prev_features
639
+ out[i] = x_out
640
+ pos[i] = self.position_encoding(x_out).to(x_out.dtype)
641
+
642
+ return out, pos
643
+
644
+
645
+ class Hiera(nn.Module):
480
646
  """
481
- Window unpartition into original sequences and removing padding.
647
+ Hierarchical vision transformer for efficient multiscale feature extraction in image processing tasks.
482
648
 
483
- Args:
484
- windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
485
- window_size (int): window size.
486
- pad_hw (Tuple): padded height and width (Hp, Wp).
487
- hw (Tuple): original height and width (H, W) before padding.
649
+ This class implements a Hiera model, which is a hierarchical vision transformer architecture designed for
650
+ efficient multiscale feature extraction. It uses a series of transformer blocks organized into stages,
651
+ with optional pooling and global attention mechanisms.
488
652
 
489
- Returns:
490
- x: unpartitioned sequences with [B, H, W, C].
653
+ Attributes:
654
+ window_spec (Tuple[int, ...]): Window sizes for each stage.
655
+ q_stride (Tuple[int, int]): Downsampling stride between stages.
656
+ stage_ends (List[int]): Indices of the last block in each stage.
657
+ q_pool_blocks (List[int]): Indices of blocks where pooling is applied.
658
+ return_interm_layers (bool): Whether to return intermediate layer outputs.
659
+ patch_embed (PatchEmbed): Module for patch embedding.
660
+ global_att_blocks (Tuple[int, ...]): Indices of blocks with global attention.
661
+ window_pos_embed_bkg_spatial_size (Tuple[int, int]): Spatial size for window positional embedding background.
662
+ pos_embed (nn.Parameter): Positional embedding for the background.
663
+ pos_embed_window (nn.Parameter): Positional embedding for the window.
664
+ blocks (nn.ModuleList): List of MultiScaleBlock modules.
665
+ channel_list (List[int]): List of output channel dimensions for each stage.
666
+
667
+ Methods:
668
+ _get_pos_embed: Generates positional embeddings by interpolating and combining window and background embeddings.
669
+ forward: Performs the forward pass through the Hiera model.
670
+
671
+ Examples:
672
+ >>> model = Hiera(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3))
673
+ >>> input_tensor = torch.randn(1, 3, 224, 224)
674
+ >>> output_features = model(input_tensor)
675
+ >>> for feat in output_features:
676
+ ... print(feat.shape)
491
677
  """
492
- Hp, Wp = pad_hw
493
- H, W = hw
494
- B = windows.shape[0] // (Hp * Wp // window_size // window_size)
495
- x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
496
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
497
-
498
- if Hp > H or Wp > W:
499
- x = x[:, :H, :W, :].contiguous()
500
- return x
501
678
 
679
+ def __init__(
680
+ self,
681
+ embed_dim: int = 96, # initial embed dim
682
+ num_heads: int = 1, # initial number of heads
683
+ drop_path_rate: float = 0.0, # stochastic depth
684
+ q_pool: int = 3, # number of q_pool stages
685
+ q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
686
+ stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
687
+ dim_mul: float = 2.0, # dim_mul factor at stage shift
688
+ head_mul: float = 2.0, # head_mul factor at stage shift
689
+ window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
690
+ # window size per stage, when not using global att.
691
+ window_spec: Tuple[int, ...] = (
692
+ 8,
693
+ 4,
694
+ 14,
695
+ 7,
696
+ ),
697
+ # global attn in these blocks
698
+ global_att_blocks: Tuple[int, ...] = (
699
+ 12,
700
+ 16,
701
+ 20,
702
+ ),
703
+ return_interm_layers=True, # return feats from every stage
704
+ ):
705
+ """Initializes the Hiera model, configuring its hierarchical vision transformer architecture."""
706
+ super().__init__()
502
707
 
503
- def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
504
- """
505
- Get relative positional embeddings according to the relative positions of query and key sizes.
708
+ assert len(stages) == len(window_spec)
709
+ self.window_spec = window_spec
506
710
 
507
- Args:
508
- q_size (int): size of query q.
509
- k_size (int): size of key k.
510
- rel_pos (Tensor): relative position embeddings (L, C).
711
+ depth = sum(stages)
712
+ self.q_stride = q_stride
713
+ self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
714
+ assert 0 <= q_pool <= len(self.stage_ends[:-1])
715
+ self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
716
+ self.return_interm_layers = return_interm_layers
511
717
 
512
- Returns:
513
- Extracted positional embeddings according to relative positions.
514
- """
515
- max_rel_dist = int(2 * max(q_size, k_size) - 1)
516
- # Interpolate rel pos if needed.
517
- if rel_pos.shape[0] != max_rel_dist:
518
- # Interpolate rel pos.
519
- rel_pos_resized = F.interpolate(
520
- rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
521
- size=max_rel_dist,
522
- mode="linear",
718
+ self.patch_embed = PatchEmbed(
719
+ embed_dim=embed_dim,
720
+ kernel_size=(7, 7),
721
+ stride=(4, 4),
722
+ padding=(3, 3),
523
723
  )
524
- rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
525
- else:
526
- rel_pos_resized = rel_pos
527
-
528
- # Scale the coords with short length if shapes for q and k are different.
529
- q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
530
- k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
531
- relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
532
-
533
- return rel_pos_resized[relative_coords.long()]
534
-
535
-
536
- def add_decomposed_rel_pos(
537
- attn: torch.Tensor,
538
- q: torch.Tensor,
539
- rel_pos_h: torch.Tensor,
540
- rel_pos_w: torch.Tensor,
541
- q_size: Tuple[int, int],
542
- k_size: Tuple[int, int],
543
- ) -> torch.Tensor:
544
- """
545
- Calculate decomposed Relative Positional Embeddings from mvitv2 paper at
546
- https://github.com/facebookresearch/mvit/blob/main/mvit/models/attention.py.
547
-
548
- Args:
549
- attn (Tensor): attention map.
550
- q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
551
- rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
552
- rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
553
- q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
554
- k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
555
-
556
- Returns:
557
- attn (Tensor): attention map with added relative positional embeddings.
558
- """
559
- q_h, q_w = q_size
560
- k_h, k_w = k_size
561
- Rh = get_rel_pos(q_h, k_h, rel_pos_h)
562
- Rw = get_rel_pos(q_w, k_w, rel_pos_w)
724
+ # Which blocks have global att?
725
+ self.global_att_blocks = global_att_blocks
563
726
 
564
- B, _, dim = q.shape
565
- r_q = q.reshape(B, q_h, q_w, dim)
566
- rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
567
- rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
727
+ # Windowed positional embedding (https://arxiv.org/abs/2311.05613)
728
+ self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
729
+ self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size))
730
+ self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]))
568
731
 
569
- attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
570
- B, q_h * q_w, k_h * k_w
571
- )
732
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
572
733
 
573
- return attn
734
+ cur_stage = 1
735
+ self.blocks = nn.ModuleList()
574
736
 
737
+ for i in range(depth):
738
+ dim_out = embed_dim
739
+ # lags by a block, so first block of
740
+ # next stage uses an initial window size
741
+ # of previous stage and final window size of current stage
742
+ window_size = self.window_spec[cur_stage - 1]
575
743
 
576
- class PatchEmbed(nn.Module):
577
- """Image to Patch Embedding."""
744
+ if self.global_att_blocks is not None:
745
+ window_size = 0 if i in self.global_att_blocks else window_size
578
746
 
579
- def __init__(
580
- self,
581
- kernel_size: Tuple[int, int] = (16, 16),
582
- stride: Tuple[int, int] = (16, 16),
583
- padding: Tuple[int, int] = (0, 0),
584
- in_chans: int = 3,
585
- embed_dim: int = 768,
586
- ) -> None:
587
- """
588
- Initialize PatchEmbed module.
747
+ if i - 1 in self.stage_ends:
748
+ dim_out = int(embed_dim * dim_mul)
749
+ num_heads = int(num_heads * head_mul)
750
+ cur_stage += 1
589
751
 
590
- Args:
591
- kernel_size (Tuple): kernel size of the projection layer.
592
- stride (Tuple): stride of the projection layer.
593
- padding (Tuple): padding size of the projection layer.
594
- in_chans (int): Number of input image channels.
595
- embed_dim (int): Patch embedding dimension.
596
- """
597
- super().__init__()
752
+ block = MultiScaleBlock(
753
+ dim=embed_dim,
754
+ dim_out=dim_out,
755
+ num_heads=num_heads,
756
+ drop_path=dpr[i],
757
+ q_stride=self.q_stride if i in self.q_pool_blocks else None,
758
+ window_size=window_size,
759
+ )
598
760
 
599
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
761
+ embed_dim = dim_out
762
+ self.blocks.append(block)
600
763
 
601
- def forward(self, x: torch.Tensor) -> torch.Tensor:
602
- """Computes patch embedding by applying convolution and transposing resulting tensor."""
603
- return self.proj(x).permute(0, 2, 3, 1) # B C H W -> B H W C
764
+ self.channel_list = (
765
+ [self.blocks[i].dim_out for i in self.stage_ends[::-1]]
766
+ if return_interm_layers
767
+ else [self.blocks[-1].dim_out]
768
+ )
769
+
770
+ def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
771
+ """Generates positional embeddings by interpolating and combining window and background embeddings."""
772
+ h, w = hw
773
+ window_embed = self.pos_embed_window
774
+ pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
775
+ pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)])
776
+ pos_embed = pos_embed.permute(0, 2, 3, 1)
777
+ return pos_embed
778
+
779
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
780
+ """Performs forward pass through Hiera model, extracting multiscale features from input images."""
781
+ x = self.patch_embed(x)
782
+ # x: (B, H, W, C)
783
+
784
+ # Add pos embed
785
+ x = x + self._get_pos_embed(x.shape[1:3])
786
+
787
+ outputs = []
788
+ for i, blk in enumerate(self.blocks):
789
+ x = blk(x)
790
+ if (i == self.stage_ends[-1]) or (i in self.stage_ends and self.return_interm_layers):
791
+ feats = x.permute(0, 3, 1, 2)
792
+ outputs.append(feats)
793
+
794
+ return outputs