dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
  2. dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
  3. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -6
  5. tests/conftest.py +15 -39
  6. tests/test_cli.py +17 -17
  7. tests/test_cuda.py +17 -8
  8. tests/test_engine.py +36 -10
  9. tests/test_exports.py +98 -37
  10. tests/test_integrations.py +12 -15
  11. tests/test_python.py +126 -82
  12. tests/test_solutions.py +319 -135
  13. ultralytics/__init__.py +27 -9
  14. ultralytics/cfg/__init__.py +83 -87
  15. ultralytics/cfg/datasets/Argoverse.yaml +4 -4
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
  17. ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
  18. ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
  19. ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
  20. ultralytics/cfg/datasets/ImageNet.yaml +3 -3
  21. ultralytics/cfg/datasets/Objects365.yaml +24 -20
  22. ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
  23. ultralytics/cfg/datasets/VOC.yaml +10 -13
  24. ultralytics/cfg/datasets/VisDrone.yaml +43 -33
  25. ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
  26. ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
  27. ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
  28. ultralytics/cfg/datasets/coco-pose.yaml +26 -4
  29. ultralytics/cfg/datasets/coco.yaml +4 -4
  30. ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
  31. ultralytics/cfg/datasets/coco128.yaml +2 -2
  32. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  33. ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
  34. ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
  35. ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
  36. ultralytics/cfg/datasets/coco8.yaml +2 -2
  37. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  38. ultralytics/cfg/datasets/crack-seg.yaml +5 -5
  39. ultralytics/cfg/datasets/dog-pose.yaml +32 -4
  40. ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
  41. ultralytics/cfg/datasets/dota8.yaml +2 -2
  42. ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
  43. ultralytics/cfg/datasets/lvis.yaml +9 -9
  44. ultralytics/cfg/datasets/medical-pills.yaml +4 -5
  45. ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
  46. ultralytics/cfg/datasets/package-seg.yaml +5 -5
  47. ultralytics/cfg/datasets/signature.yaml +4 -4
  48. ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
  49. ultralytics/cfg/datasets/xView.yaml +5 -5
  50. ultralytics/cfg/default.yaml +96 -93
  51. ultralytics/cfg/trackers/botsort.yaml +16 -17
  52. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  53. ultralytics/data/__init__.py +4 -4
  54. ultralytics/data/annotator.py +12 -12
  55. ultralytics/data/augment.py +531 -564
  56. ultralytics/data/base.py +76 -81
  57. ultralytics/data/build.py +206 -42
  58. ultralytics/data/converter.py +179 -78
  59. ultralytics/data/dataset.py +121 -121
  60. ultralytics/data/loaders.py +114 -91
  61. ultralytics/data/split.py +28 -15
  62. ultralytics/data/split_dota.py +67 -48
  63. ultralytics/data/utils.py +110 -89
  64. ultralytics/engine/exporter.py +422 -460
  65. ultralytics/engine/model.py +224 -252
  66. ultralytics/engine/predictor.py +94 -89
  67. ultralytics/engine/results.py +345 -595
  68. ultralytics/engine/trainer.py +231 -134
  69. ultralytics/engine/tuner.py +279 -73
  70. ultralytics/engine/validator.py +53 -46
  71. ultralytics/hub/__init__.py +26 -28
  72. ultralytics/hub/auth.py +30 -16
  73. ultralytics/hub/google/__init__.py +34 -36
  74. ultralytics/hub/session.py +53 -77
  75. ultralytics/hub/utils.py +23 -109
  76. ultralytics/models/__init__.py +1 -1
  77. ultralytics/models/fastsam/__init__.py +1 -1
  78. ultralytics/models/fastsam/model.py +36 -18
  79. ultralytics/models/fastsam/predict.py +33 -44
  80. ultralytics/models/fastsam/utils.py +4 -5
  81. ultralytics/models/fastsam/val.py +12 -14
  82. ultralytics/models/nas/__init__.py +1 -1
  83. ultralytics/models/nas/model.py +16 -20
  84. ultralytics/models/nas/predict.py +12 -14
  85. ultralytics/models/nas/val.py +4 -5
  86. ultralytics/models/rtdetr/__init__.py +1 -1
  87. ultralytics/models/rtdetr/model.py +9 -9
  88. ultralytics/models/rtdetr/predict.py +22 -17
  89. ultralytics/models/rtdetr/train.py +20 -16
  90. ultralytics/models/rtdetr/val.py +79 -59
  91. ultralytics/models/sam/__init__.py +8 -2
  92. ultralytics/models/sam/amg.py +53 -38
  93. ultralytics/models/sam/build.py +29 -31
  94. ultralytics/models/sam/model.py +33 -38
  95. ultralytics/models/sam/modules/blocks.py +159 -182
  96. ultralytics/models/sam/modules/decoders.py +38 -47
  97. ultralytics/models/sam/modules/encoders.py +114 -133
  98. ultralytics/models/sam/modules/memory_attention.py +38 -31
  99. ultralytics/models/sam/modules/sam.py +114 -93
  100. ultralytics/models/sam/modules/tiny_encoder.py +268 -291
  101. ultralytics/models/sam/modules/transformer.py +59 -66
  102. ultralytics/models/sam/modules/utils.py +55 -72
  103. ultralytics/models/sam/predict.py +745 -341
  104. ultralytics/models/utils/loss.py +118 -107
  105. ultralytics/models/utils/ops.py +118 -71
  106. ultralytics/models/yolo/__init__.py +1 -1
  107. ultralytics/models/yolo/classify/predict.py +28 -26
  108. ultralytics/models/yolo/classify/train.py +50 -81
  109. ultralytics/models/yolo/classify/val.py +68 -61
  110. ultralytics/models/yolo/detect/predict.py +12 -15
  111. ultralytics/models/yolo/detect/train.py +56 -46
  112. ultralytics/models/yolo/detect/val.py +279 -223
  113. ultralytics/models/yolo/model.py +167 -86
  114. ultralytics/models/yolo/obb/predict.py +7 -11
  115. ultralytics/models/yolo/obb/train.py +23 -25
  116. ultralytics/models/yolo/obb/val.py +107 -99
  117. ultralytics/models/yolo/pose/__init__.py +1 -1
  118. ultralytics/models/yolo/pose/predict.py +12 -14
  119. ultralytics/models/yolo/pose/train.py +31 -69
  120. ultralytics/models/yolo/pose/val.py +119 -254
  121. ultralytics/models/yolo/segment/predict.py +21 -25
  122. ultralytics/models/yolo/segment/train.py +12 -66
  123. ultralytics/models/yolo/segment/val.py +126 -305
  124. ultralytics/models/yolo/world/train.py +53 -45
  125. ultralytics/models/yolo/world/train_world.py +51 -32
  126. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  127. ultralytics/models/yolo/yoloe/predict.py +30 -37
  128. ultralytics/models/yolo/yoloe/train.py +89 -71
  129. ultralytics/models/yolo/yoloe/train_seg.py +15 -17
  130. ultralytics/models/yolo/yoloe/val.py +56 -41
  131. ultralytics/nn/__init__.py +9 -11
  132. ultralytics/nn/autobackend.py +179 -107
  133. ultralytics/nn/modules/__init__.py +67 -67
  134. ultralytics/nn/modules/activation.py +8 -7
  135. ultralytics/nn/modules/block.py +302 -323
  136. ultralytics/nn/modules/conv.py +61 -104
  137. ultralytics/nn/modules/head.py +488 -186
  138. ultralytics/nn/modules/transformer.py +183 -123
  139. ultralytics/nn/modules/utils.py +15 -20
  140. ultralytics/nn/tasks.py +327 -203
  141. ultralytics/nn/text_model.py +81 -65
  142. ultralytics/py.typed +1 -0
  143. ultralytics/solutions/__init__.py +12 -12
  144. ultralytics/solutions/ai_gym.py +19 -27
  145. ultralytics/solutions/analytics.py +36 -26
  146. ultralytics/solutions/config.py +29 -28
  147. ultralytics/solutions/distance_calculation.py +23 -24
  148. ultralytics/solutions/heatmap.py +17 -19
  149. ultralytics/solutions/instance_segmentation.py +21 -19
  150. ultralytics/solutions/object_blurrer.py +16 -17
  151. ultralytics/solutions/object_counter.py +48 -53
  152. ultralytics/solutions/object_cropper.py +22 -16
  153. ultralytics/solutions/parking_management.py +61 -58
  154. ultralytics/solutions/queue_management.py +19 -19
  155. ultralytics/solutions/region_counter.py +63 -50
  156. ultralytics/solutions/security_alarm.py +22 -25
  157. ultralytics/solutions/similarity_search.py +107 -60
  158. ultralytics/solutions/solutions.py +343 -262
  159. ultralytics/solutions/speed_estimation.py +35 -31
  160. ultralytics/solutions/streamlit_inference.py +104 -40
  161. ultralytics/solutions/templates/similarity-search.html +31 -24
  162. ultralytics/solutions/trackzone.py +24 -24
  163. ultralytics/solutions/vision_eye.py +11 -12
  164. ultralytics/trackers/__init__.py +1 -1
  165. ultralytics/trackers/basetrack.py +18 -27
  166. ultralytics/trackers/bot_sort.py +48 -39
  167. ultralytics/trackers/byte_tracker.py +94 -94
  168. ultralytics/trackers/track.py +7 -16
  169. ultralytics/trackers/utils/gmc.py +37 -69
  170. ultralytics/trackers/utils/kalman_filter.py +68 -76
  171. ultralytics/trackers/utils/matching.py +13 -17
  172. ultralytics/utils/__init__.py +251 -275
  173. ultralytics/utils/autobatch.py +19 -7
  174. ultralytics/utils/autodevice.py +68 -38
  175. ultralytics/utils/benchmarks.py +169 -130
  176. ultralytics/utils/callbacks/base.py +12 -13
  177. ultralytics/utils/callbacks/clearml.py +14 -15
  178. ultralytics/utils/callbacks/comet.py +139 -66
  179. ultralytics/utils/callbacks/dvc.py +19 -27
  180. ultralytics/utils/callbacks/hub.py +8 -6
  181. ultralytics/utils/callbacks/mlflow.py +6 -10
  182. ultralytics/utils/callbacks/neptune.py +11 -19
  183. ultralytics/utils/callbacks/platform.py +73 -0
  184. ultralytics/utils/callbacks/raytune.py +3 -4
  185. ultralytics/utils/callbacks/tensorboard.py +9 -12
  186. ultralytics/utils/callbacks/wb.py +33 -30
  187. ultralytics/utils/checks.py +163 -114
  188. ultralytics/utils/cpu.py +89 -0
  189. ultralytics/utils/dist.py +24 -20
  190. ultralytics/utils/downloads.py +176 -146
  191. ultralytics/utils/errors.py +11 -13
  192. ultralytics/utils/events.py +113 -0
  193. ultralytics/utils/export/__init__.py +7 -0
  194. ultralytics/utils/{export.py → export/engine.py} +81 -63
  195. ultralytics/utils/export/imx.py +294 -0
  196. ultralytics/utils/export/tensorflow.py +217 -0
  197. ultralytics/utils/files.py +33 -36
  198. ultralytics/utils/git.py +137 -0
  199. ultralytics/utils/instance.py +105 -120
  200. ultralytics/utils/logger.py +404 -0
  201. ultralytics/utils/loss.py +99 -61
  202. ultralytics/utils/metrics.py +649 -478
  203. ultralytics/utils/nms.py +337 -0
  204. ultralytics/utils/ops.py +263 -451
  205. ultralytics/utils/patches.py +70 -31
  206. ultralytics/utils/plotting.py +253 -223
  207. ultralytics/utils/tal.py +48 -61
  208. ultralytics/utils/torch_utils.py +244 -251
  209. ultralytics/utils/tqdm.py +438 -0
  210. ultralytics/utils/triton.py +22 -23
  211. ultralytics/utils/tuner.py +11 -10
  212. dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
  213. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
  214. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
  215. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
