dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/METADATA +64 -74
  2. dgenerate_ultralytics_headless-8.4.7.dist-info/RECORD +311 -0
  3. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -9
  5. tests/conftest.py +8 -15
  6. tests/test_cli.py +1 -1
  7. tests/test_cuda.py +13 -10
  8. tests/test_engine.py +9 -9
  9. tests/test_exports.py +65 -13
  10. tests/test_integrations.py +13 -13
  11. tests/test_python.py +125 -69
  12. tests/test_solutions.py +161 -152
  13. ultralytics/__init__.py +1 -1
  14. ultralytics/cfg/__init__.py +86 -92
  15. ultralytics/cfg/datasets/Argoverse.yaml +7 -6
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  17. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  18. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  19. ultralytics/cfg/datasets/TT100K.yaml +346 -0
  20. ultralytics/cfg/datasets/VOC.yaml +15 -16
  21. ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
  22. ultralytics/cfg/datasets/coco-pose.yaml +21 -0
  23. ultralytics/cfg/datasets/coco12-formats.yaml +101 -0
  24. ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
  25. ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
  26. ultralytics/cfg/datasets/dog-pose.yaml +28 -0
  27. ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
  28. ultralytics/cfg/datasets/dota8.yaml +2 -2
  29. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
  30. ultralytics/cfg/datasets/kitti.yaml +27 -0
  31. ultralytics/cfg/datasets/lvis.yaml +5 -5
  32. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  33. ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
  34. ultralytics/cfg/datasets/xView.yaml +16 -16
  35. ultralytics/cfg/default.yaml +4 -2
  36. ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
  37. ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
  38. ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
  39. ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
  40. ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
  41. ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
  42. ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
  43. ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
  44. ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
  45. ultralytics/cfg/models/26/yolo26.yaml +52 -0
  46. ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
  47. ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
  48. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
  49. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
  51. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
  52. ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
  53. ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
  54. ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
  55. ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
  56. ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
  57. ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
  58. ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
  59. ultralytics/cfg/models/v6/yolov6.yaml +1 -1
  60. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
  61. ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
  62. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
  63. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
  64. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
  65. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
  66. ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
  67. ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
  68. ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
  69. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
  70. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
  71. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
  72. ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
  73. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
  74. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  75. ultralytics/data/__init__.py +4 -4
  76. ultralytics/data/annotator.py +5 -6
  77. ultralytics/data/augment.py +300 -475
  78. ultralytics/data/base.py +18 -26
  79. ultralytics/data/build.py +147 -25
  80. ultralytics/data/converter.py +108 -87
  81. ultralytics/data/dataset.py +47 -75
  82. ultralytics/data/loaders.py +42 -49
  83. ultralytics/data/split.py +5 -6
  84. ultralytics/data/split_dota.py +8 -15
  85. ultralytics/data/utils.py +36 -45
  86. ultralytics/engine/exporter.py +351 -263
  87. ultralytics/engine/model.py +186 -225
  88. ultralytics/engine/predictor.py +45 -54
  89. ultralytics/engine/results.py +198 -325
  90. ultralytics/engine/trainer.py +165 -106
  91. ultralytics/engine/tuner.py +41 -43
  92. ultralytics/engine/validator.py +55 -38
  93. ultralytics/hub/__init__.py +16 -19
  94. ultralytics/hub/auth.py +6 -12
  95. ultralytics/hub/google/__init__.py +7 -10
  96. ultralytics/hub/session.py +15 -25
  97. ultralytics/hub/utils.py +5 -8
  98. ultralytics/models/__init__.py +1 -1
  99. ultralytics/models/fastsam/__init__.py +1 -1
  100. ultralytics/models/fastsam/model.py +8 -10
  101. ultralytics/models/fastsam/predict.py +18 -30
  102. ultralytics/models/fastsam/utils.py +1 -2
  103. ultralytics/models/fastsam/val.py +5 -7
  104. ultralytics/models/nas/__init__.py +1 -1
  105. ultralytics/models/nas/model.py +5 -8
  106. ultralytics/models/nas/predict.py +7 -9
  107. ultralytics/models/nas/val.py +1 -2
  108. ultralytics/models/rtdetr/__init__.py +1 -1
  109. ultralytics/models/rtdetr/model.py +5 -8
  110. ultralytics/models/rtdetr/predict.py +15 -19
  111. ultralytics/models/rtdetr/train.py +10 -13
  112. ultralytics/models/rtdetr/val.py +21 -23
  113. ultralytics/models/sam/__init__.py +15 -2
  114. ultralytics/models/sam/amg.py +14 -20
  115. ultralytics/models/sam/build.py +26 -19
  116. ultralytics/models/sam/build_sam3.py +377 -0
  117. ultralytics/models/sam/model.py +29 -32
  118. ultralytics/models/sam/modules/blocks.py +83 -144
  119. ultralytics/models/sam/modules/decoders.py +19 -37
  120. ultralytics/models/sam/modules/encoders.py +44 -101
  121. ultralytics/models/sam/modules/memory_attention.py +16 -30
  122. ultralytics/models/sam/modules/sam.py +200 -73
  123. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  124. ultralytics/models/sam/modules/transformer.py +18 -28
  125. ultralytics/models/sam/modules/utils.py +174 -50
  126. ultralytics/models/sam/predict.py +2248 -350
  127. ultralytics/models/sam/sam3/__init__.py +3 -0
  128. ultralytics/models/sam/sam3/decoder.py +546 -0
  129. ultralytics/models/sam/sam3/encoder.py +529 -0
  130. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  131. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  132. ultralytics/models/sam/sam3/model_misc.py +199 -0
  133. ultralytics/models/sam/sam3/necks.py +129 -0
  134. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  135. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  136. ultralytics/models/sam/sam3/vitdet.py +547 -0
  137. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  138. ultralytics/models/utils/loss.py +14 -26
  139. ultralytics/models/utils/ops.py +13 -17
  140. ultralytics/models/yolo/__init__.py +1 -1
  141. ultralytics/models/yolo/classify/predict.py +10 -13
  142. ultralytics/models/yolo/classify/train.py +12 -33
  143. ultralytics/models/yolo/classify/val.py +30 -29
  144. ultralytics/models/yolo/detect/predict.py +9 -12
  145. ultralytics/models/yolo/detect/train.py +17 -23
  146. ultralytics/models/yolo/detect/val.py +77 -59
  147. ultralytics/models/yolo/model.py +43 -60
  148. ultralytics/models/yolo/obb/predict.py +7 -16
  149. ultralytics/models/yolo/obb/train.py +14 -17
  150. ultralytics/models/yolo/obb/val.py +40 -37
  151. ultralytics/models/yolo/pose/__init__.py +1 -1
  152. ultralytics/models/yolo/pose/predict.py +7 -22
  153. ultralytics/models/yolo/pose/train.py +13 -16
  154. ultralytics/models/yolo/pose/val.py +39 -58
  155. ultralytics/models/yolo/segment/predict.py +17 -21
  156. ultralytics/models/yolo/segment/train.py +7 -10
  157. ultralytics/models/yolo/segment/val.py +95 -47
  158. ultralytics/models/yolo/world/train.py +8 -14
  159. ultralytics/models/yolo/world/train_world.py +11 -34
  160. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  161. ultralytics/models/yolo/yoloe/predict.py +16 -23
  162. ultralytics/models/yolo/yoloe/train.py +36 -44
  163. ultralytics/models/yolo/yoloe/train_seg.py +11 -11
  164. ultralytics/models/yolo/yoloe/val.py +15 -20
  165. ultralytics/nn/__init__.py +7 -7
  166. ultralytics/nn/autobackend.py +159 -85
  167. ultralytics/nn/modules/__init__.py +68 -60
  168. ultralytics/nn/modules/activation.py +4 -6
  169. ultralytics/nn/modules/block.py +260 -224
  170. ultralytics/nn/modules/conv.py +52 -97
  171. ultralytics/nn/modules/head.py +831 -299
  172. ultralytics/nn/modules/transformer.py +76 -88
  173. ultralytics/nn/modules/utils.py +16 -21
  174. ultralytics/nn/tasks.py +180 -195
  175. ultralytics/nn/text_model.py +45 -69
  176. ultralytics/optim/__init__.py +5 -0
  177. ultralytics/optim/muon.py +338 -0
  178. ultralytics/solutions/__init__.py +12 -12
  179. ultralytics/solutions/ai_gym.py +13 -19
  180. ultralytics/solutions/analytics.py +15 -16
  181. ultralytics/solutions/config.py +6 -7
  182. ultralytics/solutions/distance_calculation.py +10 -13
  183. ultralytics/solutions/heatmap.py +8 -14
  184. ultralytics/solutions/instance_segmentation.py +6 -9
  185. ultralytics/solutions/object_blurrer.py +7 -10
  186. ultralytics/solutions/object_counter.py +12 -19
  187. ultralytics/solutions/object_cropper.py +8 -14
  188. ultralytics/solutions/parking_management.py +34 -32
  189. ultralytics/solutions/queue_management.py +10 -12
  190. ultralytics/solutions/region_counter.py +9 -12
  191. ultralytics/solutions/security_alarm.py +15 -20
  192. ultralytics/solutions/similarity_search.py +10 -15
  193. ultralytics/solutions/solutions.py +77 -76
  194. ultralytics/solutions/speed_estimation.py +7 -10
  195. ultralytics/solutions/streamlit_inference.py +2 -4
  196. ultralytics/solutions/templates/similarity-search.html +7 -18
  197. ultralytics/solutions/trackzone.py +7 -10
  198. ultralytics/solutions/vision_eye.py +5 -8
  199. ultralytics/trackers/__init__.py +1 -1
  200. ultralytics/trackers/basetrack.py +3 -5
  201. ultralytics/trackers/bot_sort.py +10 -27
  202. ultralytics/trackers/byte_tracker.py +21 -37
  203. ultralytics/trackers/track.py +4 -7
  204. ultralytics/trackers/utils/gmc.py +11 -22
  205. ultralytics/trackers/utils/kalman_filter.py +37 -48
  206. ultralytics/trackers/utils/matching.py +12 -15
  207. ultralytics/utils/__init__.py +124 -124
  208. ultralytics/utils/autobatch.py +2 -4
  209. ultralytics/utils/autodevice.py +17 -18
  210. ultralytics/utils/benchmarks.py +57 -71
  211. ultralytics/utils/callbacks/base.py +8 -10
  212. ultralytics/utils/callbacks/clearml.py +5 -13
  213. ultralytics/utils/callbacks/comet.py +32 -46
  214. ultralytics/utils/callbacks/dvc.py +13 -18
  215. ultralytics/utils/callbacks/mlflow.py +4 -5
  216. ultralytics/utils/callbacks/neptune.py +7 -15
  217. ultralytics/utils/callbacks/platform.py +423 -38
  218. ultralytics/utils/callbacks/raytune.py +3 -4
  219. ultralytics/utils/callbacks/tensorboard.py +25 -31
  220. ultralytics/utils/callbacks/wb.py +16 -14
  221. ultralytics/utils/checks.py +127 -85
  222. ultralytics/utils/cpu.py +3 -8
  223. ultralytics/utils/dist.py +9 -12
  224. ultralytics/utils/downloads.py +25 -33
  225. ultralytics/utils/errors.py +6 -14
  226. ultralytics/utils/events.py +2 -4
  227. ultralytics/utils/export/__init__.py +4 -236
  228. ultralytics/utils/export/engine.py +246 -0
  229. ultralytics/utils/export/imx.py +117 -63
  230. ultralytics/utils/export/tensorflow.py +231 -0
  231. ultralytics/utils/files.py +26 -30
  232. ultralytics/utils/git.py +9 -11
  233. ultralytics/utils/instance.py +30 -51
  234. ultralytics/utils/logger.py +212 -114
  235. ultralytics/utils/loss.py +601 -215
  236. ultralytics/utils/metrics.py +128 -156
  237. ultralytics/utils/nms.py +13 -16
  238. ultralytics/utils/ops.py +117 -166
  239. ultralytics/utils/patches.py +75 -21
  240. ultralytics/utils/plotting.py +75 -80
  241. ultralytics/utils/tal.py +125 -59
  242. ultralytics/utils/torch_utils.py +53 -79
  243. ultralytics/utils/tqdm.py +24 -21
  244. ultralytics/utils/triton.py +13 -19
  245. ultralytics/utils/tuner.py +19 -10
  246. dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
  247. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/entry_points.txt +0 -0
  248. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/licenses/LICENSE +0 -0
  249. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/top_level.txt +0 -0
