dgenerate-ultralytics-headless 8.3.141__py3-none-any.whl → 8.3.144__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. {dgenerate_ultralytics_headless-8.3.141.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/METADATA +1 -1
  2. dgenerate_ultralytics_headless-8.3.144.dist-info/RECORD +272 -0
  3. tests/conftest.py +7 -24
  4. tests/test_cli.py +1 -1
  5. tests/test_cuda.py +7 -2
  6. tests/test_engine.py +7 -8
  7. tests/test_exports.py +16 -16
  8. tests/test_integrations.py +1 -1
  9. tests/test_solutions.py +12 -12
  10. ultralytics/__init__.py +1 -1
  11. ultralytics/cfg/__init__.py +22 -19
  12. ultralytics/data/annotator.py +6 -5
  13. ultralytics/data/augment.py +127 -126
  14. ultralytics/data/base.py +54 -51
  15. ultralytics/data/build.py +47 -23
  16. ultralytics/data/converter.py +47 -43
  17. ultralytics/data/dataset.py +51 -50
  18. ultralytics/data/loaders.py +77 -44
  19. ultralytics/data/split.py +22 -9
  20. ultralytics/data/split_dota.py +63 -39
  21. ultralytics/data/utils.py +59 -39
  22. ultralytics/engine/exporter.py +79 -27
  23. ultralytics/engine/model.py +39 -39
  24. ultralytics/engine/predictor.py +37 -28
  25. ultralytics/engine/results.py +187 -158
  26. ultralytics/engine/trainer.py +36 -19
  27. ultralytics/engine/tuner.py +12 -9
  28. ultralytics/engine/validator.py +7 -9
  29. ultralytics/hub/__init__.py +11 -13
  30. ultralytics/hub/auth.py +22 -2
  31. ultralytics/hub/google/__init__.py +19 -19
  32. ultralytics/hub/session.py +37 -51
  33. ultralytics/hub/utils.py +19 -5
  34. ultralytics/models/fastsam/model.py +30 -12
  35. ultralytics/models/fastsam/predict.py +5 -6
  36. ultralytics/models/fastsam/utils.py +3 -3
  37. ultralytics/models/fastsam/val.py +10 -6
  38. ultralytics/models/nas/model.py +9 -5
  39. ultralytics/models/nas/predict.py +6 -6
  40. ultralytics/models/nas/val.py +3 -3
  41. ultralytics/models/rtdetr/model.py +7 -6
  42. ultralytics/models/rtdetr/predict.py +14 -7
  43. ultralytics/models/rtdetr/train.py +10 -4
  44. ultralytics/models/rtdetr/val.py +36 -9
  45. ultralytics/models/sam/amg.py +30 -12
  46. ultralytics/models/sam/build.py +22 -22
  47. ultralytics/models/sam/model.py +10 -9
  48. ultralytics/models/sam/modules/blocks.py +76 -80
  49. ultralytics/models/sam/modules/decoders.py +6 -8
  50. ultralytics/models/sam/modules/encoders.py +23 -26
  51. ultralytics/models/sam/modules/memory_attention.py +13 -1
  52. ultralytics/models/sam/modules/sam.py +57 -26
  53. ultralytics/models/sam/modules/tiny_encoder.py +232 -237
  54. ultralytics/models/sam/modules/transformer.py +13 -13
  55. ultralytics/models/sam/modules/utils.py +11 -19
  56. ultralytics/models/sam/predict.py +114 -101
  57. ultralytics/models/utils/loss.py +98 -77
  58. ultralytics/models/utils/ops.py +116 -67
  59. ultralytics/models/yolo/classify/predict.py +5 -5
  60. ultralytics/models/yolo/classify/train.py +32 -28
  61. ultralytics/models/yolo/classify/val.py +7 -8
  62. ultralytics/models/yolo/detect/predict.py +1 -0
  63. ultralytics/models/yolo/detect/train.py +15 -14
  64. ultralytics/models/yolo/detect/val.py +37 -36
  65. ultralytics/models/yolo/model.py +106 -23
  66. ultralytics/models/yolo/obb/predict.py +3 -4
  67. ultralytics/models/yolo/obb/train.py +14 -6
  68. ultralytics/models/yolo/obb/val.py +29 -23
  69. ultralytics/models/yolo/pose/predict.py +9 -8
  70. ultralytics/models/yolo/pose/train.py +24 -16
  71. ultralytics/models/yolo/pose/val.py +44 -26
  72. ultralytics/models/yolo/segment/predict.py +5 -5
  73. ultralytics/models/yolo/segment/train.py +11 -7
  74. ultralytics/models/yolo/segment/val.py +2 -2
  75. ultralytics/models/yolo/world/train.py +33 -23
  76. ultralytics/models/yolo/world/train_world.py +11 -3
  77. ultralytics/models/yolo/yoloe/predict.py +11 -11
  78. ultralytics/models/yolo/yoloe/train.py +73 -21
  79. ultralytics/models/yolo/yoloe/train_seg.py +10 -7
  80. ultralytics/models/yolo/yoloe/val.py +42 -18
  81. ultralytics/nn/autobackend.py +59 -15
  82. ultralytics/nn/modules/__init__.py +4 -4
  83. ultralytics/nn/modules/activation.py +4 -1
  84. ultralytics/nn/modules/block.py +178 -111
  85. ultralytics/nn/modules/conv.py +6 -5
  86. ultralytics/nn/modules/head.py +469 -121
  87. ultralytics/nn/modules/transformer.py +147 -58
  88. ultralytics/nn/tasks.py +227 -20
  89. ultralytics/nn/text_model.py +30 -33
  90. ultralytics/solutions/ai_gym.py +1 -1
  91. ultralytics/solutions/analytics.py +7 -4
  92. ultralytics/solutions/config.py +10 -10
  93. ultralytics/solutions/distance_calculation.py +13 -11
  94. ultralytics/solutions/heatmap.py +1 -1
  95. ultralytics/solutions/instance_segmentation.py +6 -3
  96. ultralytics/solutions/object_blurrer.py +3 -3
  97. ultralytics/solutions/object_counter.py +18 -12
  98. ultralytics/solutions/object_cropper.py +12 -5
  99. ultralytics/solutions/parking_management.py +29 -28
  100. ultralytics/solutions/queue_management.py +6 -6
  101. ultralytics/solutions/region_counter.py +10 -3
  102. ultralytics/solutions/security_alarm.py +3 -3
  103. ultralytics/solutions/similarity_search.py +85 -24
  104. ultralytics/solutions/solutions.py +215 -85
  105. ultralytics/solutions/speed_estimation.py +28 -22
  106. ultralytics/solutions/streamlit_inference.py +17 -12
  107. ultralytics/solutions/trackzone.py +4 -4
  108. ultralytics/trackers/basetrack.py +16 -23
  109. ultralytics/trackers/bot_sort.py +30 -20
  110. ultralytics/trackers/byte_tracker.py +70 -64
  111. ultralytics/trackers/track.py +4 -8
  112. ultralytics/trackers/utils/gmc.py +31 -58
  113. ultralytics/trackers/utils/kalman_filter.py +37 -37
  114. ultralytics/trackers/utils/matching.py +1 -1
  115. ultralytics/utils/__init__.py +105 -89
  116. ultralytics/utils/autobatch.py +16 -3
  117. ultralytics/utils/autodevice.py +54 -24
  118. ultralytics/utils/benchmarks.py +42 -28
  119. ultralytics/utils/callbacks/base.py +3 -3
  120. ultralytics/utils/callbacks/clearml.py +9 -9
  121. ultralytics/utils/callbacks/comet.py +67 -25
  122. ultralytics/utils/callbacks/dvc.py +7 -10
  123. ultralytics/utils/callbacks/mlflow.py +2 -5
  124. ultralytics/utils/callbacks/neptune.py +7 -13
  125. ultralytics/utils/callbacks/raytune.py +1 -1
  126. ultralytics/utils/callbacks/tensorboard.py +5 -6
  127. ultralytics/utils/callbacks/wb.py +14 -14
  128. ultralytics/utils/checks.py +14 -13
  129. ultralytics/utils/dist.py +5 -5
  130. ultralytics/utils/downloads.py +94 -67
  131. ultralytics/utils/errors.py +5 -5
  132. ultralytics/utils/export.py +61 -47
  133. ultralytics/utils/files.py +23 -22
  134. ultralytics/utils/instance.py +48 -52
  135. ultralytics/utils/loss.py +78 -40
  136. ultralytics/utils/metrics.py +186 -130
  137. ultralytics/utils/ops.py +186 -190
  138. ultralytics/utils/patches.py +15 -17
  139. ultralytics/utils/plotting.py +84 -42
  140. ultralytics/utils/tal.py +21 -15
  141. ultralytics/utils/torch_utils.py +53 -50
  142. ultralytics/utils/triton.py +5 -4
  143. ultralytics/utils/tuner.py +5 -5
  144. dgenerate_ultralytics_headless-8.3.141.dist-info/RECORD +0 -272
  145. {dgenerate_ultralytics_headless-8.3.141.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/WHEEL +0 -0
  146. {dgenerate_ultralytics_headless-8.3.141.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/entry_points.txt +0 -0
  147. {dgenerate_ultralytics_headless-8.3.141.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/licenses/LICENSE +0 -0
  148. {dgenerate_ultralytics_headless-8.3.141.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/top_level.txt +0 -0
@@ -33,14 +33,14 @@ class DropPath(nn.Module):
33
33
  >>> output = drop_path(x)
34
34
  """
35
35
 
36
- def __init__(self, drop_prob=0.0, scale_by_keep=True):
36
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
37
37
  """Initialize DropPath module for stochastic depth regularization during training."""
38
38
  super().__init__()
39
39
  self.drop_prob = drop_prob
40
40
  self.scale_by_keep = scale_by_keep
41
41
 
42
- def forward(self, x):
43
- """Applies stochastic depth to input tensor during training, with optional scaling."""
42
+ def forward(self, x: Tensor) -> Tensor:
43
+ """Apply stochastic depth to input tensor during training, with optional scaling."""
44
44
  if self.drop_prob == 0.0 or not self.training:
45
45
  return x
46
46
  keep_prob = 1 - self.drop_prob
@@ -76,14 +76,14 @@ class MaskDownSampler(nn.Module):
76
76
 
77
77
  def __init__(
78
78
  self,
79
- embed_dim=256,
80
- kernel_size=4,
81
- stride=4,
82
- padding=0,
83
- total_stride=16,
84
- activation=nn.GELU,
79
+ embed_dim: int = 256,
80
+ kernel_size: int = 4,
81
+ stride: int = 4,
82
+ padding: int = 0,
83
+ total_stride: int = 16,
84
+ activation: Type[nn.Module] = nn.GELU,
85
85
  ):
86
- """Initializes a mask downsampler module for progressive downsampling and channel expansion."""
86
+ """Initialize a mask downsampler module for progressive downsampling and channel expansion."""
87
87
  super().__init__()
88
88
  num_layers = int(math.log2(total_stride) // math.log2(stride))
89
89
  assert stride**num_layers == total_stride
@@ -106,8 +106,8 @@ class MaskDownSampler(nn.Module):
106
106
 
107
107
  self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
108
108
 
109
- def forward(self, x):
110
- """Downsamples and encodes input mask to embed_dim channels using convolutional layers and LayerNorm2d."""
109
+ def forward(self, x: Tensor) -> Tensor:
110
+ """Downsample and encode input mask to embed_dim channels using convolutional layers and LayerNorm2d."""
111
111
  return self.encoder(x)
112
112
 
113
113
 
@@ -141,12 +141,12 @@ class CXBlock(nn.Module):
141
141
 
142
142
  def __init__(
143
143
  self,
144
- dim,
145
- kernel_size=7,
146
- padding=3,
147
- drop_path=0.0,
148
- layer_scale_init_value=1e-6,
149
- use_dwconv=True,
144
+ dim: int,
145
+ kernel_size: int = 7,
146
+ padding: int = 3,
147
+ drop_path: float = 0.0,
148
+ layer_scale_init_value: float = 1e-6,
149
+ use_dwconv: bool = True,
150
150
  ):
151
151
  """
152
152
  Initialize a ConvNeXt Block for efficient feature extraction in convolutional neural networks.
@@ -188,8 +188,8 @@ class CXBlock(nn.Module):
188
188
  )
189
189
  self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
190
190
 
191
- def forward(self, x):
192
- """Applies ConvNeXt block operations to input tensor, including convolutions and residual connection."""
191
+ def forward(self, x: Tensor) -> Tensor:
192
+ """Apply ConvNeXt block operations to input tensor, including convolutions and residual connection."""
193
193
  input = x
194
194
  x = self.dwconv(x)
195
195
  x = self.norm(x)
@@ -227,9 +227,9 @@ class Fuser(nn.Module):
227
227
  torch.Size([1, 256, 32, 32])
228
228
  """
229
229
 
230
- def __init__(self, layer, num_layers, dim=None, input_projection=False):
230
+ def __init__(self, layer: nn.Module, num_layers: int, dim: Optional[int] = None, input_projection: bool = False):
231
231
  """
232
- Initializes the Fuser module for feature fusion through multiple layers.
232
+ Initialize the Fuser module for feature fusion through multiple layers.
233
233
 
234
234
  This module creates a sequence of identical layers and optionally applies an input projection.
235
235
 
@@ -253,8 +253,8 @@ class Fuser(nn.Module):
253
253
  assert dim is not None
254
254
  self.proj = nn.Conv2d(dim, dim, kernel_size=1)
255
255
 
256
- def forward(self, x):
257
- """Applies a series of layers to the input tensor, optionally projecting it first."""
256
+ def forward(self, x: Tensor) -> Tensor:
257
+ """Apply a series of layers to the input tensor, optionally projecting it first."""
258
258
  x = self.proj(x)
259
259
  for layer in self.layers:
260
260
  x = layer(x)
@@ -300,7 +300,7 @@ class SAM2TwoWayAttentionBlock(TwoWayAttentionBlock):
300
300
  skip_first_layer_pe: bool = False,
301
301
  ) -> None:
302
302
  """
303
- Initializes a SAM2TwoWayAttentionBlock for performing self-attention and cross-attention in two directions.
303
+ Initialize a SAM2TwoWayAttentionBlock for performing self-attention and cross-attention in two directions.
304
304
 
305
305
  This block extends the TwoWayAttentionBlock and consists of four main components: self-attention on sparse
306
306
  inputs, cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and cross-attention
@@ -363,7 +363,7 @@ class SAM2TwoWayTransformer(TwoWayTransformer):
363
363
  attention_downsample_rate: int = 2,
364
364
  ) -> None:
365
365
  """
366
- Initializes a SAM2TwoWayTransformer instance.
366
+ Initialize a SAM2TwoWayTransformer instance.
367
367
 
368
368
  This transformer decoder attends to an input image using queries with supplied positional embeddings.
369
369
  It is designed for tasks like object detection, image segmentation, and point cloud processing.
@@ -430,12 +430,12 @@ class RoPEAttention(Attention):
430
430
  def __init__(
431
431
  self,
432
432
  *args,
433
- rope_theta=10000.0,
434
- rope_k_repeat=False,
435
- feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution
433
+ rope_theta: float = 10000.0,
434
+ rope_k_repeat: bool = False,
435
+ feat_sizes: Tuple[int, int] = (32, 32), # [w, h] for stride 16 feats at 512 resolution
436
436
  **kwargs,
437
437
  ):
438
- """Initializes RoPEAttention with rotary position encoding for enhanced positional awareness."""
438
+ """Initialize RoPEAttention with rotary position encoding for enhanced positional awareness."""
439
439
  super().__init__(*args, **kwargs)
440
440
 
441
441
  self.compute_cis = partial(compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta)
@@ -444,7 +444,7 @@ class RoPEAttention(Attention):
444
444
  self.rope_k_repeat = rope_k_repeat # repeat q rope to match k length, needed for cross-attention to memories
445
445
 
446
446
  def forward(self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0) -> Tensor:
447
- """Applies rotary position encoding and computes attention between query, key, and value tensors."""
447
+ """Apply rotary position encoding and compute attention between query, key, and value tensors."""
448
448
  q = self.q_proj(q)
449
449
  k = self.k_proj(k)
450
450
  v = self.v_proj(v)
@@ -486,7 +486,7 @@ class RoPEAttention(Attention):
486
486
 
487
487
 
488
488
  def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
489
- """Applies pooling and optional normalization to a tensor, handling spatial dimension permutations."""
489
+ """Apply pooling and optional normalization to a tensor, handling spatial dimension permutations."""
490
490
  if pool is None:
491
491
  return x
492
492
  # (B, H, W, C) -> (B, C, H, W)
@@ -537,7 +537,7 @@ class MultiScaleAttention(nn.Module):
537
537
  num_heads: int,
538
538
  q_pool: nn.Module = None,
539
539
  ):
540
- """Initializes multiscale attention with optional query pooling for efficient feature extraction."""
540
+ """Initialize multiscale attention with optional query pooling for efficient feature extraction."""
541
541
  super().__init__()
542
542
 
543
543
  self.dim = dim
@@ -552,7 +552,7 @@ class MultiScaleAttention(nn.Module):
552
552
  self.proj = nn.Linear(dim_out, dim_out)
553
553
 
554
554
  def forward(self, x: torch.Tensor) -> torch.Tensor:
555
- """Applies multiscale attention with optional query pooling to extract multiscale features."""
555
+ """Apply multiscale attention with optional query pooling to extract multiscale features."""
556
556
  B, H, W, _ = x.shape
557
557
  # qkv with shape (B, H * W, 3, nHead, C)
558
558
  qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
@@ -620,10 +620,10 @@ class MultiScaleBlock(nn.Module):
620
620
  drop_path: float = 0.0,
621
621
  norm_layer: Union[nn.Module, str] = "LayerNorm",
622
622
  q_stride: Tuple[int, int] = None,
623
- act_layer: nn.Module = nn.GELU,
623
+ act_layer: Type[nn.Module] = nn.GELU,
624
624
  window_size: int = 0,
625
625
  ):
626
- """Initializes a multiscale attention block with window partitioning and optional query pooling."""
626
+ """Initialize a multiscale attention block with window partitioning and optional query pooling."""
627
627
  super().__init__()
628
628
 
629
629
  if isinstance(norm_layer, str):
@@ -660,7 +660,7 @@ class MultiScaleBlock(nn.Module):
660
660
  self.proj = nn.Linear(dim, dim_out)
661
661
 
662
662
  def forward(self, x: torch.Tensor) -> torch.Tensor:
663
- """Processes input through multiscale attention and MLP, with optional windowing and downsampling."""
663
+ """Process input through multiscale attention and MLP, with optional windowing and downsampling."""
664
664
  shortcut = x # B, H, W, C
665
665
  x = self.norm1(x)
666
666
 
@@ -725,12 +725,12 @@ class PositionEmbeddingSine(nn.Module):
725
725
 
726
726
  def __init__(
727
727
  self,
728
- num_pos_feats,
728
+ num_pos_feats: int,
729
729
  temperature: int = 10000,
730
730
  normalize: bool = True,
731
731
  scale: Optional[float] = None,
732
732
  ):
733
- """Initializes sinusoidal position embeddings for 2D image inputs."""
733
+ """Initialize sinusoidal position embeddings for 2D image inputs."""
734
734
  super().__init__()
735
735
  assert num_pos_feats % 2 == 0, "Expecting even model width"
736
736
  self.num_pos_feats = num_pos_feats // 2
@@ -744,8 +744,8 @@ class PositionEmbeddingSine(nn.Module):
744
744
 
745
745
  self.cache = {}
746
746
 
747
- def _encode_xy(self, x, y):
748
- """Encodes 2D positions using sine/cosine functions for transformer positional embeddings."""
747
+ def _encode_xy(self, x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]:
748
+ """Encode 2D positions using sine/cosine functions for transformer positional embeddings."""
749
749
  assert len(x) == len(y) and x.ndim == y.ndim == 1
750
750
  x_embed = x * self.scale
751
751
  y_embed = y * self.scale
@@ -760,16 +760,16 @@ class PositionEmbeddingSine(nn.Module):
760
760
  return pos_x, pos_y
761
761
 
762
762
  @torch.no_grad()
763
- def encode_boxes(self, x, y, w, h):
764
- """Encodes box coordinates and dimensions into positional embeddings for detection."""
763
+ def encode_boxes(self, x: Tensor, y: Tensor, w: Tensor, h: Tensor) -> Tensor:
764
+ """Encode box coordinates and dimensions into positional embeddings for detection."""
765
765
  pos_x, pos_y = self._encode_xy(x, y)
766
766
  return torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
767
767
 
768
768
  encode = encode_boxes # Backwards compatibility
769
769
 
770
770
  @torch.no_grad()
771
- def encode_points(self, x, y, labels):
772
- """Encodes 2D points with sinusoidal embeddings and appends labels."""
771
+ def encode_points(self, x: Tensor, y: Tensor, labels: Tensor) -> Tensor:
772
+ """Encode 2D points with sinusoidal embeddings and append labels."""
773
773
  (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
774
774
  assert bx == by and nx == ny and bx == bl and nx == nl
775
775
  pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
@@ -777,8 +777,8 @@ class PositionEmbeddingSine(nn.Module):
777
777
  return torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
778
778
 
779
779
  @torch.no_grad()
780
- def forward(self, x: torch.Tensor):
781
- """Generates sinusoidal position embeddings for 2D inputs like images."""
780
+ def forward(self, x: torch.Tensor) -> Tensor:
781
+ """Generate sinusoidal position embeddings for 2D inputs like images."""
782
782
  cache_key = (x.shape[-2], x.shape[-1])
783
783
  if cache_key in self.cache:
784
784
  return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
@@ -834,7 +834,7 @@ class PositionEmbeddingRandom(nn.Module):
834
834
  """
835
835
 
836
836
  def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
837
- """Initializes random spatial frequency position embedding for transformers."""
837
+ """Initialize random spatial frequency position embedding for transformers."""
838
838
  super().__init__()
839
839
  if scale is None or scale <= 0.0:
840
840
  scale = 1.0
@@ -845,7 +845,7 @@ class PositionEmbeddingRandom(nn.Module):
845
845
  torch.backends.cudnn.deterministic = False
846
846
 
847
847
  def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
848
- """Encodes normalized [0,1] coordinates using random spatial frequencies."""
848
+ """Encode normalized [0,1] coordinates using random spatial frequencies."""
849
849
  # Assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
850
850
  coords = 2 * coords - 1
851
851
  coords = coords @ self.positional_encoding_gaussian_matrix
@@ -854,7 +854,7 @@ class PositionEmbeddingRandom(nn.Module):
854
854
  return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
855
855
 
856
856
  def forward(self, size: Tuple[int, int]) -> torch.Tensor:
857
- """Generates positional encoding for a grid using random spatial frequencies."""
857
+ """Generate positional encoding for a grid using random spatial frequencies."""
858
858
  h, w = size
859
859
  device: Any = self.positional_encoding_gaussian_matrix.device
860
860
  grid = torch.ones((h, w), device=device, dtype=torch.float32)
@@ -867,7 +867,7 @@ class PositionEmbeddingRandom(nn.Module):
867
867
  return pe.permute(2, 0, 1) # C x H x W
868
868
 
869
869
  def forward_with_coords(self, coords_input: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor:
870
- """Positionally encodes input coordinates, normalizing them to [0,1] based on the given image size."""
870
+ """Positionally encode input coordinates, normalizing them to [0,1] based on the given image size."""
871
871
  coords = coords_input.clone()
872
872
  coords[:, :, 0] = coords[:, :, 0] / image_size[1]
873
873
  coords[:, :, 1] = coords[:, :, 1] / image_size[0]
@@ -915,7 +915,7 @@ class Block(nn.Module):
915
915
  input_size: Optional[Tuple[int, int]] = None,
916
916
  ) -> None:
917
917
  """
918
- Initializes a transformer block with optional window attention and relative positional embeddings.
918
+ Initialize a transformer block with optional window attention and relative positional embeddings.
919
919
 
920
920
  This constructor sets up a transformer block that can use either global or windowed self-attention,
921
921
  followed by a feed-forward network. It supports relative positional embeddings and is designed
@@ -931,7 +931,7 @@ class Block(nn.Module):
931
931
  use_rel_pos (bool): If True, uses relative positional embeddings in attention.
932
932
  rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero.
933
933
  window_size (int): Size of attention window. If 0, uses global attention.
934
- input_size (Optional[Tuple[int, int]]): Input resolution for calculating relative positional parameter size.
934
+ input_size (Tuple[int, int] | None): Input resolution for calculating relative positional parameter size.
935
935
 
936
936
  Examples:
937
937
  >>> block = Block(dim=256, num_heads=8, window_size=7)
@@ -957,7 +957,7 @@ class Block(nn.Module):
957
957
  self.window_size = window_size
958
958
 
959
959
  def forward(self, x: torch.Tensor) -> torch.Tensor:
960
- """Processes input through transformer block with optional windowed self-attention and residual connection."""
960
+ """Process input through transformer block with optional windowed self-attention and residual connection."""
961
961
  shortcut = x
962
962
  x = self.norm1(x)
963
963
  # Window partition
@@ -976,34 +976,30 @@ class Block(nn.Module):
976
976
 
977
977
  class REAttention(nn.Module):
978
978
  """
979
- Rotary Embedding Attention module for efficient self-attention in transformer architectures.
979
+ Relative Position Attention module for efficient self-attention in transformer architectures.
980
980
 
981
- This class implements a multi-head attention mechanism with rotary positional embeddings, designed
981
+ This class implements a multi-head attention mechanism with relative positional embeddings, designed
982
982
  for use in vision transformer models. It supports optional query pooling and window partitioning
983
983
  for efficient processing of large inputs.
984
984
 
985
985
  Attributes:
986
- compute_cis (Callable): Function to compute axial complex numbers for rotary encoding.
987
- freqs_cis (Tensor): Precomputed frequency tensor for rotary encoding.
988
- rope_k_repeat (bool): Flag to repeat query RoPE to match key length for cross-attention to memories.
989
- q_proj (nn.Linear): Linear projection for query.
990
- k_proj (nn.Linear): Linear projection for key.
991
- v_proj (nn.Linear): Linear projection for value.
992
- out_proj (nn.Linear): Output projection.
993
986
  num_heads (int): Number of attention heads.
994
- internal_dim (int): Internal dimension for attention computation.
987
+ scale (float): Scaling factor for attention computation.
988
+ qkv (nn.Linear): Linear projection for query, key, and value.
989
+ proj (nn.Linear): Output projection layer.
990
+ use_rel_pos (bool): Whether to use relative positional embeddings.
991
+ rel_pos_h (nn.Parameter): Relative positional embeddings for height dimension.
992
+ rel_pos_w (nn.Parameter): Relative positional embeddings for width dimension.
995
993
 
996
994
  Methods:
997
- forward: Applies rotary position encoding and computes attention between query, key, and value tensors.
995
+ forward: Applies multi-head attention with optional relative positional encoding to input tensor.
998
996
 
999
997
  Examples:
1000
- >>> rope_attn = REAttention(embedding_dim=256, num_heads=8, rope_theta=10000.0, feat_sizes=(32, 32))
1001
- >>> q = torch.randn(1, 1024, 256)
1002
- >>> k = torch.randn(1, 1024, 256)
1003
- >>> v = torch.randn(1, 1024, 256)
1004
- >>> output = rope_attn(q, k, v)
998
+ >>> attention = REAttention(dim=256, num_heads=8, input_size=(32, 32))
999
+ >>> x = torch.randn(1, 32, 32, 256)
1000
+ >>> output = attention(x)
1005
1001
  >>> print(output.shape)
1006
- torch.Size([1, 1024, 256])
1002
+ torch.Size([1, 32, 32, 256])
1007
1003
  """
1008
1004
 
1009
1005
  def __init__(
@@ -1016,19 +1012,19 @@ class REAttention(nn.Module):
1016
1012
  input_size: Optional[Tuple[int, int]] = None,
1017
1013
  ) -> None:
1018
1014
  """
1019
- Initializes a Relative Position Attention module for transformer-based architectures.
1015
+ Initialize a Relative Position Attention module for transformer-based architectures.
1020
1016
 
1021
1017
  This module implements multi-head attention with optional relative positional encodings, designed
1022
1018
  specifically for vision tasks in transformer models.
1023
1019
 
1024
1020
  Args:
1025
1021
  dim (int): Number of input channels.
1026
- num_heads (int): Number of attention heads. Default is 8.
1027
- qkv_bias (bool): If True, adds a learnable bias to query, key, value projections. Default is True.
1028
- use_rel_pos (bool): If True, uses relative positional encodings. Default is False.
1029
- rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero. Default is True.
1022
+ num_heads (int): Number of attention heads.
1023
+ qkv_bias (bool): If True, adds a learnable bias to query, key, value projections.
1024
+ use_rel_pos (bool): If True, uses relative positional encodings.
1025
+ rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero.
1030
1026
  input_size (Tuple[int, int] | None): Input resolution for calculating relative positional parameter size.
1031
- Required if use_rel_pos is True. Default is None.
1027
+ Required if use_rel_pos is True.
1032
1028
 
1033
1029
  Examples:
1034
1030
  >>> attention = REAttention(dim=256, num_heads=8, input_size=(32, 32))
@@ -1053,7 +1049,7 @@ class REAttention(nn.Module):
1053
1049
  self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
1054
1050
 
1055
1051
  def forward(self, x: torch.Tensor) -> torch.Tensor:
1056
- """Applies multi-head attention with optional relative positional encoding to input tensor."""
1052
+ """Apply multi-head attention with optional relative positional encoding to input tensor."""
1057
1053
  B, H, W, _ = x.shape
1058
1054
  # qkv with shape (3, B, nHead, H * W, C)
1059
1055
  qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
@@ -1101,7 +1097,7 @@ class PatchEmbed(nn.Module):
1101
1097
  embed_dim: int = 768,
1102
1098
  ) -> None:
1103
1099
  """
1104
- Initializes the PatchEmbed module for converting image patches to embeddings.
1100
+ Initialize the PatchEmbed module for converting image patches to embeddings.
1105
1101
 
1106
1102
  This module is typically used as the first layer in vision transformer architectures to transform
1107
1103
  image data into a suitable format for subsequent transformer blocks.
@@ -1125,5 +1121,5 @@ class PatchEmbed(nn.Module):
1125
1121
  self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
1126
1122
 
1127
1123
  def forward(self, x: torch.Tensor) -> torch.Tensor:
1128
- """Computes patch embedding by applying convolution and transposing resulting tensor."""
1124
+ """Compute patch embedding by applying convolution and transposing resulting tensor."""
1129
1125
  return self.proj(x).permute(0, 2, 3, 1) # B C H W -> B H W C
@@ -27,7 +27,7 @@ class MaskDecoder(nn.Module):
27
27
  iou_prediction_head (nn.Module): MLP for predicting mask quality.
28
28
 
29
29
  Methods:
30
- forward: Predicts masks given image and prompt embeddings.
30
+ forward: Predict masks given image and prompt embeddings.
31
31
  predict_masks: Internal method for mask prediction.
32
32
 
33
33
  Examples:
@@ -129,7 +129,6 @@ class MaskDecoder(nn.Module):
129
129
  masks = masks[:, mask_slice, :, :]
130
130
  iou_pred = iou_pred[:, mask_slice]
131
131
 
132
- # Prepare output
133
132
  return masks, iou_pred
134
133
 
135
134
  def predict_masks(
@@ -201,10 +200,10 @@ class SAM2MaskDecoder(nn.Module):
201
200
  dynamic_multimask_stability_thresh (float): Threshold for dynamic multimask stability.
202
201
 
203
202
  Methods:
204
- forward: Predicts masks given image and prompt embeddings.
205
- predict_masks: Predicts instance segmentation masks from image and prompt embeddings.
206
- _get_stability_scores: Computes mask stability scores based on IoU between thresholds.
207
- _dynamic_multimask_via_stability: Dynamically selects the most stable mask output.
203
+ forward: Predict masks given image and prompt embeddings.
204
+ predict_masks: Predict instance segmentation masks from image and prompt embeddings.
205
+ _get_stability_scores: Compute mask stability scores based on IoU between thresholds.
206
+ _dynamic_multimask_via_stability: Dynamically select the most stable mask output.
208
207
 
209
208
  Examples:
210
209
  >>> image_embeddings = torch.rand(1, 256, 64, 64)
@@ -330,7 +329,7 @@ class SAM2MaskDecoder(nn.Module):
330
329
  dense_prompt_embeddings (torch.Tensor): Embeddings of the mask inputs with shape (B, C, H, W).
331
330
  multimask_output (bool): Whether to return multiple masks or a single mask.
332
331
  repeat_image (bool): Flag to repeat the image embeddings.
333
- high_res_features (List[torch.Tensor] | None): Optional high-resolution features.
332
+ high_res_features (List[torch.Tensor] | None, optional): Optional high-resolution features.
334
333
 
335
334
  Returns:
336
335
  masks (torch.Tensor): Batched predicted masks with shape (B, N, H, W).
@@ -377,7 +376,6 @@ class SAM2MaskDecoder(nn.Module):
377
376
  # are always the single mask token (and we'll let it be the object-memory token).
378
377
  sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape
379
378
 
380
- # Prepare output
381
379
  return masks, iou_pred, sam_tokens_out, object_score_logits
382
380
 
383
381
  def predict_masks(
@@ -35,7 +35,7 @@ class ImageEncoderViT(nn.Module):
35
35
  neck (nn.Sequential): Neck module to further process the output.
36
36
 
37
37
  Methods:
38
- forward: Processes input through patch embedding, positional embedding, blocks, and neck.
38
+ forward: Process input through patch embedding, positional embedding, blocks, and neck.
39
39
 
40
40
  Examples:
41
41
  >>> import torch
@@ -103,7 +103,7 @@ class ImageEncoderViT(nn.Module):
103
103
 
104
104
  self.pos_embed: Optional[nn.Parameter] = None
105
105
  if use_abs_pos:
106
- # Initialize absolute positional embedding with pretrain image size.
106
+ # Initialize absolute positional embedding with pretrain image size
107
107
  self.pos_embed = nn.Parameter(torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim))
108
108
 
109
109
  self.blocks = nn.ModuleList()
@@ -157,7 +157,7 @@ class ImageEncoderViT(nn.Module):
157
157
 
158
158
  class PromptEncoder(nn.Module):
159
159
  """
160
- Encodes different types of prompts for input to SAM's mask decoder, producing sparse and dense embeddings.
160
+ Encode different types of prompts for input to SAM's mask decoder, producing sparse and dense embeddings.
161
161
 
162
162
  Attributes:
163
163
  embed_dim (int): Dimension of the embeddings.
@@ -172,8 +172,8 @@ class PromptEncoder(nn.Module):
172
172
  no_mask_embed (nn.Embedding): Embedding for cases where no mask is provided.
173
173
 
174
174
  Methods:
175
- get_dense_pe: Returns the positional encoding used to encode point prompts.
176
- forward: Embeds different types of prompts, returning both sparse and dense embeddings.
175
+ get_dense_pe: Return the positional encoding used to encode point prompts.
176
+ forward: Embed different types of prompts, returning both sparse and dense embeddings.
177
177
 
178
178
  Examples:
179
179
  >>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
@@ -321,9 +321,8 @@ class PromptEncoder(nn.Module):
321
321
  masks (torch.Tensor | None): Masks to embed with shape (B, 1, H, W).
322
322
 
323
323
  Returns:
324
- (Tuple[torch.Tensor, torch.Tensor]): A tuple containing:
325
- - sparse_embeddings (torch.Tensor): Sparse embeddings for points and boxes with shape (B, N, embed_dim).
326
- - dense_embeddings (torch.Tensor): Dense embeddings for masks of shape (B, embed_dim, embed_H, embed_W).
324
+ sparse_embeddings (torch.Tensor): Sparse embeddings for points and boxes with shape (B, N, embed_dim).
325
+ dense_embeddings (torch.Tensor): Dense embeddings for masks of shape (B, embed_dim, embed_H, embed_W).
327
326
 
328
327
  Examples:
329
328
  >>> encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
@@ -394,7 +393,7 @@ class MemoryEncoder(nn.Module):
394
393
 
395
394
  Args:
396
395
  out_dim (int): Output dimension of the encoded features.
397
- in_dim (int): Input dimension of the pixel features. Default is 256.
396
+ in_dim (int): Input dimension of the pixel features.
398
397
 
399
398
  Examples:
400
399
  >>> encoder = MemoryEncoder(out_dim=256, in_dim=256)
@@ -420,7 +419,7 @@ class MemoryEncoder(nn.Module):
420
419
  pix_feat: torch.Tensor,
421
420
  masks: torch.Tensor,
422
421
  skip_mask_sigmoid: bool = False,
423
- ) -> Tuple[torch.Tensor, torch.Tensor]:
422
+ ) -> dict:
424
423
  """Process pixel features and masks to generate encoded memory representations for segmentation."""
425
424
  if not skip_mask_sigmoid:
426
425
  masks = F.sigmoid(masks)
@@ -499,7 +498,7 @@ class ImageEncoder(nn.Module):
499
498
  )
500
499
 
501
500
  def forward(self, sample: torch.Tensor):
502
- """Encode input through patch embedding, positional embedding, transformer blocks, and neck module."""
501
+ """Encode input through trunk and neck networks, returning multiscale features and positional encodings."""
503
502
  features, pos = self.neck(self.trunk(sample))
504
503
  if self.scalp > 0:
505
504
  # Discard the lowest resolution features
@@ -552,7 +551,7 @@ class FpnNeck(nn.Module):
552
551
  fpn_top_down_levels: Optional[List[int]] = None,
553
552
  ):
554
553
  """
555
- Initializes a modified Feature Pyramid Network (FPN) neck.
554
+ Initialize a modified Feature Pyramid Network (FPN) neck.
556
555
 
557
556
  This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing,
558
557
  similar to ViT positional embedding interpolation.
@@ -594,18 +593,18 @@ class FpnNeck(nn.Module):
594
593
  assert fuse_type in {"sum", "avg"}
595
594
  self.fuse_type = fuse_type
596
595
 
597
- # levels to have top-down features in its outputs
596
+ # Levels to have top-down features in its outputs
598
597
  # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
599
598
  # have top-down propagation, while outputs of level 0 and level 1 have only
600
- # lateral features from the same backbone level.
599
+ # lateral features from the same backbone level
601
600
  if fpn_top_down_levels is None:
602
- # default is to have top-down features on all levels
601
+ # Default is to have top-down features on all levels
603
602
  fpn_top_down_levels = range(len(self.convs))
604
603
  self.fpn_top_down_levels = list(fpn_top_down_levels)
605
604
 
606
605
  def forward(self, xs: List[torch.Tensor]):
607
606
  """
608
- Performs forward pass through the Feature Pyramid Network (FPN) neck.
607
+ Perform forward pass through the Feature Pyramid Network (FPN) neck.
609
608
 
610
609
  This method processes a list of input tensors from the backbone through the FPN, applying lateral connections
611
610
  and top-down feature fusion. It generates output feature maps and corresponding positional encodings.
@@ -614,10 +613,9 @@ class FpnNeck(nn.Module):
614
613
  xs (List[torch.Tensor]): List of input tensors from the backbone, each with shape (B, C, H, W).
615
614
 
616
615
  Returns:
617
- (Tuple[List[torch.Tensor], List[torch.Tensor]]): A tuple containing:
618
- - out (List[torch.Tensor]): List of output feature maps after FPN processing, each with shape
619
- (B, d_model, H, W).
620
- - pos (List[torch.Tensor]): List of positional encodings corresponding to each output feature map.
616
+ out (List[torch.Tensor]): List of output feature maps after FPN processing, each with shape
617
+ (B, d_model, H, W).
618
+ pos (List[torch.Tensor]): List of positional encodings corresponding to each output feature map.
621
619
 
622
620
  Examples:
623
621
  >>> fpn_neck = FpnNeck(d_model=256, backbone_channel_list=[64, 128, 256, 512])
@@ -629,10 +627,10 @@ class FpnNeck(nn.Module):
629
627
  out = [None] * len(self.convs)
630
628
  pos = [None] * len(self.convs)
631
629
  assert len(xs) == len(self.convs)
632
- # fpn forward pass
630
+ # FPN forward pass
633
631
  # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
634
632
  prev_features = None
635
- # forward in top-down order (from low to high resolution)
633
+ # Forward in top-down order (from low to high resolution)
636
634
  n = len(self.convs) - 1
637
635
  for i in range(n, -1, -1):
638
636
  x = xs[i]
@@ -763,7 +761,7 @@ class Hiera(nn.Module):
763
761
  stride=(4, 4),
764
762
  padding=(3, 3),
765
763
  )
766
- # Which blocks have global att?
764
+ # Which blocks have global attention?
767
765
  self.global_att_blocks = global_att_blocks
768
766
 
769
767
  # Windowed positional embedding (https://arxiv.org/abs/2311.05613)
@@ -778,8 +776,7 @@ class Hiera(nn.Module):
778
776
 
779
777
  for i in range(depth):
780
778
  dim_out = embed_dim
781
- # lags by a block, so first block of
782
- # next stage uses an initial window size
779
+ # Lags by a block, so first block of next stage uses an initial window size
783
780
  # of previous stage and final window size of current stage
784
781
  window_size = self.window_spec[cur_stage - 1]
785
782
 
@@ -841,7 +838,7 @@ class Hiera(nn.Module):
841
838
  x = self.patch_embed(x)
842
839
  # x: (B, H, W, C)
843
840
 
844
- # Add pos embed
841
+ # Add positional embedding
845
842
  x = x + self._get_pos_embed(x.shape[1:3])
846
843
 
847
844
  outputs = []
@@ -144,7 +144,19 @@ class MemoryAttentionLayer(nn.Module):
144
144
  query_pos: Optional[Tensor] = None,
145
145
  num_k_exclude_rope: int = 0,
146
146
  ) -> torch.Tensor:
147
- """Process input tensors through self-attention, cross-attention, and feedforward network layers."""
147
+ """
148
+ Process input tensors through self-attention, cross-attention, and feedforward network layers.
149
+
150
+ Args:
151
+ tgt (Tensor): Target tensor for self-attention with shape (N, L, D).
152
+ memory (Tensor): Memory tensor for cross-attention with shape (N, S, D).
153
+ pos (Optional[Tensor]): Positional encoding for memory tensor.
154
+ query_pos (Optional[Tensor]): Positional encoding for target tensor.
155
+ num_k_exclude_rope (int): Number of keys to exclude from rotary position embedding.
156
+
157
+ Returns:
158
+ (torch.Tensor): Processed tensor after attention and feedforward layers with shape (N, L, D).
159
+ """
148
160
  tgt = self._forward_sa(tgt, query_pos)
149
161
  tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
150
162
  # MLP