- from typing import List, Optional, Tuple, Type
3
+ from __future__ import annotations
4
4
 
5
5
  import torch
6
6
  import torch.nn as nn
@@ -21,8 +21,7 @@ from .blocks import (
21
21
 
22
22
 
23
23
  class ImageEncoderViT(nn.Module):
24
- """
25
- An image encoder using Vision Transformer (ViT) architecture for encoding images into a compact latent space.
24
+ """An image encoder using Vision Transformer (ViT) architecture for encoding images into a compact latent space.
26
25
 
27
26
  This class processes images by splitting them into patches, applying transformer blocks, and generating a final
28
27
  encoded representation through a neck module.
@@ -35,7 +34,7 @@ class ImageEncoderViT(nn.Module):
35
34
  neck (nn.Sequential): Neck module to further process the output.
36
35
 
37
36
  Methods:
38
- forward: Processes input through patch embedding, positional embedding, blocks, and neck.
37
+ forward: Process input through patch embedding, positional embedding, blocks, and neck.
39
38
 
40
39
  Examples:
41
40
  >>> import torch
@@ -56,16 +55,15 @@ class ImageEncoderViT(nn.Module):
56
55
  mlp_ratio: float = 4.0,
57
56
  out_chans: int = 256,
58
57
  qkv_bias: bool = True,
59
- norm_layer: Type[nn.Module] = nn.LayerNorm,
60
- act_layer: Type[nn.Module] = nn.GELU,
58
+ norm_layer: type[nn.Module] = nn.LayerNorm,
59
+ act_layer: type[nn.Module] = nn.GELU,
61
60
  use_abs_pos: bool = True,
62
61
  use_rel_pos: bool = False,
63
62
  rel_pos_zero_init: bool = True,
64
63
  window_size: int = 0,
65
- global_attn_indexes: Tuple[int, ...] = (),
64
+ global_attn_indexes: tuple[int, ...] = (),
66
65
  ) -> None:
67
- """
68
- Initialize an ImageEncoderViT instance for encoding images using Vision Transformer architecture.
66
+ """Initialize an ImageEncoderViT instance for encoding images using Vision Transformer architecture.
69
67
 
70
68
  Args:
71
69
  img_size (int): Input image size, assumed to be square.
@@ -83,7 +81,7 @@ class ImageEncoderViT(nn.Module):
83
81
  use_rel_pos (bool): If True, adds relative positional embeddings to attention maps.
84
82
  rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero.
85
83
  window_size (int): Size of attention window for windowed attention blocks.
86
- global_attn_indexes (Tuple[int, ...]): Indices of blocks that use global attention.
84
+ global_attn_indexes (tuple[int, ...]): Indices of blocks that use global attention.
87
85
 
88
86
  Examples:
89
87
  >>> encoder = ImageEncoderViT(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12)
@@ -101,9 +99,9 @@ class ImageEncoderViT(nn.Module):
101
99
  embed_dim=embed_dim,
102
100
  )
103
101
 
104
- self.pos_embed: Optional[nn.Parameter] = None
102
+ self.pos_embed: nn.Parameter | None = None
105
103
  if use_abs_pos:
106
- # Initialize absolute positional embedding with pretrain image size.
104
+ # Initialize absolute positional embedding with pretrain image size
107
105
  self.pos_embed = nn.Parameter(torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim))
108
106
 
109
107
  self.blocks = nn.ModuleList()
@@ -156,24 +154,23 @@ class ImageEncoderViT(nn.Module):
156
154
 
157
155
 
158
156
  class PromptEncoder(nn.Module):
159
- """
160
- Encodes different types of prompts for input to SAM's mask decoder, producing sparse and dense embeddings.
157
+ """Encode different types of prompts for input to SAM's mask decoder, producing sparse and dense embeddings.
161
158
 
162
159
  Attributes:
163
160
  embed_dim (int): Dimension of the embeddings.
164
- input_image_size (Tuple[int, int]): Size of the input image as (H, W).
165
- image_embedding_size (Tuple[int, int]): Spatial size of the image embedding as (H, W).
161
+ input_image_size (tuple[int, int]): Size of the input image as (H, W).
162
+ image_embedding_size (tuple[int, int]): Spatial size of the image embedding as (H, W).
166
163
  pe_layer (PositionEmbeddingRandom): Module for random position embedding.