@@ -21,8 +21,7 @@ from .blocks import (
21
21
 
22
22
 
23
23
  class ImageEncoderViT(nn.Module):
24
- """
25
- An image encoder using Vision Transformer (ViT) architecture for encoding images into a compact latent space.
24
+ """An image encoder using Vision Transformer (ViT) architecture for encoding images into a compact latent space.
26
25
 
27
26
  This class processes images by splitting them into patches, applying transformer blocks, and generating a final
28
27
  encoded representation through a neck module.
@@ -64,8 +63,7 @@ class ImageEncoderViT(nn.Module):
64
63
  window_size: int = 0,
65
64
  global_attn_indexes: tuple[int, ...] = (),
66
65
  ) -> None:
67
- """
68
- Initialize an ImageEncoderViT instance for encoding images using Vision Transformer architecture.
66
+ """Initialize an ImageEncoderViT instance for encoding images using Vision Transformer architecture.
69
67
 
70
68
  Args:
71
69
  img_size (int): Input image size, assumed to be square.
@@ -84,12 +82,6 @@ class ImageEncoderViT(nn.Module):
84
82
  rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero.
85
83
  window_size (int): Size of attention window for windowed attention blocks.
86
84
  global_attn_indexes (tuple[int, ...]): Indices of blocks that use global attention.
