dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
  2. dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
  3. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -6
  5. tests/conftest.py +15 -39
  6. tests/test_cli.py +17 -17
  7. tests/test_cuda.py +17 -8
  8. tests/test_engine.py +36 -10
  9. tests/test_exports.py +98 -37
  10. tests/test_integrations.py +12 -15
  11. tests/test_python.py +126 -82
  12. tests/test_solutions.py +319 -135
  13. ultralytics/__init__.py +27 -9
  14. ultralytics/cfg/__init__.py +83 -87
  15. ultralytics/cfg/datasets/Argoverse.yaml +4 -4
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
  17. ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
  18. ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
  19. ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
  20. ultralytics/cfg/datasets/ImageNet.yaml +3 -3
  21. ultralytics/cfg/datasets/Objects365.yaml +24 -20
  22. ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
  23. ultralytics/cfg/datasets/VOC.yaml +10 -13
  24. ultralytics/cfg/datasets/VisDrone.yaml +43 -33
  25. ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
  26. ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
  27. ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
  28. ultralytics/cfg/datasets/coco-pose.yaml +26 -4
  29. ultralytics/cfg/datasets/coco.yaml +4 -4
  30. ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
  31. ultralytics/cfg/datasets/coco128.yaml +2 -2
  32. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  33. ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
  34. ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
  35. ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
  36. ultralytics/cfg/datasets/coco8.yaml +2 -2
  37. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  38. ultralytics/cfg/datasets/crack-seg.yaml +5 -5
  39. ultralytics/cfg/datasets/dog-pose.yaml +32 -4
  40. ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
  41. ultralytics/cfg/datasets/dota8.yaml +2 -2
  42. ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
  43. ultralytics/cfg/datasets/lvis.yaml +9 -9
  44. ultralytics/cfg/datasets/medical-pills.yaml +4 -5
  45. ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
  46. ultralytics/cfg/datasets/package-seg.yaml +5 -5
  47. ultralytics/cfg/datasets/signature.yaml +4 -4
  48. ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
  49. ultralytics/cfg/datasets/xView.yaml +5 -5
  50. ultralytics/cfg/default.yaml +96 -93
  51. ultralytics/cfg/trackers/botsort.yaml +16 -17
  52. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  53. ultralytics/data/__init__.py +4 -4
  54. ultralytics/data/annotator.py +12 -12
  55. ultralytics/data/augment.py +531 -564
  56. ultralytics/data/base.py +76 -81
  57. ultralytics/data/build.py +206 -42
  58. ultralytics/data/converter.py +179 -78
  59. ultralytics/data/dataset.py +121 -121
  60. ultralytics/data/loaders.py +114 -91
  61. ultralytics/data/split.py +28 -15
  62. ultralytics/data/split_dota.py +67 -48
  63. ultralytics/data/utils.py +110 -89
  64. ultralytics/engine/exporter.py +422 -460
  65. ultralytics/engine/model.py +224 -252
  66. ultralytics/engine/predictor.py +94 -89
  67. ultralytics/engine/results.py +345 -595
  68. ultralytics/engine/trainer.py +231 -134
  69. ultralytics/engine/tuner.py +279 -73
  70. ultralytics/engine/validator.py +53 -46
  71. ultralytics/hub/__init__.py +26 -28
  72. ultralytics/hub/auth.py +30 -16
  73. ultralytics/hub/google/__init__.py +34 -36
  74. ultralytics/hub/session.py +53 -77
  75. ultralytics/hub/utils.py +23 -109
  76. ultralytics/models/__init__.py +1 -1
  77. ultralytics/models/fastsam/__init__.py +1 -1
  78. ultralytics/models/fastsam/model.py +36 -18
  79. ultralytics/models/fastsam/predict.py +33 -44
  80. ultralytics/models/fastsam/utils.py +4 -5
  81. ultralytics/models/fastsam/val.py +12 -14
  82. ultralytics/models/nas/__init__.py +1 -1
  83. ultralytics/models/nas/model.py +16 -20
  84. ultralytics/models/nas/predict.py +12 -14
  85. ultralytics/models/nas/val.py +4 -5
  86. ultralytics/models/rtdetr/__init__.py +1 -1
  87. ultralytics/models/rtdetr/model.py +9 -9
  88. ultralytics/models/rtdetr/predict.py +22 -17
  89. ultralytics/models/rtdetr/train.py +20 -16
  90. ultralytics/models/rtdetr/val.py +79 -59
  91. ultralytics/models/sam/__init__.py +8 -2
  92. ultralytics/models/sam/amg.py +53 -38
  93. ultralytics/models/sam/build.py +29 -31
  94. ultralytics/models/sam/model.py +33 -38
  95. ultralytics/models/sam/modules/blocks.py +159 -182
  96. ultralytics/models/sam/modules/decoders.py +38 -47
  97. ultralytics/models/sam/modules/encoders.py +114 -133
  98. ultralytics/models/sam/modules/memory_attention.py +38 -31
  99. ultralytics/models/sam/modules/sam.py +114 -93
  100. ultralytics/models/sam/modules/tiny_encoder.py +268 -291
  101. ultralytics/models/sam/modules/transformer.py +59 -66
  102. ultralytics/models/sam/modules/utils.py +55 -72
  103. ultralytics/models/sam/predict.py +745 -341
  104. ultralytics/models/utils/loss.py +118 -107
  105. ultralytics/models/utils/ops.py +118 -71
  106. ultralytics/models/yolo/__init__.py +1 -1
  107. ultralytics/models/yolo/classify/predict.py +28 -26
  108. ultralytics/models/yolo/classify/train.py +50 -81
  109. ultralytics/models/yolo/classify/val.py +68 -61
  110. ultralytics/models/yolo/detect/predict.py +12 -15
  111. ultralytics/models/yolo/detect/train.py +56 -46
  112. ultralytics/models/yolo/detect/val.py +279 -223
  113. ultralytics/models/yolo/model.py +167 -86
  114. ultralytics/models/yolo/obb/predict.py +7 -11
  115. ultralytics/models/yolo/obb/train.py +23 -25
  116. ultralytics/models/yolo/obb/val.py +107 -99
  117. ultralytics/models/yolo/pose/__init__.py +1 -1
  118. ultralytics/models/yolo/pose/predict.py +12 -14
  119. ultralytics/models/yolo/pose/train.py +31 -69
  120. ultralytics/models/yolo/pose/val.py +119 -254
  121. ultralytics/models/yolo/segment/predict.py +21 -25
  122. ultralytics/models/yolo/segment/train.py +12 -66
  123. ultralytics/models/yolo/segment/val.py +126 -305
  124. ultralytics/models/yolo/world/train.py +53 -45
  125. ultralytics/models/yolo/world/train_world.py +51 -32
  126. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  127. ultralytics/models/yolo/yoloe/predict.py +30 -37
  128. ultralytics/models/yolo/yoloe/train.py +89 -71
  129. ultralytics/models/yolo/yoloe/train_seg.py +15 -17
  130. ultralytics/models/yolo/yoloe/val.py +56 -41
  131. ultralytics/nn/__init__.py +9 -11
  132. ultralytics/nn/autobackend.py +179 -107
  133. ultralytics/nn/modules/__init__.py +67 -67
  134. ultralytics/nn/modules/activation.py +8 -7
  135. ultralytics/nn/modules/block.py +302 -323
  136. ultralytics/nn/modules/conv.py +61 -104
  137. ultralytics/nn/modules/head.py +488 -186
  138. ultralytics/nn/modules/transformer.py +183 -123
  139. ultralytics/nn/modules/utils.py +15 -20
  140. ultralytics/nn/tasks.py +327 -203
  141. ultralytics/nn/text_model.py +81 -65
  142. ultralytics/py.typed +1 -0
  143. ultralytics/solutions/__init__.py +12 -12
  144. ultralytics/solutions/ai_gym.py +19 -27
  145. ultralytics/solutions/analytics.py +36 -26
  146. ultralytics/solutions/config.py +29 -28
  147. ultralytics/solutions/distance_calculation.py +23 -24
  148. ultralytics/solutions/heatmap.py +17 -19
  149. ultralytics/solutions/instance_segmentation.py +21 -19
  150. ultralytics/solutions/object_blurrer.py +16 -17
  151. ultralytics/solutions/object_counter.py +48 -53
  152. ultralytics/solutions/object_cropper.py +22 -16
  153. ultralytics/solutions/parking_management.py +61 -58
  154. ultralytics/solutions/queue_management.py +19 -19
  155. ultralytics/solutions/region_counter.py +63 -50
  156. ultralytics/solutions/security_alarm.py +22 -25
  157. ultralytics/solutions/similarity_search.py +107 -60
  158. ultralytics/solutions/solutions.py +343 -262
  159. ultralytics/solutions/speed_estimation.py +35 -31
  160. ultralytics/solutions/streamlit_inference.py +104 -40
  161. ultralytics/solutions/templates/similarity-search.html +31 -24
  162. ultralytics/solutions/trackzone.py +24 -24
  163. ultralytics/solutions/vision_eye.py +11 -12
  164. ultralytics/trackers/__init__.py +1 -1
  165. ultralytics/trackers/basetrack.py +18 -27
  166. ultralytics/trackers/bot_sort.py +48 -39
  167. ultralytics/trackers/byte_tracker.py +94 -94
  168. ultralytics/trackers/track.py +7 -16
  169. ultralytics/trackers/utils/gmc.py +37 -69
  170. ultralytics/trackers/utils/kalman_filter.py +68 -76
  171. ultralytics/trackers/utils/matching.py +13 -17
  172. ultralytics/utils/__init__.py +251 -275
  173. ultralytics/utils/autobatch.py +19 -7
  174. ultralytics/utils/autodevice.py +68 -38
  175. ultralytics/utils/benchmarks.py +169 -130
  176. ultralytics/utils/callbacks/base.py +12 -13
  177. ultralytics/utils/callbacks/clearml.py +14 -15
  178. ultralytics/utils/callbacks/comet.py +139 -66
  179. ultralytics/utils/callbacks/dvc.py +19 -27
  180. ultralytics/utils/callbacks/hub.py +8 -6
  181. ultralytics/utils/callbacks/mlflow.py +6 -10
  182. ultralytics/utils/callbacks/neptune.py +11 -19
  183. ultralytics/utils/callbacks/platform.py +73 -0
  184. ultralytics/utils/callbacks/raytune.py +3 -4
  185. ultralytics/utils/callbacks/tensorboard.py +9 -12
  186. ultralytics/utils/callbacks/wb.py +33 -30
  187. ultralytics/utils/checks.py +163 -114
  188. ultralytics/utils/cpu.py +89 -0
  189. ultralytics/utils/dist.py +24 -20
  190. ultralytics/utils/downloads.py +176 -146
  191. ultralytics/utils/errors.py +11 -13
  192. ultralytics/utils/events.py +113 -0
  193. ultralytics/utils/export/__init__.py +7 -0
  194. ultralytics/utils/{export.py → export/engine.py} +81 -63
  195. ultralytics/utils/export/imx.py +294 -0
  196. ultralytics/utils/export/tensorflow.py +217 -0
  197. ultralytics/utils/files.py +33 -36
  198. ultralytics/utils/git.py +137 -0
  199. ultralytics/utils/instance.py +105 -120
  200. ultralytics/utils/logger.py +404 -0
  201. ultralytics/utils/loss.py +99 -61
  202. ultralytics/utils/metrics.py +649 -478
  203. ultralytics/utils/nms.py +337 -0
  204. ultralytics/utils/ops.py +263 -451
  205. ultralytics/utils/patches.py +70 -31
  206. ultralytics/utils/plotting.py +253 -223
  207. ultralytics/utils/tal.py +48 -61
  208. ultralytics/utils/torch_utils.py +244 -251
  209. ultralytics/utils/tqdm.py +438 -0
  210. ultralytics/utils/triton.py +22 -23
  211. ultralytics/utils/tuner.py +11 -10
  212. dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
  213. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
  214. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
  215. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,9 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+ from __future__ import annotations