167
164
  num_point_embeddings (int): Number of point embeddings for different types of points.
168
165
  point_embeddings (nn.ModuleList): List of point embeddings.
169
166
  not_a_point_embed (nn.Embedding): Embedding for points that are not part of any label.
170
- mask_input_size (Tuple[int, int]): Size of the input mask.
167
+ mask_input_size (tuple[int, int]): Size of the input mask.
171
168
  mask_downscaling (nn.Sequential): Neural network for downscaling the mask.
172
169
  no_mask_embed (nn.Embedding): Embedding for cases where no mask is provided.
173
170
 
174
171
  Methods:
175
- get_dense_pe: Returns the positional encoding used to encode point prompts.
176
- forward: Embeds different types of prompts, returning both sparse and dense embeddings.
172
+ get_dense_pe: Return the positional encoding used to encode point prompts.
173
+ forward: Embed different types of prompts, returning both sparse and dense embeddings.
177
174
 
178
175
  Examples:
179
176
  >>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
@@ -188,18 +185,17 @@ class PromptEncoder(nn.Module):
188
185
  def __init__(
189
186
  self,
190
187
  embed_dim: int,
191
- image_embedding_size: Tuple[int, int],
192
- input_image_size: Tuple[int, int],
188
+ image_embedding_size: tuple[int, int],
189
+ input_image_size: tuple[int, int],
193
190
  mask_in_chans: int,
194
- activation: Type[nn.Module] = nn.GELU,
191
+ activation: type[nn.Module] = nn.GELU,
195
192
  ) -> None:
196
- """
197
- Initialize the PromptEncoder module for encoding various types of prompts.
193
+ """Initialize the PromptEncoder module for encoding various types of prompts.
198
194
 
199
195
  Args:
200
196
  embed_dim (int): The dimension of the embeddings.
201
- image_embedding_size (Tuple[int, int]): The spatial size of the image embedding as (H, W).
202
- input_image_size (Tuple[int, int]): The padded size of the input image as (H, W).
197
+ image_embedding_size (tuple[int, int]): The spatial size of the image embedding as (H, W).
198
+ input_image_size (tuple[int, int]): The padded size of the input image as (H, W).
203
199
  mask_in_chans (int): The number of hidden channels used for encoding input masks.
204
200
  activation (Type[nn.Module]): The activation function to use when encoding input masks.
205
201
 
@@ -236,15 +232,14 @@ class PromptEncoder(nn.Module):
236
232
  self.no_mask_embed = nn.Embedding(1, embed_dim)
237
233
 
238
234
  def get_dense_pe(self) -> torch.Tensor:
239
- """
240
- Return the dense positional encoding used for encoding point prompts.
235
+ """Return the dense positional encoding used for encoding point prompts.
241
236
 
242
237
  Generate a positional encoding for a dense set of points matching the shape of the image
243
238
  encoding. The encoding is used to provide spatial information to the model when processing point prompts.
244
239
 
245
240
  Returns:
246
- (torch.Tensor): Positional encoding tensor with shape (1, embed_dim, H, W), where H and W are the
247
- height and width of the image embedding size, respectively.
241
+ (torch.Tensor): Positional encoding tensor with shape (1, embed_dim, H, W), where H and W are the height and
242
+ width of the image embedding size, respectively.
248
243
 
249
244
  Examples:
250
245
  >>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
@@ -258,8 +253,8 @@ class PromptEncoder(nn.Module):
258
253
  """Embed point prompts by applying positional encoding and label-specific embeddings."""
259
254
  points = points + 0.5 # Shift to center of pixel
260
255
  if pad:
261
- padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
262
- padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
256
+ padding_point = torch.zeros((points.shape[0], 1, 2), dtype=points.dtype, device=points.device)
257
+ padding_label = -torch.ones((labels.shape[0], 1), dtype=labels.dtype, device=labels.device)
263
258
  points = torch.cat([points, padding_point], dim=1)
264
259
  labels = torch.cat([labels, padding_label], dim=1)
265
260
  point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
@@ -286,9 +281,9 @@ class PromptEncoder(nn.Module):
286
281
 
287
282
  @staticmethod
288
283
  def _get_batch_size(
289
- points: Optional[Tuple[torch.Tensor, torch.Tensor]],
290
- boxes: Optional[torch.Tensor],
291
- masks: Optional[torch.Tensor],
284
+ points: tuple[torch.Tensor, torch.Tensor] | None,
285
+ boxes: torch.Tensor | None,
286
+ masks: torch.Tensor | None,
292
287
  ) -> int:
293
288
  """Get the batch size of the output given the batch size of the input prompts."""
294
289
  if points is not None:
@@ -300,30 +295,23 @@ class PromptEncoder(nn.Module):
300
295
  else:
301
296
  return 1
302
297
 
303
- def _get_device(self) -> torch.device:
304
- """Return the device of the first point embedding's weight tensor."""
305
- return self.point_embeddings[0].weight.device
306
-
307
298
  def forward(
308
299
  self,
309
- points: Optional[Tuple[torch.Tensor, torch.Tensor]],
310
- boxes: Optional[torch.Tensor],
311
- masks: Optional[torch.Tensor],
312
- ) -> Tuple[torch.Tensor, torch.Tensor]:
313
- """
314
- Embed different types of prompts, returning both sparse and dense embeddings.
300
+ points: tuple[torch.Tensor, torch.Tensor] | None,
301
+ boxes: torch.Tensor | None,
302
+ masks: torch.Tensor | None,
303
+ ) -> tuple[torch.Tensor, torch.Tensor]:
304
+ """Embed different types of prompts, returning both sparse and dense embeddings.
315
305
 