87
-
88
- Examples:
89
- >>> encoder = ImageEncoderViT(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12)
90
- >>> input_image = torch.randn(1, 3, 224, 224)
91
- >>> output = encoder(input_image)
92
- >>> print(output.shape)
93
85
  """
94
86
  super().__init__()
95
87
  self.img_size = img_size
@@ -156,8 +148,7 @@ class ImageEncoderViT(nn.Module):
156
148
 
157
149
 
158
150
  class PromptEncoder(nn.Module):
159
- """
160
- Encode different types of prompts for input to SAM's mask decoder, producing sparse and dense embeddings.
151
+ """Encode different types of prompts for input to SAM's mask decoder, producing sparse and dense embeddings.
161
152
 
162
153
  Attributes:
163
154
  embed_dim (int): Dimension of the embeddings.
@@ -193,8 +184,7 @@ class PromptEncoder(nn.Module):
193
184
  mask_in_chans: int,
194
185
  activation: type[nn.Module] = nn.GELU,
195
186
  ) -> None:
196
- """
197
- Initialize the PromptEncoder module for encoding various types of prompts.
187
+ """Initialize the PromptEncoder module for encoding various types of prompts.
198
188
 
199
189
  Args:
200
190
  embed_dim (int): The dimension of the embeddings.
@@ -202,15 +192,6 @@ class PromptEncoder(nn.Module):
202
192
  input_image_size (tuple[int, int]): The padded size of the input image as (H, W).
203
193
  mask_in_chans (int): The number of hidden channels used for encoding input masks.
204
194
  activation (Type[nn.Module]): The activation function to use when encoding input masks.
205
-
206
- Examples:
207
- >>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
208
- >>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5)))
209
- >>> boxes = torch.rand(1, 2, 2)
210
- >>> masks = torch.rand(1, 1, 256, 256)
211
- >>> sparse_embeddings, dense_embeddings = prompt_encoder(points, boxes, masks)
212
- >>> print(sparse_embeddings.shape, dense_embeddings.shape)
213
- torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])
214
195
  """