2
3
 
3
4
  import copy
4
5
  import math
5
6
  from functools import partial
6
- from typing import Any, Optional, Tuple, Type, Union
7
7
 
8
8
  import numpy as np
9
9
  import torch
@@ -17,8 +17,7 @@ from .utils import add_decomposed_rel_pos, apply_rotary_enc, compute_axial_cis,
17
17
 
18
18
 
19
19
  class DropPath(nn.Module):
20
- """
21
- Implements stochastic depth regularization for neural networks during training.
20
+ """Implements stochastic depth regularization for neural networks during training.
22
21
 
23
22
  Attributes:
24
23
  drop_prob (float): Probability of dropping a path during training.
@@ -33,14 +32,14 @@ class DropPath(nn.Module):
33
32
  >>> output = drop_path(x)
34
33
  """
35
34
 
36
- def __init__(self, drop_prob=0.0, scale_by_keep=True):
35
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
37
36
  """Initialize DropPath module for stochastic depth regularization during training."""
38
37
  super().__init__()
39
38
  self.drop_prob = drop_prob
40
39
  self.scale_by_keep = scale_by_keep
41
40
 
42
- def forward(self, x):
43
- """Applies stochastic depth to input tensor during training, with optional scaling."""
41
+ def forward(self, x: Tensor) -> Tensor:
42
+ """Apply stochastic depth to input tensor during training, with optional scaling."""
44
43
  if self.drop_prob == 0.0 or not self.training:
45
44
  return x
46
45
  keep_prob = 1 - self.drop_prob
@@ -52,16 +51,14 @@ class DropPath(nn.Module):
52
51
 
53
52
 
54
53
  class MaskDownSampler(nn.Module):
55
- """
56
- A mask downsampling and embedding module for efficient processing of input masks.
54
+ """A mask downsampling and embedding module for efficient processing of input masks.
57
55
 
58
- This class implements a mask downsampler that progressively reduces the spatial dimensions of input masks
59
- while expanding their channel dimensions using convolutional layers, layer normalization, and activation
60
- functions.
56
+ This class implements a mask downsampler that progressively reduces the spatial dimensions of input masks while
57
+ expanding their channel dimensions using convolutional layers, layer normalization, and activation functions.
61
58
 
62
59
  Attributes:
63
- encoder (nn.Sequential): A sequential container of convolutional layers, layer normalization, and
64
- activation functions for downsampling and embedding masks.
60
+ encoder (nn.Sequential): A sequential container of convolutional layers, layer normalization, and activation
61
+ functions for downsampling and embedding masks.
65
62
 
66
63
  Methods:
67
64
  forward: Downsamples and encodes input mask to embed_dim channels.
@@ -76,14 +73,14 @@ class MaskDownSampler(nn.Module):
76
73
 
77
74
  def __init__(
78
75
  self,
79
- embed_dim=256,
80
- kernel_size=4,
81
- stride=4,
82
- padding=0,
83
- total_stride=16,
84
- activation=nn.GELU,
76
+ embed_dim: int = 256,
77
+ kernel_size: int = 4,
78
+ stride: int = 4,
79
+ padding: int = 0,
80
+ total_stride: int = 16,
81
+ activation: type[nn.Module] = nn.GELU,
85
82
  ):
86
- """Initializes a mask downsampler module for progressive downsampling and channel expansion."""
83
+ """Initialize a mask downsampler module for progressive downsampling and channel expansion."""
87
84
  super().__init__()
88
85
  num_layers = int(math.log2(total_stride) // math.log2(stride))
89
86
  assert stride**num_layers == total_stride
@@ -106,17 +103,16 @@ class MaskDownSampler(nn.Module):
106
103
 
107
104
  self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
108
105
 
109
- def forward(self, x):
110
- """Downsamples and encodes input mask to embed_dim channels using convolutional layers and LayerNorm2d."""
106
+ def forward(self, x: Tensor) -> Tensor:
107
+ """Downsample and encode input mask to embed_dim channels using convolutional layers and LayerNorm2d."""
111
108
  return self.encoder(x)
112
109
 
113
110
 
114
111
  class CXBlock(nn.Module):
115
- """
116
- ConvNeXt Block for efficient feature extraction in convolutional neural networks.
112
+ """ConvNeXt Block for efficient feature extraction in convolutional neural networks.
117
113
 
118
- This block implements a modified version of the ConvNeXt architecture, offering improved performance and
119
- flexibility in feature extraction.
114
+ This block implements a modified version of the ConvNeXt architecture, offering improved performance and flexibility
115
+ in feature extraction.
120
116
 
121
117
  Attributes:
122
118
  dwconv (nn.Conv2d): Depthwise or standard 2D convolution layer.
@@ -141,15 +137,14 @@ class CXBlock(nn.Module):
141
137
 
142
138
  def __init__(
143
139
  self,
144
- dim,
145
- kernel_size=7,
146
- padding=3,
147
- drop_path=0.0,
148
- layer_scale_init_value=1e-6,
149
- use_dwconv=True,
140
+ dim: int,
141
+ kernel_size: int = 7,
142
+ padding: int = 3,
143
+ drop_path: float = 0.0,
144
+ layer_scale_init_value: float = 1e-6,
145
+ use_dwconv: bool = True,
150
146
  ):
151
- """
152
- Initialize a ConvNeXt Block for efficient feature extraction in convolutional neural networks.
147
+ """Initialize a ConvNeXt Block for efficient feature extraction in convolutional neural networks.
153
148
 
154
149
  This block implements a modified version of the ConvNeXt architecture, offering improved performance and
155
150
  flexibility in feature extraction.
@@ -188,8 +183,8 @@ class CXBlock(nn.Module):
188
183
  )
189
184
  self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
190
185
 
191
- def forward(self, x):
192
- """Applies ConvNeXt block operations to input tensor, including convolutions and residual connection."""
186
+ def forward(self, x: Tensor) -> Tensor:
187
+ """Apply ConvNeXt block operations to input tensor, including convolutions and residual connection."""
193
188
  input = x
194
189
  x = self.dwconv(x)
195
190
  x = self.norm(x)
@@ -206,8 +201,7 @@ class CXBlock(nn.Module):
206
201
 
207
202
 
208
203
  class Fuser(nn.Module):
209
- """
210
- A module for fusing features through multiple layers of a neural network.
204
+ """A module for fusing features through multiple layers of a neural network.
211
205
 
212
206
  This class applies a series of identical layers to an input tensor, optionally projecting the input first.
213
207
 
@@ -227,9 +221,8 @@ class Fuser(nn.Module):
227
221
  torch.Size([1, 256, 32, 32])
228
222
  """