316
306
  Args:
317
- points (Tuple[torch.Tensor, torch.Tensor] | None): Point coordinates and labels to embed. The first
318
- tensor contains coordinates with shape (B, N, 2), and the second tensor contains labels with
319
- shape (B, N).
307
+ points (tuple[torch.Tensor, torch.Tensor] | None): Point coordinates and labels to embed. The first tensor
308
+ contains coordinates of shape (B, N, 2), and the second tensor contains labels of shape (B, N).
320
309
  boxes (torch.Tensor | None): Boxes to embed with shape (B, M, 2, 2), where M is the number of boxes.
321
310
  masks (torch.Tensor | None): Masks to embed with shape (B, 1, H, W).
322
311
 
323
312
  Returns:
324
- (Tuple[torch.Tensor, torch.Tensor]): A tuple containing:
325
- - sparse_embeddings (torch.Tensor): Sparse embeddings for points and boxes with shape (B, N, embed_dim).
326
- - dense_embeddings (torch.Tensor): Dense embeddings for masks of shape (B, embed_dim, embed_H, embed_W).
313
+ sparse_embeddings (torch.Tensor): Sparse embeddings for points and boxes with shape (B, N, embed_dim).
314
+ dense_embeddings (torch.Tensor): Dense embeddings for masks of shape (B, embed_dim, embed_H, embed_W).
327
315
 
328
316
  Examples:
329
317
  >>> encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
@@ -335,7 +323,11 @@ class PromptEncoder(nn.Module):
335
323
  torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])