215
196
  super().__init__()
216
197
  self.embed_dim = embed_dim
@@ -236,15 +217,14 @@ class PromptEncoder(nn.Module):
236
217
  self.no_mask_embed = nn.Embedding(1, embed_dim)
237
218
 
238
219
  def get_dense_pe(self) -> torch.Tensor:
239
- """
240
- Return the dense positional encoding used for encoding point prompts.
220
+ """Return the dense positional encoding used for encoding point prompts.
241
221
 
242
222
  Generate a positional encoding for a dense set of points matching the shape of the image
243
223
  encoding. The encoding is used to provide spatial information to the model when processing point prompts.
244
224
 
245
225
  Returns:
246
- (torch.Tensor): Positional encoding tensor with shape (1, embed_dim, H, W), where H and W are the
247
- height and width of the image embedding size, respectively.
226
+ (torch.Tensor): Positional encoding tensor with shape (1, embed_dim, H, W), where H and W are the height and
227
+ width of the image embedding size, respectively.
248
228
 
249
229
  Examples:
250
230
  >>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
@@ -306,13 +286,11 @@ class PromptEncoder(nn.Module):
306
286
  boxes: torch.Tensor | None,
307
287
  masks: torch.Tensor | None,
308
288
  ) -> tuple[torch.Tensor, torch.Tensor]:
309
- """
310
- Embed different types of prompts, returning both sparse and dense embeddings.
289
+ """Embed different types of prompts, returning both sparse and dense embeddings.
311
290
 
312
291
  Args:
313
- points (tuple[torch.Tensor, torch.Tensor] | None): Point coordinates and labels to embed. The first
314
- tensor contains coordinates with shape (B, N, 2), and the second tensor contains labels with
315
- shape (B, N).
292
+ points (tuple[torch.Tensor, torch.Tensor] | None): Point coordinates and labels to embed. The first tensor
293
+ contains coordinates of shape (B, N, 2), and the second tensor contains labels of shape (B, N).
316
294
  boxes (torch.Tensor | None): Boxes to embed with shape (B, M, 2, 2), where M is the number of boxes.