229
223
 
230
- def __init__(self, layer, num_layers, dim=None, input_projection=False):
231
- """
232
- Initializes the Fuser module for feature fusion through multiple layers.
224
+ def __init__(self, layer: nn.Module, num_layers: int, dim: int | None = None, input_projection: bool = False):
225
+ """Initialize the Fuser module for feature fusion through multiple layers.
233
226
 
234
227
  This module creates a sequence of identical layers and optionally applies an input projection.
235
228
 
@@ -253,8 +246,8 @@ class Fuser(nn.Module):
253
246
  assert dim is not None
254
247
  self.proj = nn.Conv2d(dim, dim, kernel_size=1)
255
248
 
256
- def forward(self, x):
257
- """Applies a series of layers to the input tensor, optionally projecting it first."""
249
+ def forward(self, x: Tensor) -> Tensor:
250
+ """Apply a series of layers to the input tensor, optionally projecting it first."""
258
251
  x = self.proj(x)
259
252
  for layer in self.layers:
260
253
  x = layer(x)
@@ -262,12 +255,11 @@ class Fuser(nn.Module):
262
255
 
263
256
 
264
257
  class SAM2TwoWayAttentionBlock(TwoWayAttentionBlock):
265
- """
266
- A two-way attention block for performing self-attention and cross-attention in both directions.
258
+ """A two-way attention block for performing self-attention and cross-attention in both directions.
267
259
 