336
324
  """
337
325
  bs = self._get_batch_size(points, boxes, masks)
338
- sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
326
+ sparse_embeddings = torch.empty(
327
+ (bs, 0, self.embed_dim),
328
+ dtype=self.point_embeddings[0].weight.dtype,
329
+ device=self.point_embeddings[0].weight.device,
330
+ )
339
331
  if points is not None:
340
332
  coords, labels = points
341
333
  point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
@@ -355,11 +347,10 @@ class PromptEncoder(nn.Module):
355
347
 
356
348
 
357
349
  class MemoryEncoder(nn.Module):
358
- """
359
- Encode pixel features and masks into a memory representation for efficient image segmentation.
350
+ """Encode pixel features and masks into a memory representation for efficient image segmentation.
360
351
 
361
- This class processes pixel-level features and masks, fusing them to generate encoded memory representations
362
- suitable for downstream tasks in image segmentation models like SAM (Segment Anything Model).
352
+ This class processes pixel-level features and masks, fusing them to generate encoded memory representations suitable
353
+ for downstream tasks in image segmentation models like SAM (Segment Anything Model).
363
354
 
364
355
  Attributes:
365
356
  mask_downsampler (MaskDownSampler): Module for downsampling input masks.
@@ -386,15 +377,14 @@ class MemoryEncoder(nn.Module):
386
377
  out_dim,
387
378
  in_dim=256, # in_dim of pix_feats
388
379
  ):
389
- """
390
- Initialize the MemoryEncoder for encoding pixel features and masks into memory representations.
380
+ """Initialize the MemoryEncoder for encoding pixel features and masks into memory representations.
391
381
 
392
382
  This encoder processes pixel-level features and masks, fusing them to generate encoded memory representations
393
383
  suitable for downstream tasks in image segmentation models like SAM (Segment Anything Model).
394
384
 
395
385
  Args:
396
386
  out_dim (int): Output dimension of the encoded features.
397
- in_dim (int): Input dimension of the pixel features. Default is 256.
387
+ in_dim (int): Input dimension of the pixel features.
398
388
 
399
389
  Examples:
400
390
  >>> encoder = MemoryEncoder(out_dim=256, in_dim=256)
@@ -420,7 +410,7 @@ class MemoryEncoder(nn.Module):
420
410
  pix_feat: torch.Tensor,
421
411
  masks: torch.Tensor,
422
412
  skip_mask_sigmoid: bool = False,
423
- ) -> Tuple[torch.Tensor, torch.Tensor]:
413
+ ) -> dict:
424
414
  """Process pixel features and masks to generate encoded memory representations for segmentation."""
425
415
  if not skip_mask_sigmoid:
426
416
  masks = F.sigmoid(masks)
@@ -440,11 +430,10 @@ class MemoryEncoder(nn.Module):
440
430
 
441
431
 
442
432
  class ImageEncoder(nn.Module):
443
- """
444
- Encode images using a trunk-neck architecture, producing multiscale features and positional encodings.
433
+ """Encode images using a trunk-neck architecture, producing multiscale features and positional encodings.
445
434
 
446
- This class combines a trunk network for feature extraction with a neck network for feature refinement
447
- and positional encoding generation. It can optionally discard the lowest resolution features.
435
+ This class combines a trunk network for feature extraction with a neck network for feature refinement and positional
436
+ encoding generation. It can optionally discard the lowest resolution features.
448
437
 
449
438
  Attributes:
450
439
  trunk (nn.Module): The trunk network for initial feature extraction.
@@ -470,11 +459,10 @@ class ImageEncoder(nn.Module):
470
459
  neck: nn.Module,
471
460
  scalp: int = 0,
472
461
  ):
473
- """
474
- Initialize the ImageEncoder with trunk and neck networks for feature extraction and refinement.
462
+ """Initialize the ImageEncoder with trunk and neck networks for feature extraction and refinement.
475
463
 
476
- This encoder combines a trunk network for feature extraction with a neck network for feature refinement
477
- and positional encoding generation. It can optionally discard the lowest resolution features.
464
+ This encoder combines a trunk network for feature extraction with a neck network for feature refinement and
465
+ positional encoding generation. It can optionally discard the lowest resolution features.
478
466
 
479
467
  Args:
480
468
  trunk (nn.Module): The trunk network for initial feature extraction.
@@ -499,7 +487,7 @@ class ImageEncoder(nn.Module):
499
487
  )
500
488
 
501
489
  def forward(self, sample: torch.Tensor):
502
- """Encode input through patch embedding, positional embedding, transformer blocks, and neck module."""
490
+ """Encode input through trunk and neck networks, returning multiscale features and positional encodings."""
503
491
  features, pos = self.neck(self.trunk(sample))
504
492
  if self.scalp > 0:
505
493
  # Discard the lowest resolution features
@@ -514,19 +502,18 @@ class ImageEncoder(nn.Module):
514
502
 
515
503
 
516
504
  class FpnNeck(nn.Module):
517
- """
518
- A Feature Pyramid Network (FPN) neck variant for multiscale feature fusion in object detection models.
505
+ """A Feature Pyramid Network (FPN) neck variant for multiscale feature fusion in object detection models.
519
506
 
520
- This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing,
521
- similar to ViT positional embedding interpolation.
507
+ This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing, similar to ViT
508
+ positional embedding interpolation.
522
509
 
523
510
  Attributes:
524
511
  position_encoding (PositionEmbeddingSine): Sinusoidal positional encoding module.
525
512
  convs (nn.ModuleList): List of convolutional layers for each backbone level.
526
- backbone_channel_list (List[int]): List of channel dimensions from the backbone.
513
+ backbone_channel_list (list[int]): List of channel dimensions from the backbone.
527
514
  fpn_interp_model (str): Interpolation mode for FPN feature resizing.
528
515
  fuse_type (str): Type of feature fusion, either 'sum' or 'avg'.
529
- fpn_top_down_levels (List[int]): Levels to have top-down features in outputs.
516
+ fpn_top_down_levels (list[int]): Levels to have top-down features in outputs.
530
517
 
531
518
  Methods:
532
519
  forward: Perform forward pass through the FPN neck.
@@ -543,29 +530,28 @@ class FpnNeck(nn.Module):
543
530
  def __init__(
544
531
  self,
545
532
  d_model: int,
546
- backbone_channel_list: List[int],
533
+ backbone_channel_list: list[int],
547
534
  kernel_size: int = 1,
548
535
  stride: int = 1,
549
536
  padding: int = 0,
550
537
  fpn_interp_model: str = "bilinear",
551
538
  fuse_type: str = "sum",
552
- fpn_top_down_levels: Optional[List[int]] = None,
539
+ fpn_top_down_levels: list[int] | None = None,
553
540
  ):
554
- """
555
- Initializes a modified Feature Pyramid Network (FPN) neck.
541
+ """Initialize a modified Feature Pyramid Network (FPN) neck.
556
542
 
557
- This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing,
558
- similar to ViT positional embedding interpolation.
543
+ This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing, similar to
544
+ ViT positional embedding interpolation.
559
545
 
560
546
  Args:
561
547
  d_model (int): Dimension of the model.
562
- backbone_channel_list (List[int]): List of channel dimensions from the backbone.
548
+ backbone_channel_list (list[int]): List of channel dimensions from the backbone.
563
549
  kernel_size (int): Kernel size for the convolutional layers.
564
550
  stride (int): Stride for the convolutional layers.
565
551
  padding (int): Padding for the convolutional layers.
566
552
  fpn_interp_model (str): Interpolation mode for FPN feature resizing.
567
553
  fuse_type (str): Type of feature fusion, either 'sum' or 'avg'.
568
- fpn_top_down_levels (Optional[List[int]]): Levels to have top-down features in outputs.
554
+ fpn_top_down_levels (Optional[list[int]]): Levels to have top-down features in outputs.
569
555
 
570
556
  Examples:
571
557
  >>> backbone_channels = [64, 128, 256, 512]
@@ -594,30 +580,28 @@ class FpnNeck(nn.Module):
594
580
  assert fuse_type in {"sum", "avg"}
595
581
  self.fuse_type = fuse_type
596
582
 
597
- # levels to have top-down features in its outputs
583
+ # Levels to have top-down features in its outputs
598
584
  # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
599
585
  # have top-down propagation, while outputs of level 0 and level 1 have only
600
- # lateral features from the same backbone level.
586
+ # lateral features from the same backbone level
601
587
  if fpn_top_down_levels is None:
602
- # default is to have top-down features on all levels
588
+ # Default is to have top-down features on all levels
603
589
  fpn_top_down_levels = range(len(self.convs))
604
590
  self.fpn_top_down_levels = list(fpn_top_down_levels)
605
591
 
606
- def forward(self, xs: List[torch.Tensor]):
607
- """
608
- Performs forward pass through the Feature Pyramid Network (FPN) neck.
592
+ def forward(self, xs: list[torch.Tensor]):
593
+ """Perform forward pass through the Feature Pyramid Network (FPN) neck.
609
594
 
610
595
  This method processes a list of input tensors from the backbone through the FPN, applying lateral connections
611
596
  and top-down feature fusion. It generates output feature maps and corresponding positional encodings.
612
597
 
613
598
  Args:
614
- xs (List[torch.Tensor]): List of input tensors from the backbone, each with shape (B, C, H, W).
599
+ xs (list[torch.Tensor]): List of input tensors from the backbone, each with shape (B, C, H, W).
615
600
 
616
601
  Returns:
617
- (Tuple[List[torch.Tensor], List[torch.Tensor]]): A tuple containing:
618
- - out (List[torch.Tensor]): List of output feature maps after FPN processing, each with shape
619
- (B, d_model, H, W).
620
- - pos (List[torch.Tensor]): List of positional encodings corresponding to each output feature map.
602
+ out (list[torch.Tensor]): List of output feature maps after FPN processing, each with shape (B, d_model, H,
603
+ W).
604
+ pos (list[torch.Tensor]): List of positional encodings corresponding to each output feature map.
621
605
 
622
606
  Examples:
623
607
  >>> fpn_neck = FpnNeck(d_model=256, backbone_channel_list=[64, 128, 256, 512])
@@ -629,17 +613,17 @@ class FpnNeck(nn.Module):
629
613
  out = [None] * len(self.convs)
630
614
  pos = [None] * len(self.convs)
631
615
  assert len(xs) == len(self.convs)
632
- # fpn forward pass
616
+ # FPN forward pass
633
617
  # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
634
618
  prev_features = None
635
- # forward in top-down order (from low to high resolution)
619
+ # Forward in top-down order (from low to high resolution)
636
620
  n = len(self.convs) - 1
637
621
  for i in range(n, -1, -1):
638
622
  x = xs[i]
639
623
  lateral_features = self.convs[n - i](x)
640
624
  if i in self.fpn_top_down_levels and prev_features is not None:
641
625
  top_down_features = F.interpolate(
642
- prev_features.to(dtype=torch.float32),
626
+ prev_features.to(dtype=x.dtype),
643
627
  scale_factor=2.0,
644
628
  mode=self.fpn_interp_model,
645
629
  align_corners=(None if self.fpn_interp_model == "nearest" else False),
@@ -658,26 +642,25 @@ class FpnNeck(nn.Module):
658
642
 
659
643
 
660
644
  class Hiera(nn.Module):
661
- """
662
- Hierarchical vision transformer for efficient multiscale feature extraction in image processing tasks.
645
+ """Hierarchical vision transformer for efficient multiscale feature extraction in image processing tasks.
663
646
 
664
- This class implements a Hiera model, which is a hierarchical vision transformer architecture designed for
665
- efficient multiscale feature extraction. It uses a series of transformer blocks organized into stages,
666
- with optional pooling and global attention mechanisms.
647
+ This class implements a Hiera model, which is a hierarchical vision transformer architecture designed for efficient
648
+ multiscale feature extraction. It uses a series of transformer blocks organized into stages, with optional pooling
649
+ and global attention mechanisms.
667
650
 
668
651
  Attributes:
669
- window_spec (Tuple[int, ...]): Window sizes for each stage.
670
- q_stride (Tuple[int, int]): Downsampling stride between stages.
671
- stage_ends (List[int]): Indices of the last block in each stage.
672
- q_pool_blocks (List[int]): Indices of blocks where pooling is applied.
652
+ window_spec (tuple[int, ...]): Window sizes for each stage.
653
+ q_stride (tuple[int, int]): Downsampling stride between stages.
654
+ stage_ends (list[int]): Indices of the last block in each stage.
655
+ q_pool_blocks (list[int]): Indices of blocks where pooling is applied.
673
656
  return_interm_layers (bool): Whether to return intermediate layer outputs.
674
657
  patch_embed (PatchEmbed): Module for patch embedding.
675
- global_att_blocks (Tuple[int, ...]): Indices of blocks with global attention.
676
- window_pos_embed_bkg_spatial_size (Tuple[int, int]): Spatial size for window positional embedding background.
658
+ global_att_blocks (tuple[int, ...]): Indices of blocks with global attention.
659
+ window_pos_embed_bkg_spatial_size (tuple[int, int]): Spatial size for window positional embedding background.
677
660
  pos_embed (nn.Parameter): Positional embedding for the background.
678
661
  pos_embed_window (nn.Parameter): Positional embedding for the window.
679
662
  blocks (nn.ModuleList): List of MultiScaleBlock modules.
680
- channel_list (List[int]): List of output channel dimensions for each stage.
663
+ channel_list (list[int]): List of output channel dimensions for each stage.
681
664
 
682
665
  Methods:
683
666
  _get_pos_embed: Generate positional embeddings by interpolating and combining window and background embeddings.
@@ -697,45 +680,45 @@ class Hiera(nn.Module):
697
680
  num_heads: int = 1, # initial number of heads
698
681
  drop_path_rate: float = 0.0, # stochastic depth
699
682
  q_pool: int = 3, # number of q_pool stages
700
- q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
701
- stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
683
+ q_stride: tuple[int, int] = (2, 2), # downsample stride bet. stages
684
+ stages: tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
702
685
  dim_mul: float = 2.0, # dim_mul factor at stage shift
703
686
  head_mul: float = 2.0, # head_mul factor at stage shift
704
- window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
687
+ window_pos_embed_bkg_spatial_size: tuple[int, int] = (14, 14),
705
688
  # window size per stage, when not using global att.
706
- window_spec: Tuple[int, ...] = (
689
+ window_spec: tuple[int, ...] = (
707
690
  8,
708
691
  4,
709
692
  14,
710
693
  7,
711
694
  ),
712
695
  # global attn in these blocks
713
- global_att_blocks: Tuple[int, ...] = (
696
+ global_att_blocks: tuple[int, ...] = (
714
697
  12,
715
698
  16,
716
699
  20,
717
700
  ),
718
701
  return_interm_layers=True, # return feats from every stage
719
702
  ):
720
- """
721
- Initialize a Hiera model, a hierarchical vision transformer for efficient multiscale feature extraction.
703
+ """Initialize a Hiera model, a hierarchical vision transformer for efficient multiscale feature extraction.
722
704
 
723
- Hiera is a hierarchical vision transformer architecture designed for efficient multiscale feature extraction
724
- in image processing tasks. It uses a series of transformer blocks organized into stages, with optional
725
- pooling and global attention mechanisms.
705
+ Hiera is a hierarchical vision transformer architecture designed for efficient multiscale feature extraction in
706
+ image processing tasks. It uses a series of transformer blocks organized into stages, with optional pooling and
707
+ global attention mechanisms.
726
708
 
727
709
  Args:
728
710
  embed_dim (int): Initial embedding dimension for the model.
729
711
  num_heads (int): Initial number of attention heads.
730
712
  drop_path_rate (float): Stochastic depth rate.
731
713
  q_pool (int): Number of query pooling stages.
732
- q_stride (Tuple[int, int]): Downsampling stride between stages.
733
- stages (Tuple[int, ...]): Number of blocks per stage.
714
+ q_stride (tuple[int, int]): Downsampling stride between stages.
715
+ stages (tuple[int, ...]): Number of blocks per stage.
734
716
  dim_mul (float): Dimension multiplier factor at stage transitions.
735
717
  head_mul (float): Head multiplier factor at stage transitions.
736
- window_pos_embed_bkg_spatial_size (Tuple[int, int]): Spatial size for window positional embedding background.
737
- window_spec (Tuple[int, ...]): Window sizes for each stage when not using global attention.
738
- global_att_blocks (Tuple[int, ...]): Indices of blocks that use global attention.
718
+ window_pos_embed_bkg_spatial_size (tuple[int, int]): Spatial size for window positional embedding
719
+ background.
720
+ window_spec (tuple[int, ...]): Window sizes for each stage when not using global attention.
721
+ global_att_blocks (tuple[int, ...]): Indices of blocks that use global attention.
739
722
  return_interm_layers (bool): Whether to return intermediate layer outputs.
740
723
 
741
724
  Examples:
@@ -763,7 +746,7 @@ class Hiera(nn.Module):
763
746
  stride=(4, 4),
764
747
  padding=(3, 3),
765
748
  )
766
- # Which blocks have global att?
749
+ # Which blocks have global attention?
767
750
  self.global_att_blocks = global_att_blocks
768
751
 
769
752
  # Windowed positional embedding (https://arxiv.org/abs/2311.05613)
@@ -778,8 +761,7 @@ class Hiera(nn.Module):
778
761
 
779
762
  for i in range(depth):
780
763
  dim_out = embed_dim
781
- # lags by a block, so first block of
782
- # next stage uses an initial window size
764
+ # Lags by a block, so first block of next stage uses an initial window size
783
765
  # of previous stage and final window size of current stage
784
766
  window_size = self.window_spec[cur_stage - 1]
785
767
 
@@ -809,7 +791,7 @@ class Hiera(nn.Module):
809
791
  else [self.blocks[-1].dim_out]
810
792
  )
811
793
 
812
- def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
794
+ def _get_pos_embed(self, hw: tuple[int, int]) -> torch.Tensor:
813
795
  """Generate positional embeddings by interpolating and combining window and background embeddings."""
814
796
  h, w = hw
815
797
  window_embed = self.pos_embed_window
@@ -818,15 +800,14 @@ class Hiera(nn.Module):
818
800
  pos_embed = pos_embed.permute(0, 2, 3, 1)
819
801
  return pos_embed
820
802
 
821
- def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
822
- """
823
- Perform forward pass through Hiera model, extracting multiscale features from input images.
803
+ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
804
+ """Perform forward pass through Hiera model, extracting multiscale features from input images.
824
805
 
825
806
  Args:
826
807
  x (torch.Tensor): Input tensor with shape (B, C, H, W) representing a batch of images.
827
808
 
828
809
  Returns:
829
- (List[torch.Tensor]): List of feature maps at different scales, each with shape (B, C_i, H_i, W_i), where
810
+ (list[torch.Tensor]): List of feature maps at different scales, each with shape (B, C_i, H_i, W_i), where
830
811
  C_i is the channel dimension and H_i, W_i are the spatial dimensions at scale i. The list is ordered
831
812
  from highest resolution (fine features) to lowest resolution (coarse features) if return_interm_layers
832
813
  is True, otherwise contains only the final output.
@@ -841,7 +822,7 @@ class Hiera(nn.Module):
841
822
  x = self.patch_embed(x)
842
823
  # x: (B, H, W, C)
843
824
 
844
- # Add pos embed
825
+ # Add positional embedding
845
826
  x = x + self._get_pos_embed(x.shape[1:3])
846
827
 
847
828
  outputs = []