317
295
  masks (torch.Tensor | None): Masks to embed with shape (B, 1, H, W).
318
296
 
@@ -354,11 +332,10 @@ class PromptEncoder(nn.Module):
354
332
 
355
333
 
356
334
  class MemoryEncoder(nn.Module):
357
- """
358
- Encode pixel features and masks into a memory representation for efficient image segmentation.
335
+ """Encode pixel features and masks into a memory representation for efficient image segmentation.
359
336
 
360
- This class processes pixel-level features and masks, fusing them to generate encoded memory representations
361
- suitable for downstream tasks in image segmentation models like SAM (Segment Anything Model).
337
+ This class processes pixel-level features and masks, fusing them to generate encoded memory representations suitable
338
+ for downstream tasks in image segmentation models like SAM (Segment Anything Model).
362
339
 
363
340
  Attributes:
364
341
  mask_downsampler (MaskDownSampler): Module for downsampling input masks.
@@ -384,9 +361,9 @@ class MemoryEncoder(nn.Module):
384
361
  self,
385
362
  out_dim,
386
363
  in_dim=256, # in_dim of pix_feats
364
+ interpol_size: tuple[int, int] | None = None,
387
365
  ):
388
- """
389
- Initialize the MemoryEncoder for encoding pixel features and masks into memory representations.
366
+ """Initialize the MemoryEncoder for encoding pixel features and masks into memory representations.
390
367
 
391
368
  This encoder processes pixel-level features and masks, fusing them to generate encoded memory representations
392
369
  suitable for downstream tasks in image segmentation models like SAM (Segment Anything Model).
@@ -394,18 +371,12 @@ class MemoryEncoder(nn.Module):
394
371
  Args:
395
372
  out_dim (int): Output dimension of the encoded features.
396
373
  in_dim (int): Input dimension of the pixel features.
397
-
398
- Examples:
399
- >>> encoder = MemoryEncoder(out_dim=256, in_dim=256)
400
- >>> pix_feat = torch.randn(1, 256, 64, 64)
401
- >>> masks = torch.randn(1, 1, 64, 64)
402
- >>> encoded_feat, pos = encoder(pix_feat, masks)
403
- >>> print(encoded_feat.shape, pos.shape)
404
- torch.Size([1, 256, 64, 64]) torch.Size([1, 128, 64, 64])
374
+ interpol_size (tuple[int, int] | None): Size to interpolate masks to. If None, uses the size of pixel
375
+ features.
405
376
  """
406
377
  super().__init__()
407
378
 
408
- self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1)
379
+ self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1, interpol_size=interpol_size)
409
380
 
410
381
  self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
411
382
  self.fuser = Fuser(CXBlock(dim=256), num_layers=2)
@@ -439,11 +410,10 @@ class MemoryEncoder(nn.Module):
439
410
 
440
411
 
441
412
  class ImageEncoder(nn.Module):
442
- """
443
- Encode images using a trunk-neck architecture, producing multiscale features and positional encodings.
413
+ """Encode images using a trunk-neck architecture, producing multiscale features and positional encodings.
444
414
 
445
- This class combines a trunk network for feature extraction with a neck network for feature refinement
446
- and positional encoding generation. It can optionally discard the lowest resolution features.
415
+ This class combines a trunk network for feature extraction with a neck network for feature refinement and positional
416
+ encoding generation. It can optionally discard the lowest resolution features.
447
417
 
448
418
  Attributes:
449
419
  trunk (nn.Module): The trunk network for initial feature extraction.
@@ -469,25 +439,15 @@ class ImageEncoder(nn.Module):
469
439
  neck: nn.Module,
470
440
  scalp: int = 0,
471
441
  ):
472
- """
473
- Initialize the ImageEncoder with trunk and neck networks for feature extraction and refinement.
442
+ """Initialize the ImageEncoder with trunk and neck networks for feature extraction and refinement.
474
443
 
475
- This encoder combines a trunk network for feature extraction with a neck network for feature refinement
476
- and positional encoding generation. It can optionally discard the lowest resolution features.
444
+ This encoder combines a trunk network for feature extraction with a neck network for feature refinement and
445
+ positional encoding generation. It can optionally discard the lowest resolution features.
477
446
 
478
447
  Args:
479
448
  trunk (nn.Module): The trunk network for initial feature extraction.
480
449
  neck (nn.Module): The neck network for feature refinement and positional encoding generation.
481
450
  scalp (int): Number of lowest resolution feature levels to discard.
482
-
483
- Examples:
484
- >>> trunk = SomeTrunkNetwork()
485
- >>> neck = SomeNeckNetwork()
486
- >>> encoder = ImageEncoder(trunk, neck, scalp=1)
487
- >>> image = torch.randn(1, 3, 224, 224)
488
- >>> output = encoder(image)
489
- >>> print(output.keys())
490
- dict_keys(['vision_features', 'vision_pos_enc', 'backbone_fpn'])
491
451
  """