268
- This block extends the TwoWayAttentionBlock and consists of four main components: self-attention on
269
- sparse inputs, cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and
270
- cross-attention from dense to sparse inputs.
260
+ This block extends the TwoWayAttentionBlock and consists of four main components: self-attention on sparse inputs,
261
+ cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and cross-attention from dense to sparse
262
+ inputs.
271
263
 
272
264
  Attributes:
273
265
  self_attn (Attention): Self-attention layer for queries.
@@ -295,16 +287,15 @@ class SAM2TwoWayAttentionBlock(TwoWayAttentionBlock):
295
287
  embedding_dim: int,
296
288
  num_heads: int,
297
289
  mlp_dim: int = 2048,
298
- activation: Type[nn.Module] = nn.ReLU,
290
+ activation: type[nn.Module] = nn.ReLU,
299
291
  attention_downsample_rate: int = 2,
300
292
  skip_first_layer_pe: bool = False,
301
293
  ) -> None:
302
- """
303
- Initializes a SAM2TwoWayAttentionBlock for performing self-attention and cross-attention in two directions.
294
+ """Initialize a SAM2TwoWayAttentionBlock for performing self-attention and cross-attention in two directions.
304
295
 
305
296
  This block extends the TwoWayAttentionBlock and consists of four main components: self-attention on sparse
306
- inputs, cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and cross-attention
307
- from dense to sparse inputs.
297
+ inputs, cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and cross-attention from
298
+ dense to sparse inputs.
308
299
 
309
300
  Args:
310
301
  embedding_dim (int): The channel dimension of the embeddings.
@@ -325,12 +316,11 @@ class SAM2TwoWayAttentionBlock(TwoWayAttentionBlock):
325
316
 
326
317
 
327
318
  class SAM2TwoWayTransformer(TwoWayTransformer):
328
- """
329
- A Two-Way Transformer module for simultaneous attention to image and query points.
319
+ """A Two-Way Transformer module for simultaneous attention to image and query points.
330
320
 
331
- This class extends the TwoWayTransformer, implementing a specialized transformer decoder that attends to an
332
- input image using queries with supplied positional embeddings. It is particularly useful for tasks like
333
- object detection, image segmentation, and point cloud processing.
321
+ This class extends the TwoWayTransformer, implementing a specialized transformer decoder that attends to an input
322
+ image using queries with supplied positional embeddings. It is particularly useful for tasks like object detection,
323
+ image segmentation, and point cloud processing.
334
324
 
335
325
  Attributes:
336
326
  depth (int): Number of layers in the transformer.
@@ -359,14 +349,13 @@ class SAM2TwoWayTransformer(TwoWayTransformer):
359
349
  embedding_dim: int,
360
350
  num_heads: int,
361
351
  mlp_dim: int,
362
- activation: Type[nn.Module] = nn.ReLU,
352
+ activation: type[nn.Module] = nn.ReLU,
363
353
  attention_downsample_rate: int = 2,
364
354
  ) -> None:
365
- """
366
- Initializes a SAM2TwoWayTransformer instance.
355
+ """Initialize a SAM2TwoWayTransformer instance.
367
356
 
368
- This transformer decoder attends to an input image using queries with supplied positional embeddings.
369
- It is designed for tasks like object detection, image segmentation, and point cloud processing.
357
+ This transformer decoder attends to an input image using queries with supplied positional embeddings. It is
358
+ designed for tasks like object detection, image segmentation, and point cloud processing.
370
359
 
371
360
  Args:
372
361
  depth (int): Number of layers in the transformer.
@@ -403,15 +392,14 @@ class SAM2TwoWayTransformer(TwoWayTransformer):
403
392
 
404
393
 
405
394
  class RoPEAttention(Attention):
406
- """
407
- Implements rotary position encoding for attention mechanisms in transformer architectures.
395
+ """Implements rotary position encoding for attention mechanisms in transformer architectures.
408
396
 
409
- This class extends the base Attention class by incorporating Rotary Position Encoding (RoPE) to enhance
410
- the positional awareness of the attention mechanism.
397
+ This class extends the base Attention class by incorporating Rotary Position Encoding (RoPE) to enhance the
398
+ positional awareness of the attention mechanism.
411
399
 
412
400
  Attributes:
413
401
  compute_cis (Callable): Function to compute axial complex numbers for rotary encoding.
414
- freqs_cis (Tensor): Precomputed frequency tensor for rotary encoding.
402
+ freqs_cis (torch.Tensor): Precomputed frequency tensor for rotary encoding.
415
403
  rope_k_repeat (bool): Flag to repeat query RoPE to match key length for cross-attention to memories.
416
404
 
417
405
  Methods:
@@ -430,12 +418,12 @@ class RoPEAttention(Attention):
430
418
  def __init__(
431
419
  self,
432
420
  *args,
433
- rope_theta=10000.0,
434
- rope_k_repeat=False,
435
- feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution
421
+ rope_theta: float = 10000.0,
422
+ rope_k_repeat: bool = False,
423
+ feat_sizes: tuple[int, int] = (32, 32), # [w, h] for stride 16 feats at 512 resolution
436
424
  **kwargs,
437
425
  ):
438
- """Initializes RoPEAttention with rotary position encoding for enhanced positional awareness."""
426
+ """Initialize RoPEAttention with rotary position encoding for enhanced positional awareness."""
439
427
  super().__init__(*args, **kwargs)
440
428
 
441
429
  self.compute_cis = partial(compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta)
@@ -443,8 +431,8 @@ class RoPEAttention(Attention):
443
431
  self.freqs_cis = freqs_cis
444
432
  self.rope_k_repeat = rope_k_repeat # repeat q rope to match k length, needed for cross-attention to memories
445
433
 
446
- def forward(self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0) -> Tensor:
447
- """Applies rotary position encoding and computes attention between query, key, and value tensors."""
434
+ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_k_exclude_rope: int = 0) -> torch.Tensor:
435
+ """Apply rotary position encoding and compute attention between query, key, and value tensors."""
448
436
  q = self.q_proj(q)
449
437
  k = self.k_proj(k)
450
438
  v = self.v_proj(v)
@@ -486,7 +474,7 @@ class RoPEAttention(Attention):
486
474
 
487
475
 
488
476
  def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
489
- """Applies pooling and optional normalization to a tensor, handling spatial dimension permutations."""
477
+ """Apply pooling and optional normalization to a tensor, handling spatial dimension permutations."""
490
478
  if pool is None:
491
479
  return x
492
480
  # (B, H, W, C) -> (B, C, H, W)
@@ -501,12 +489,11 @@ def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.T
501
489
 
502
490
 
503
491
  class MultiScaleAttention(nn.Module):
504
- """
505
- Implements multiscale self-attention with optional query pooling for efficient feature extraction.
492
+ """Implements multiscale self-attention with optional query pooling for efficient feature extraction.
506
493
 
507
- This class provides a flexible implementation of multiscale attention, allowing for optional
508
- downsampling of query features through pooling. It's designed to enhance the model's ability to
509
- capture multiscale information in visual tasks.
494
+ This class provides a flexible implementation of multiscale attention, allowing for optional downsampling of query
495
+ features through pooling. It's designed to enhance the model's ability to capture multiscale information in visual
496
+ tasks.
510
497
 
511
498
  Attributes:
512
499
  dim (int): Input dimension of the feature map.
@@ -537,7 +524,7 @@ class MultiScaleAttention(nn.Module):
537
524
  num_heads: int,
538
525
  q_pool: nn.Module = None,
539
526
  ):
540
- """Initializes multiscale attention with optional query pooling for efficient feature extraction."""
527
+ """Initialize multiscale attention with optional query pooling for efficient feature extraction."""
541
528
  super().__init__()
542
529
 
543
530
  self.dim = dim
@@ -552,7 +539,7 @@ class MultiScaleAttention(nn.Module):
552
539
  self.proj = nn.Linear(dim_out, dim_out)
553
540
 
554
541
  def forward(self, x: torch.Tensor) -> torch.Tensor:
555
- """Applies multiscale attention with optional query pooling to extract multiscale features."""
542
+ """Apply multiscale attention with optional query pooling to extract multiscale features."""
556
543
  B, H, W, _ = x.shape
557
544
  # qkv with shape (B, H * W, 3, nHead, C)
558
545
  qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
@@ -581,11 +568,10 @@ class MultiScaleAttention(nn.Module):
581
568
 
582
569
 
583
570
  class MultiScaleBlock(nn.Module):
584
- """
585
- A multiscale attention block with window partitioning and query pooling for efficient vision transformers.
571
+ """A multiscale attention block with window partitioning and query pooling for efficient vision transformers.
586
572
 
587
- This class implements a multiscale attention mechanism with optional window partitioning and downsampling,
588
- designed for use in vision transformer architectures.
573
+ This class implements a multiscale attention mechanism with optional window partitioning and downsampling, designed
574
+ for use in vision transformer architectures.
589
575
 
590
576
  Attributes:
591
577
  dim (int): Input dimension of the block.
@@ -593,7 +579,7 @@ class MultiScaleBlock(nn.Module):
593
579
  norm1 (nn.Module): First normalization layer.
594
580
  window_size (int): Size of the window for partitioning.
595
581
  pool (nn.Module | None): Pooling layer for query downsampling.
596
- q_stride (Tuple[int, int] | None): Stride for query pooling.
582
+ q_stride (tuple[int, int] | None): Stride for query pooling.
597
583
  attn (MultiScaleAttention): Multi-scale attention module.
598
584
  drop_path (nn.Module): Drop path layer for regularization.
599
585
  norm2 (nn.Module): Second normalization layer.
@@ -618,12 +604,12 @@ class MultiScaleBlock(nn.Module):
618
604
  num_heads: int,
619
605
  mlp_ratio: float = 4.0,
620
606
  drop_path: float = 0.0,
621
- norm_layer: Union[nn.Module, str] = "LayerNorm",
622
- q_stride: Tuple[int, int] = None,
623
- act_layer: nn.Module = nn.GELU,
607
+ norm_layer: nn.Module | str = "LayerNorm",
608
+ q_stride: tuple[int, int] | None = None,
609
+ act_layer: type[nn.Module] = nn.GELU,
624
610
  window_size: int = 0,
625
611
  ):
626
- """Initializes a multiscale attention block with window partitioning and optional query pooling."""
612
+ """Initialize a multiscale attention block with window partitioning and optional query pooling."""
627
613
  super().__init__()
628
614
 
629
615
  if isinstance(norm_layer, str):
@@ -660,7 +646,7 @@ class MultiScaleBlock(nn.Module):
660
646
  self.proj = nn.Linear(dim, dim_out)
661
647
 
662
648
  def forward(self, x: torch.Tensor) -> torch.Tensor:
663
- """Processes input through multiscale attention and MLP, with optional windowing and downsampling."""
649
+ """Process input through multiscale attention and MLP, with optional windowing and downsampling."""
664
650
  shortcut = x # B, H, W, C
665
651
  x = self.norm1(x)