492
452
  super().__init__()
493
453
  self.trunk = trunk
@@ -513,11 +473,10 @@ class ImageEncoder(nn.Module):
513
473
 
514
474
 
515
475
  class FpnNeck(nn.Module):
516
- """
517
- A Feature Pyramid Network (FPN) neck variant for multiscale feature fusion in object detection models.
476
+ """A Feature Pyramid Network (FPN) neck variant for multiscale feature fusion in object detection models.
518
477
 
519
- This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing,
520
- similar to ViT positional embedding interpolation.
478
+ This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing, similar to ViT
479
+ positional embedding interpolation.
521
480
 
522
481
  Attributes:
523
482
  position_encoding (PositionEmbeddingSine): Sinusoidal positional encoding module.
@@ -550,11 +509,10 @@ class FpnNeck(nn.Module):
550
509
  fuse_type: str = "sum",
551
510
  fpn_top_down_levels: list[int] | None = None,
552
511
  ):
553
- """
554
- Initialize a modified Feature Pyramid Network (FPN) neck.
512
+ """Initialize a modified Feature Pyramid Network (FPN) neck.
555
513
 
556
- This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing,
557
- similar to ViT positional embedding interpolation.
514
+ This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing, similar to
515
+ ViT positional embedding interpolation.
558
516
 
559
517
  Args:
560
518
  d_model (int): Dimension of the model.
@@ -565,11 +523,6 @@ class FpnNeck(nn.Module):
565
523
  fpn_interp_model (str): Interpolation mode for FPN feature resizing.
566
524
  fuse_type (str): Type of feature fusion, either 'sum' or 'avg'.
567
525
  fpn_top_down_levels (Optional[list[int]]): Levels to have top-down features in outputs.
568
-
569
- Examples:
570
- >>> backbone_channels = [64, 128, 256, 512]
571
- >>> fpn_neck = FpnNeck(256, backbone_channels)
572
- >>> print(fpn_neck)
573
526
  """
574
527
  super().__init__()
575
528
  self.position_encoding = PositionEmbeddingSine(num_pos_feats=256)
@@ -603,8 +556,7 @@ class FpnNeck(nn.Module):
603
556
  self.fpn_top_down_levels = list(fpn_top_down_levels)
604
557
 
605
558
  def forward(self, xs: list[torch.Tensor]):
606
- """
607
- Perform forward pass through the Feature Pyramid Network (FPN) neck.
559
+ """Perform forward pass through the Feature Pyramid Network (FPN) neck.
608
560
 
609
561
  This method processes a list of input tensors from the backbone through the FPN, applying lateral connections
610
562
  and top-down feature fusion. It generates output feature maps and corresponding positional encodings.
@@ -613,8 +565,8 @@ class FpnNeck(nn.Module):
613
565
  xs (list[torch.Tensor]): List of input tensors from the backbone, each with shape (B, C, H, W).
614
566
 
615
567
  Returns:
616
- out (list[torch.Tensor]): List of output feature maps after FPN processing, each with shape
617
- (B, d_model, H, W).
568
+ out (list[torch.Tensor]): List of output feature maps after FPN processing, each with shape (B, d_model, H,
569
+ W).
618
570
  pos (list[torch.Tensor]): List of positional encodings corresponding to each output feature map.
619
571
 
620
572
  Examples:
@@ -656,12 +608,11 @@ class FpnNeck(nn.Module):
656
608
 
657
609
 
658
610
  class Hiera(nn.Module):
659
- """
660
- Hierarchical vision transformer for efficient multiscale feature extraction in image processing tasks.
611
+ """Hierarchical vision transformer for efficient multiscale feature extraction in image processing tasks.
661
612
 
662
- This class implements a Hiera model, which is a hierarchical vision transformer architecture designed for
663
- efficient multiscale feature extraction. It uses a series of transformer blocks organized into stages,
664
- with optional pooling and global attention mechanisms.
613
+ This class implements a Hiera model, which is a hierarchical vision transformer architecture designed for efficient
614
+ multiscale feature extraction. It uses a series of transformer blocks organized into stages, with optional pooling
615
+ and global attention mechanisms.
665
616
 
666
617
  Attributes:
667
618
  window_spec (tuple[int, ...]): Window sizes for each stage.
@@ -715,12 +666,11 @@ class Hiera(nn.Module):
715
666
  ),
716
667
  return_interm_layers=True, # return feats from every stage
717
668
  ):
718
- """
719
- Initialize a Hiera model, a hierarchical vision transformer for efficient multiscale feature extraction.
669
+ """Initialize a Hiera model, a hierarchical vision transformer for efficient multiscale feature extraction.
720
670
 
721
- Hiera is a hierarchical vision transformer architecture designed for efficient multiscale feature extraction
722
- in image processing tasks. It uses a series of transformer blocks organized into stages, with optional
723
- pooling and global attention mechanisms.
671
+ Hiera is a hierarchical vision transformer architecture designed for efficient multiscale feature extraction in
672
+ image processing tasks. It uses a series of transformer blocks organized into stages, with optional pooling and
673
+ global attention mechanisms.
724
674
 
725
675
  Args:
726
676
  embed_dim (int): Initial embedding dimension for the model.
@@ -731,17 +681,11 @@ class Hiera(nn.Module):
731
681
  stages (tuple[int, ...]): Number of blocks per stage.
732
682
  dim_mul (float): Dimension multiplier factor at stage transitions.
733
683
  head_mul (float): Head multiplier factor at stage transitions.
734
- window_pos_embed_bkg_spatial_size (tuple[int, int]): Spatial size for window positional embedding background.
684
+ window_pos_embed_bkg_spatial_size (tuple[int, int]): Spatial size for window positional embedding
685
+ background.
735
686
  window_spec (tuple[int, ...]): Window sizes for each stage when not using global attention.
736
687
  global_att_blocks (tuple[int, ...]): Indices of blocks that use global attention.
737
688
  return_interm_layers (bool): Whether to return intermediate layer outputs.
738
-
739
- Examples:
740
- >>> model = Hiera(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3))
741
- >>> input_tensor = torch.randn(1, 3, 224, 224)
742
- >>> output_features = model(input_tensor)
743
- >>> for feat in output_features:
744
- ... print(feat.shape)
745
689
  """
746
690
  super().__init__()
747
691
 
@@ -816,8 +760,7 @@ class Hiera(nn.Module):
816
760
  return pos_embed
817
761
 
818
762
  def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
819
- """
820
- Perform forward pass through Hiera model, extracting multiscale features from input images.
763
+ """Perform forward pass through Hiera model, extracting multiscale features from input images.
821
764
 
822
765
  Args:
823
766
  x (torch.Tensor): Input tensor with shape (B, C, H, W) representing a batch of images.
@@ -11,8 +11,7 @@ from .blocks import RoPEAttention
11
11
 
12
12
 
13
13
  class MemoryAttentionLayer(nn.Module):
14
- """
15
- Implements a memory attention layer with self-attention and cross-attention mechanisms for neural networks.
14
+ """Implements a memory attention layer with self-attention and cross-attention mechanisms for neural networks.
16
15
 
17
16
  This class combines self-attention, cross-attention, and feedforward components to process input tensors and
18
17
  generate memory-based attention outputs.
@@ -60,9 +59,10 @@ class MemoryAttentionLayer(nn.Module):
60
59
  pos_enc_at_attn: bool = False,
61
60
  pos_enc_at_cross_attn_keys: bool = True,
62
61
  pos_enc_at_cross_attn_queries: bool = False,
62
+ self_attn: nn.Module | None = None,
63
+ cross_attn: nn.Module | None = None,
63
64
  ):
64
- """
65
- Initialize a memory attention layer with self-attention, cross-attention, and feedforward components.
65
+ """Initialize a memory attention layer with self-attention, cross-attention, and feedforward components.
66
66
 
67
67
  Args:
68
68
  d_model (int): Dimensionality of the model.
@@ -71,13 +71,15 @@ class MemoryAttentionLayer(nn.Module):
71
71
  pos_enc_at_attn (bool): Whether to add positional encoding at attention.
72
72
  pos_enc_at_cross_attn_keys (bool): Whether to add positional encoding to cross-attention keys.
73
73
  pos_enc_at_cross_attn_queries (bool): Whether to add positional encoding to cross-attention queries.