666
652
 
@@ -696,11 +682,10 @@ class MultiScaleBlock(nn.Module):
696
682
 
697
683
 
698
684
  class PositionEmbeddingSine(nn.Module):
699
- """
700
- A module for generating sinusoidal positional embeddings for 2D inputs like images.
685
+ """A module for generating sinusoidal positional embeddings for 2D inputs like images.
701
686
 
702
- This class implements sinusoidal position encoding for 2D spatial positions, which can be used in
703
- transformer-based models for computer vision tasks.
687
+ This class implements sinusoidal position encoding for 2D spatial positions, which can be used in transformer-based
688
+ models for computer vision tasks.
704
689
 
705
690
  Attributes:
706
691
  num_pos_feats (int): Number of positional features (half of the embedding dimension).
@@ -725,12 +710,12 @@ class PositionEmbeddingSine(nn.Module):
725
710
 
726
711
  def __init__(
727
712
  self,
728
- num_pos_feats,
713
+ num_pos_feats: int,
729
714
  temperature: int = 10000,
730
715
  normalize: bool = True,
731
- scale: Optional[float] = None,
716
+ scale: float | None = None,
732
717
  ):
733
- """Initializes sinusoidal position embeddings for 2D image inputs."""
718
+ """Initialize sinusoidal position embeddings for 2D image inputs."""
734
719
  super().__init__()
735
720
  assert num_pos_feats % 2 == 0, "Expecting even model width"
736
721
  self.num_pos_feats = num_pos_feats // 2
@@ -744,8 +729,8 @@ class PositionEmbeddingSine(nn.Module):
744
729
 
745
730
  self.cache = {}
746
731
 