74
+ self_attn (nn.Module | None): Custom self-attention module. If None, a default RoPEAttention is used.
75
+ cross_attn (nn.Module | None): Custom cross-attention module. If None, a default RoPEAttention is used.
74
76
  """
75
77
  super().__init__()
76
78
  self.d_model = d_model
77
79
  self.dim_feedforward = dim_feedforward
78
80
  self.dropout_value = dropout
79
- self.self_attn = RoPEAttention(embedding_dim=256, num_heads=1, downsample_rate=1)
80
- self.cross_attn_image = RoPEAttention(
81
+ self.self_attn = self_attn or RoPEAttention(embedding_dim=256, num_heads=1, downsample_rate=1)
82
+ self.cross_attn_image = cross_attn or RoPEAttention(
81
83
  rope_k_repeat=True,
82
84
  embedding_dim=256,
83
85
  num_heads=1,
@@ -145,8 +147,7 @@ class MemoryAttentionLayer(nn.Module):
145
147
  query_pos: torch.Tensor | None = None,
146
148
  num_k_exclude_rope: int = 0,
147
149
  ) -> torch.Tensor:
148
- """
149
- Process input tensors through self-attention, cross-attention, and feedforward network layers.
150
+ """Process input tensors through self-attention, cross-attention, and feedforward network layers.
150
151
 
151
152
  Args:
152
153
  tgt (torch.Tensor): Target tensor for self-attention with shape (N, L, D).
@@ -168,11 +169,10 @@ class MemoryAttentionLayer(nn.Module):
168
169
 
169
170
 
170
171
  class MemoryAttention(nn.Module):
171
- """
172
- Memory attention module for processing sequential data with self and cross-attention mechanisms.
172
+ """Memory attention module for processing sequential data with self and cross-attention mechanisms.
173
173
 
174
- This class implements a multi-layer attention mechanism that combines self-attention and cross-attention
175
- for processing sequential data, particularly useful in transformer-like architectures.
174
+ This class implements a multi-layer attention mechanism that combines self-attention and cross-attention for
175
+ processing sequential data, particularly useful in transformer-like architectures.
176
176
 
177
177
  Attributes:
178
178
  d_model (int): The dimension of the model's hidden state.
@@ -206,11 +206,10 @@ class MemoryAttention(nn.Module):
206
206
  num_layers: int,
207
207
  batch_first: bool = True, # Do layers expect batch first input?
208
208
  ):
209
- """
210
- Initialize MemoryAttention with specified layers and normalization for sequential data processing.
209
+ """Initialize MemoryAttention with specified layers and normalization for sequential data processing.
211
210
 
212
- This class implements a multi-layer attention mechanism that combines self-attention and cross-attention
213
- for processing sequential data, particularly useful in transformer-like architectures.
211
+ This class implements a multi-layer attention mechanism that combines self-attention and cross-attention for
212
+ processing sequential data, particularly useful in transformer-like architectures.
214
213
 
215
214
  Args:
216
215
  d_model (int): The dimension of the model's hidden state.
@@ -218,18 +217,6 @@ class MemoryAttention(nn.Module):
218
217
  layer (nn.Module): The attention layer to be used in the module.
219
218
  num_layers (int): The number of attention layers.
220
219
  batch_first (bool): Whether the input tensors are in batch-first format.
221
-
222
- Examples:
223
- >>> d_model = 256
224
- >>> layer = MemoryAttentionLayer(d_model)
225
- >>> attention = MemoryAttention(d_model, pos_enc_at_input=True, layer=layer, num_layers=3)
226
- >>> curr = torch.randn(10, 32, d_model) # (seq_len, batch_size, d_model)
227
- >>> memory = torch.randn(20, 32, d_model) # (mem_len, batch_size, d_model)
228
- >>> curr_pos = torch.randn(10, 32, d_model)
229
- >>> memory_pos = torch.randn(20, 32, d_model)
230
- >>> output = attention(curr, memory, curr_pos, memory_pos)
231
- >>> print(output.shape)
232
- torch.Size([10, 32, 256])
233
220
  """
234
221
  super().__init__()
235
222
  self.d_model = d_model
@@ -247,8 +234,7 @@ class MemoryAttention(nn.Module):
247
234
  memory_pos: torch.Tensor | None = None, # pos_enc for cross-attention inputs
248
235
  num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
249
236
  ) -> torch.Tensor:
250
- """
251
- Process inputs through attention layers, applying self and cross-attention with positional encoding.
237
+ """Process inputs through attention layers, applying self and cross-attention with positional encoding.
252
238
 
253
239
  Args:
254
240
  curr (torch.Tensor): Self-attention input tensor, representing the current state.