747
- def _encode_xy(self, x, y):
748
- """Encodes 2D positions using sine/cosine functions for transformer positional embeddings."""
732
+ def _encode_xy(self, x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
733
+ """Encode 2D positions using sine/cosine functions for transformer positional embeddings."""
749
734
  assert len(x) == len(y) and x.ndim == y.ndim == 1
750
735
  x_embed = x * self.scale
751
736
  y_embed = y * self.scale
@@ -760,16 +745,16 @@ class PositionEmbeddingSine(nn.Module):
760
745
  return pos_x, pos_y
761
746
 
762
747
  @torch.no_grad()
763
- def encode_boxes(self, x, y, w, h):
764
- """Encodes box coordinates and dimensions into positional embeddings for detection."""
748
+ def encode_boxes(self, x: torch.Tensor, y: torch.Tensor, w: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
749
+ """Encode box coordinates and dimensions into positional embeddings for detection."""
765
750
  pos_x, pos_y = self._encode_xy(x, y)
766
751
  return torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
767
752
 
768
753
  encode = encode_boxes # Backwards compatibility
769
754
 
770
755
  @torch.no_grad()
771
- def encode_points(self, x, y, labels):
772
- """Encodes 2D points with sinusoidal embeddings and appends labels."""
756
+ def encode_points(self, x: torch.Tensor, y: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
757
+ """Encode 2D points with sinusoidal embeddings and append labels."""
773
758
  (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
774
759
  assert bx == by and nx == ny and bx == bl and nx == nl
775
760
  pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
@@ -777,8 +762,8 @@ class PositionEmbeddingSine(nn.Module):
777
762
  return torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
778
763
 
779
764
  @torch.no_grad()
780
- def forward(self, x: torch.Tensor):
781
- """Generates sinusoidal position embeddings for 2D inputs like images."""
765
+ def forward(self, x: torch.Tensor) -> Tensor:
766
+ """Generate sinusoidal position embeddings for 2D inputs like images."""
782
767
  cache_key = (x.shape[-2], x.shape[-1])
783
768
  if cache_key in self.cache:
784
769
  return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
@@ -811,8 +796,7 @@ class PositionEmbeddingSine(nn.Module):
811
796
 
812
797
 
813
798
  class PositionEmbeddingRandom(nn.Module):
814
- """
815
- Positional encoding using random spatial frequencies.
799
+ """Positional encoding using random spatial frequencies.
816
800
 
817
801
  This class generates positional embeddings for input coordinates using random spatial frequencies. It is
818
802
  particularly useful for transformer-based models that require position information.
@@ -833,8 +817,8 @@ class PositionEmbeddingRandom(nn.Module):
833
817
  torch.Size([128, 32, 32])
834
818
  """
835
819
 
836
- def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
837
- """Initializes random spatial frequency position embedding for transformers."""
820
+ def __init__(self, num_pos_feats: int = 64, scale: float | None = None) -> None:
821
+ """Initialize random spatial frequency position embedding for transformers."""
838
822
  super().__init__()
839
823
  if scale is None or scale <= 0.0:
840
824
  scale = 1.0
@@ -845,7 +829,7 @@ class PositionEmbeddingRandom(nn.Module):
845
829
  torch.backends.cudnn.deterministic = False
846
830
 
847
831
  def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
848
- """Encodes normalized [0,1] coordinates using random spatial frequencies."""
832
+ """Encode normalized [0,1] coordinates using random spatial frequencies."""
849
833
  # Assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
850
834
  coords = 2 * coords - 1
851
835
  coords = coords @ self.positional_encoding_gaussian_matrix
@@ -853,11 +837,14 @@ class PositionEmbeddingRandom(nn.Module):
853
837
  # Outputs d_1 x ... x d_n x C shape
854
838
  return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
855
839
 
856
- def forward(self, size: Tuple[int, int]) -> torch.Tensor:
857
- """Generates positional encoding for a grid using random spatial frequencies."""
840
+ def forward(self, size: tuple[int, int]) -> torch.Tensor:
841
+ """Generate positional encoding for a grid using random spatial frequencies."""
858
842
  h, w = size
859
- device: Any = self.positional_encoding_gaussian_matrix.device
860
- grid = torch.ones((h, w), device=device, dtype=torch.float32)
843
+ grid = torch.ones(
844
+ (h, w),
845
+ device=self.positional_encoding_gaussian_matrix.device,
846
+ dtype=self.positional_encoding_gaussian_matrix.dtype,
847
+ )
861
848
  y_embed = grid.cumsum(dim=0) - 0.5
862
849
  x_embed = grid.cumsum(dim=1) - 0.5
863
850
  y_embed = y_embed / h
@@ -866,21 +853,20 @@ class PositionEmbeddingRandom(nn.Module):
866
853
  pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
867
854
  return pe.permute(2, 0, 1) # C x H x W
868
855
 
869
- def forward_with_coords(self, coords_input: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor:
870
- """Positionally encodes input coordinates, normalizing them to [0,1] based on the given image size."""
856
+ def forward_with_coords(self, coords_input: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
857
+ """Positionally encode input coordinates, normalizing them to [0,1] based on the given image size."""
871
858
  coords = coords_input.clone()
872
859
  coords[:, :, 0] = coords[:, :, 0] / image_size[1]
873
860
  coords[:, :, 1] = coords[:, :, 1] / image_size[0]
874
- return self._pe_encoding(coords.to(torch.float)) # B x N x C
861
+ return self._pe_encoding(coords) # B x N x C
875
862
 
876
863
 
877
864
  class Block(nn.Module):
878
- """
879
- Transformer block with support for window attention and residual propagation.
865
+ """Transformer block with support for window attention and residual propagation.
880
866
 
881
- This class implements a transformer block that can use either global or windowed self-attention,
882
- followed by a feed-forward network. It supports relative positional embeddings and is designed
883
- for use in vision transformer architectures.
867
+ This class implements a transformer block that can use either global or windowed self-attention, followed by a
868
+ feed-forward network. It supports relative positional embeddings and is designed for use in vision transformer
869
+ architectures.
884
870
 
885
871
  Attributes:
886
872
  norm1 (nn.Module): First normalization layer.
@@ -907,19 +893,18 @@ class Block(nn.Module):
907
893
  num_heads: int,
908
894
  mlp_ratio: float = 4.0,
909
895
  qkv_bias: bool = True,
910
- norm_layer: Type[nn.Module] = nn.LayerNorm,
911
- act_layer: Type[nn.Module] = nn.GELU,
896
+ norm_layer: type[nn.Module] = nn.LayerNorm,
897
+ act_layer: type[nn.Module] = nn.GELU,
912
898
  use_rel_pos: bool = False,
913
899
  rel_pos_zero_init: bool = True,
914
900
  window_size: int = 0,
915
- input_size: Optional[Tuple[int, int]] = None,
901
+ input_size: tuple[int, int] | None = None,
916
902
  ) -> None:
917
- """
918
- Initializes a transformer block with optional window attention and relative positional embeddings.
903
+ """Initialize a transformer block with optional window attention and relative positional embeddings.
919
904
 
920
- This constructor sets up a transformer block that can use either global or windowed self-attention,
921
- followed by a feed-forward network. It supports relative positional embeddings and is designed
922
- for use in vision transformer architectures.
905
+ This constructor sets up a transformer block that can use either global or windowed self-attention, followed by
906
+ a feed-forward network. It supports relative positional embeddings and is designed for use in vision transformer
907
+ architectures.
923
908
 
924
909
  Args:
925
910
  dim (int): Number of input channels.
@@ -931,7 +916,7 @@ class Block(nn.Module):
931
916
  use_rel_pos (bool): If True, uses relative positional embeddings in attention.
932
917
  rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero.
933
918
  window_size (int): Size of attention window. If 0, uses global attention.
934
- input_size (Optional[Tuple[int, int]]): Input resolution for calculating relative positional parameter size.
919
+ input_size (tuple[int, int] | None): Input resolution for calculating relative positional parameter size.
935
920
 
936
921
  Examples:
937
922
  >>> block = Block(dim=256, num_heads=8, window_size=7)
@@ -957,7 +942,7 @@ class Block(nn.Module):
957
942
  self.window_size = window_size
958
943
 
959
944
  def forward(self, x: torch.Tensor) -> torch.Tensor:
960
- """Processes input through transformer block with optional windowed self-attention and residual connection."""
945
+ """Process input through transformer block with optional windowed self-attention and residual connection."""
961
946
  shortcut = x
962
947
  x = self.norm1(x)
963
948
  # Window partition
@@ -975,35 +960,30 @@ class Block(nn.Module):
975
960
 
976
961
 
977
962
  class REAttention(nn.Module):
978
- """
979
- Rotary Embedding Attention module for efficient self-attention in transformer architectures.
963
+ """Relative Position Attention module for efficient self-attention in transformer architectures.
980
964
 
981
- This class implements a multi-head attention mechanism with rotary positional embeddings, designed
982
- for use in vision transformer models. It supports optional query pooling and window partitioning
983
- for efficient processing of large inputs.
965
+ This class implements a multi-head attention mechanism with relative positional embeddings, designed for use in
966
+ vision transformer models. It supports optional query pooling and window partitioning for efficient processing of
967
+ large inputs.
984
968
 
985
969
  Attributes:
986
- compute_cis (Callable): Function to compute axial complex numbers for rotary encoding.
987
- freqs_cis (Tensor): Precomputed frequency tensor for rotary encoding.
988
- rope_k_repeat (bool): Flag to repeat query RoPE to match key length for cross-attention to memories.
989
- q_proj (nn.Linear): Linear projection for query.
990
- k_proj (nn.Linear): Linear projection for key.
991
- v_proj (nn.Linear): Linear projection for value.
992
- out_proj (nn.Linear): Output projection.
993
970
  num_heads (int): Number of attention heads.
994
- internal_dim (int): Internal dimension for attention computation.
971
+ scale (float): Scaling factor for attention computation.
972
+ qkv (nn.Linear): Linear projection for query, key, and value.
973
+ proj (nn.Linear): Output projection layer.
974
+ use_rel_pos (bool): Whether to use relative positional embeddings.
975
+ rel_pos_h (nn.Parameter): Relative positional embeddings for height dimension.
976
+ rel_pos_w (nn.Parameter): Relative positional embeddings for width dimension.
995
977
 
996
978
  Methods:
997
- forward: Applies rotary position encoding and computes attention between query, key, and value tensors.
979
+ forward: Applies multi-head attention with optional relative positional encoding to input tensor.
998
980
 
999
981
  Examples:
1000
- >>> rope_attn = REAttention(embedding_dim=256, num_heads=8, rope_theta=10000.0, feat_sizes=(32, 32))
1001
- >>> q = torch.randn(1, 1024, 256)
1002
- >>> k = torch.randn(1, 1024, 256)
1003
- >>> v = torch.randn(1, 1024, 256)
1004
- >>> output = rope_attn(q, k, v)
982
+ >>> attention = REAttention(dim=256, num_heads=8, input_size=(32, 32))
983
+ >>> x = torch.randn(1, 32, 32, 256)
984
+ >>> output = attention(x)
1005
985
  >>> print(output.shape)
1006
- torch.Size([1, 1024, 256])
986
+ torch.Size([1, 32, 32, 256])
1007
987
  """
1008
988
 
1009
989
  def __init__(
@@ -1013,22 +993,21 @@ class REAttention(nn.Module):
1013
993
  qkv_bias: bool = True,
1014
994
  use_rel_pos: bool = False,
1015
995
  rel_pos_zero_init: bool = True,
1016
- input_size: Optional[Tuple[int, int]] = None,
996
+ input_size: tuple[int, int] | None = None,
1017
997
  ) -> None:
1018
- """
1019
- Initializes a Relative Position Attention module for transformer-based architectures.
998
+ """Initialize a Relative Position Attention module for transformer-based architectures.
1020
999
 
1021
- This module implements multi-head attention with optional relative positional encodings, designed
1022
- specifically for vision tasks in transformer models.
1000
+ This module implements multi-head attention with optional relative positional encodings, designed specifically
1001
+ for vision tasks in transformer models.
1023
1002
 
1024
1003
  Args:
1025
1004
  dim (int): Number of input channels.
1026
- num_heads (int): Number of attention heads. Default is 8.
1027
- qkv_bias (bool): If True, adds a learnable bias to query, key, value projections. Default is True.
1028
- use_rel_pos (bool): If True, uses relative positional encodings. Default is False.
1029
- rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero. Default is True.
1030
- input_size (Tuple[int, int] | None): Input resolution for calculating relative positional parameter size.
1031
- Required if use_rel_pos is True. Default is None.
1005
+ num_heads (int): Number of attention heads.
1006
+ qkv_bias (bool): If True, adds a learnable bias to query, key, value projections.
1007
+ use_rel_pos (bool): If True, uses relative positional encodings.
1008
+ rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero.
1009
+ input_size (tuple[int, int] | None): Input resolution for calculating relative positional parameter size.
1010
+ Required if use_rel_pos is True.
1032
1011
 
1033
1012
  Examples:
1034
1013
  >>> attention = REAttention(dim=256, num_heads=8, input_size=(32, 32))
@@ -1053,7 +1032,7 @@ class REAttention(nn.Module):
1053
1032
  self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
1054
1033
 
1055
1034
  def forward(self, x: torch.Tensor) -> torch.Tensor:
1056
- """Applies multi-head attention with optional relative positional encoding to input tensor."""
1035
+ """Apply multi-head attention with optional relative positional encoding to input tensor."""
1057
1036
  B, H, W, _ = x.shape
1058
1037
  # qkv with shape (3, B, nHead, H * W, C)
1059
1038
  qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
@@ -1071,12 +1050,11 @@ class REAttention(nn.Module):
1071
1050
 
1072
1051
 
1073
1052
  class PatchEmbed(nn.Module):
1074
- """
1075
- Image to Patch Embedding module for vision transformer architectures.
1053
+ """Image to Patch Embedding module for vision transformer architectures.
1076
1054
 
1077
- This module converts an input image into a sequence of patch embeddings using a convolutional layer.
1078
- It is commonly used as the first layer in vision transformer architectures to transform image data
1079
- into a suitable format for subsequent transformer blocks.
1055
+ This module converts an input image into a sequence of patch embeddings using a convolutional layer. It is commonly
1056
+ used as the first layer in vision transformer architectures to transform image data into a suitable format for
1057
+ subsequent transformer blocks.
1080
1058
 
1081
1059
  Attributes:
1082
1060
  proj (nn.Conv2d): Convolutional layer for projecting image patches to embeddings.
@@ -1094,22 +1072,21 @@ class PatchEmbed(nn.Module):
1094
1072
 
1095
1073
  def __init__(
1096
1074
  self,
1097
- kernel_size: Tuple[int, int] = (16, 16),
1098
- stride: Tuple[int, int] = (16, 16),
1099
- padding: Tuple[int, int] = (0, 0),
1075
+ kernel_size: tuple[int, int] = (16, 16),
1076
+ stride: tuple[int, int] = (16, 16),
1077
+ padding: tuple[int, int] = (0, 0),
1100
1078
  in_chans: int = 3,
1101
1079
  embed_dim: int = 768,
1102
1080
  ) -> None:
1103
- """
1104
- Initializes the PatchEmbed module for converting image patches to embeddings.
1081
+ """Initialize the PatchEmbed module for converting image patches to embeddings.
1105
1082
 
1106
- This module is typically used as the first layer in vision transformer architectures to transform
1107
- image data into a suitable format for subsequent transformer blocks.
1083
+ This module is typically used as the first layer in vision transformer architectures to transform image data
1084
+ into a suitable format for subsequent transformer blocks.
1108
1085
 
1109
1086
  Args:
1110
- kernel_size (Tuple[int, int]): Size of the convolutional kernel for patch extraction.
1111
- stride (Tuple[int, int]): Stride of the convolutional operation.
1112
- padding (Tuple[int, int]): Padding applied to the input before convolution.
1087
+ kernel_size (tuple[int, int]): Size of the convolutional kernel for patch extraction.
1088
+ stride (tuple[int, int]): Stride of the convolutional operation.
1089
+ padding (tuple[int, int]): Padding applied to the input before convolution.
1113
1090
  in_chans (int): Number of input image channels.
1114
1091
  embed_dim (int): Dimensionality of the output patch embeddings.
1115
1092
 
@@ -1125,5 +1102,5 @@ class PatchEmbed(nn.Module):
1125
1102
  self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
1126
1103
 
1127
1104
  def forward(self, x: torch.Tensor) -> torch.Tensor:
1128
- """Computes patch embedding by applying convolution and transposing resulting tensor."""
1105
+ """Compute patch embedding by applying convolution and transposing resulting tensor."""
1129
1106
  return self.proj(x).permute(0, 2, 3, 1) # B C H W -> B H W C