ultralytics 8.3.89__py3-none-any.whl → 8.3.91__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (156) hide show
  1. tests/conftest.py +2 -2
  2. tests/test_cli.py +13 -11
  3. tests/test_cuda.py +10 -1
  4. tests/test_exports.py +2 -2
  5. tests/test_integrations.py +1 -5
  6. tests/test_python.py +16 -16
  7. tests/test_solutions.py +9 -9
  8. ultralytics/__init__.py +1 -1
  9. ultralytics/cfg/__init__.py +3 -1
  10. ultralytics/cfg/models/11/yolo11-cls.yaml +5 -5
  11. ultralytics/cfg/models/11/yolo11-obb.yaml +5 -5
  12. ultralytics/cfg/models/11/yolo11-pose.yaml +5 -5
  13. ultralytics/cfg/models/11/yolo11-seg.yaml +5 -5
  14. ultralytics/cfg/models/11/yolo11.yaml +5 -5
  15. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +5 -5
  16. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +5 -5
  17. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -5
  18. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -5
  19. ultralytics/cfg/models/v8/yolov8-p6.yaml +5 -5
  20. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -5
  21. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -5
  22. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -5
  23. ultralytics/cfg/models/v8/yolov8.yaml +5 -5
  24. ultralytics/cfg/models/v9/yolov9c-seg.yaml +1 -1
  25. ultralytics/cfg/models/v9/yolov9c.yaml +1 -1
  26. ultralytics/cfg/models/v9/yolov9e-seg.yaml +1 -1
  27. ultralytics/cfg/models/v9/yolov9e.yaml +1 -1
  28. ultralytics/cfg/models/v9/yolov9m.yaml +1 -1
  29. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  30. ultralytics/cfg/models/v9/yolov9t.yaml +1 -1
  31. ultralytics/data/annotator.py +9 -14
  32. ultralytics/data/base.py +118 -30
  33. ultralytics/data/build.py +63 -24
  34. ultralytics/data/converter.py +5 -5
  35. ultralytics/data/dataset.py +207 -53
  36. ultralytics/data/loaders.py +1 -0
  37. ultralytics/data/split_dota.py +39 -12
  38. ultralytics/data/utils.py +15 -19
  39. ultralytics/engine/exporter.py +24 -23
  40. ultralytics/engine/model.py +67 -88
  41. ultralytics/engine/predictor.py +106 -21
  42. ultralytics/engine/trainer.py +32 -23
  43. ultralytics/engine/tuner.py +21 -18
  44. ultralytics/engine/validator.py +75 -41
  45. ultralytics/hub/__init__.py +12 -13
  46. ultralytics/hub/auth.py +9 -12
  47. ultralytics/hub/session.py +76 -21
  48. ultralytics/hub/utils.py +19 -17
  49. ultralytics/models/fastsam/model.py +20 -11
  50. ultralytics/models/fastsam/predict.py +36 -16
  51. ultralytics/models/fastsam/utils.py +5 -5
  52. ultralytics/models/fastsam/val.py +6 -6
  53. ultralytics/models/nas/model.py +22 -11
  54. ultralytics/models/nas/predict.py +9 -4
  55. ultralytics/models/nas/val.py +5 -5
  56. ultralytics/models/rtdetr/model.py +20 -11
  57. ultralytics/models/rtdetr/predict.py +18 -15
  58. ultralytics/models/rtdetr/train.py +20 -16
  59. ultralytics/models/rtdetr/val.py +42 -6
  60. ultralytics/models/sam/__init__.py +1 -1
  61. ultralytics/models/sam/amg.py +50 -4
  62. ultralytics/models/sam/model.py +8 -14
  63. ultralytics/models/sam/modules/decoders.py +18 -21
  64. ultralytics/models/sam/modules/encoders.py +25 -46
  65. ultralytics/models/sam/modules/memory_attention.py +19 -15
  66. ultralytics/models/sam/modules/sam.py +18 -25
  67. ultralytics/models/sam/modules/tiny_encoder.py +19 -29
  68. ultralytics/models/sam/modules/transformer.py +35 -57
  69. ultralytics/models/sam/modules/utils.py +15 -15
  70. ultralytics/models/sam/predict.py +0 -3
  71. ultralytics/models/utils/loss.py +87 -36
  72. ultralytics/models/utils/ops.py +26 -31
  73. ultralytics/models/yolo/classify/predict.py +24 -3
  74. ultralytics/models/yolo/classify/train.py +77 -10
  75. ultralytics/models/yolo/classify/val.py +40 -15
  76. ultralytics/models/yolo/detect/predict.py +23 -10
  77. ultralytics/models/yolo/detect/train.py +85 -15
  78. ultralytics/models/yolo/detect/val.py +145 -21
  79. ultralytics/models/yolo/model.py +1 -2
  80. ultralytics/models/yolo/obb/predict.py +12 -4
  81. ultralytics/models/yolo/obb/train.py +7 -0
  82. ultralytics/models/yolo/obb/val.py +25 -7
  83. ultralytics/models/yolo/pose/predict.py +22 -6
  84. ultralytics/models/yolo/pose/train.py +17 -1
  85. ultralytics/models/yolo/pose/val.py +46 -21
  86. ultralytics/models/yolo/segment/predict.py +22 -8
  87. ultralytics/models/yolo/segment/train.py +6 -0
  88. ultralytics/models/yolo/segment/val.py +100 -14
  89. ultralytics/models/yolo/world/train.py +38 -8
  90. ultralytics/models/yolo/world/train_world.py +39 -10
  91. ultralytics/nn/autobackend.py +28 -14
  92. ultralytics/nn/modules/__init__.py +3 -0
  93. ultralytics/nn/modules/activation.py +12 -3
  94. ultralytics/nn/modules/block.py +587 -84
  95. ultralytics/nn/modules/conv.py +418 -54
  96. ultralytics/nn/modules/head.py +3 -4
  97. ultralytics/nn/modules/transformer.py +320 -34
  98. ultralytics/nn/modules/utils.py +17 -3
  99. ultralytics/nn/tasks.py +221 -69
  100. ultralytics/solutions/ai_gym.py +2 -2
  101. ultralytics/solutions/analytics.py +4 -4
  102. ultralytics/solutions/heatmap.py +4 -4
  103. ultralytics/solutions/instance_segmentation.py +10 -4
  104. ultralytics/solutions/object_blurrer.py +2 -2
  105. ultralytics/solutions/object_counter.py +2 -2
  106. ultralytics/solutions/object_cropper.py +2 -2
  107. ultralytics/solutions/parking_management.py +9 -9
  108. ultralytics/solutions/queue_management.py +1 -1
  109. ultralytics/solutions/region_counter.py +2 -2
  110. ultralytics/solutions/security_alarm.py +7 -7
  111. ultralytics/solutions/solutions.py +7 -4
  112. ultralytics/solutions/speed_estimation.py +2 -2
  113. ultralytics/solutions/streamlit_inference.py +6 -6
  114. ultralytics/solutions/trackzone.py +9 -2
  115. ultralytics/solutions/vision_eye.py +4 -4
  116. ultralytics/trackers/basetrack.py +1 -1
  117. ultralytics/trackers/bot_sort.py +23 -22
  118. ultralytics/trackers/byte_tracker.py +4 -4
  119. ultralytics/trackers/track.py +2 -1
  120. ultralytics/trackers/utils/gmc.py +26 -27
  121. ultralytics/trackers/utils/kalman_filter.py +31 -29
  122. ultralytics/trackers/utils/matching.py +7 -7
  123. ultralytics/utils/__init__.py +32 -27
  124. ultralytics/utils/autobatch.py +5 -5
  125. ultralytics/utils/benchmarks.py +111 -18
  126. ultralytics/utils/callbacks/base.py +3 -3
  127. ultralytics/utils/callbacks/clearml.py +11 -11
  128. ultralytics/utils/callbacks/comet.py +42 -24
  129. ultralytics/utils/callbacks/dvc.py +11 -10
  130. ultralytics/utils/callbacks/hub.py +8 -8
  131. ultralytics/utils/callbacks/mlflow.py +1 -1
  132. ultralytics/utils/callbacks/neptune.py +12 -10
  133. ultralytics/utils/callbacks/raytune.py +1 -1
  134. ultralytics/utils/callbacks/tensorboard.py +6 -6
  135. ultralytics/utils/callbacks/wb.py +16 -16
  136. ultralytics/utils/checks.py +116 -35
  137. ultralytics/utils/dist.py +15 -2
  138. ultralytics/utils/downloads.py +13 -9
  139. ultralytics/utils/files.py +12 -13
  140. ultralytics/utils/instance.py +112 -45
  141. ultralytics/utils/loss.py +28 -33
  142. ultralytics/utils/metrics.py +246 -181
  143. ultralytics/utils/ops.py +61 -53
  144. ultralytics/utils/patches.py +8 -6
  145. ultralytics/utils/plotting.py +65 -45
  146. ultralytics/utils/tal.py +88 -57
  147. ultralytics/utils/torch_utils.py +181 -33
  148. ultralytics/utils/triton.py +13 -3
  149. ultralytics/utils/tuner.py +8 -16
  150. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/METADATA +1 -1
  151. ultralytics-8.3.91.dist-info/RECORD +250 -0
  152. ultralytics-8.3.89.dist-info/RECORD +0 -250
  153. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/LICENSE +0 -0
  154. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/WHEEL +0 -0
  155. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/entry_points.txt +0 -0
  156. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/top_level.txt +0 -0
@@ -176,7 +176,7 @@ class SAM2Model(torch.nn.Module):
176
176
  compile_image_encoder: bool = False,
177
177
  ):
178
178
  """
179
- Initializes the SAM2Model for video object segmentation with memory-based tracking.
179
+ Initialize the SAM2Model for video object segmentation with memory-based tracking.
180
180
 
181
181
  Args:
182
182
  image_encoder (nn.Module): Visual encoder for extracting image features.
@@ -213,9 +213,9 @@ class SAM2Model(torch.nn.Module):
213
213
  the encoder.
214
214
  proj_tpos_enc_in_obj_ptrs (bool): Whether to add an extra linear projection layer for temporal positional
215
215
  encoding in object pointers.
216
- use_signed_tpos_enc_to_obj_ptrs (bool): whether to use signed distance (instead of unsigned absolute distance)
217
- in the temporal positional encoding in the object pointers, only relevant when both `use_obj_ptrs_in_encoder=True`
218
- and `add_tpos_enc_to_obj_ptrs=True`.
216
+ use_signed_tpos_enc_to_obj_ptrs (bool): Whether to use signed distance (instead of unsigned absolute distance)
217
+ in the temporal positional encoding in the object pointers, only relevant when both
218
+ `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`.
219
219
  only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past
220
220
  during evaluation.
221
221
  pred_obj_scores (bool): Whether to predict if there is an object in the frame.
@@ -332,18 +332,18 @@ class SAM2Model(torch.nn.Module):
332
332
 
333
333
  @property
334
334
  def device(self):
335
- """Returns the device on which the model's parameters are stored."""
335
+ """Return the device on which the model's parameters are stored."""
336
336
  return next(self.parameters()).device
337
337
 
338
338
  def forward(self, *args, **kwargs):
339
- """Processes image and prompt inputs to generate object masks and scores in video sequences."""
339
+ """Process image and prompt inputs to generate object masks and scores in video sequences."""
340
340
  raise NotImplementedError(
341
341
  "Please use the corresponding methods in SAM2VideoPredictor for inference."
342
342
  "See notebooks/video_predictor_example.ipynb for an example."
343
343
  )
344
344
 
345
345
  def _build_sam_heads(self):
346
- """Builds SAM-style prompt encoder and mask decoder for image segmentation tasks."""
346
+ """Build SAM-style prompt encoder and mask decoder for image segmentation tasks."""
347
347
  self.sam_prompt_embed_dim = self.hidden_dim
348
348
  self.sam_image_embedding_size = self.image_size // self.backbone_stride
349
349
 
@@ -545,7 +545,7 @@ class SAM2Model(torch.nn.Module):
545
545
  )
546
546
 
547
547
  def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
548
- """Processes mask inputs directly as output, bypassing SAM encoder/decoder."""
548
+ """Process mask inputs directly as output, bypassing SAM encoder/decoder."""
549
549
  # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
550
550
  out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
551
551
  mask_inputs_float = mask_inputs.float()
@@ -592,7 +592,7 @@ class SAM2Model(torch.nn.Module):
592
592
  )
593
593
 
594
594
  def forward_image(self, img_batch: torch.Tensor):
595
- """Processes image batch through encoder to extract multi-level features for SAM model."""
595
+ """Process image batch through encoder to extract multi-level features for SAM model."""
596
596
  backbone_out = self.image_encoder(img_batch)
597
597
  if self.use_high_res_features_in_sam:
598
598
  # precompute projected level 0 and level 1 features in SAM decoder
@@ -602,7 +602,7 @@ class SAM2Model(torch.nn.Module):
602
602
  return backbone_out
603
603
 
604
604
  def _prepare_backbone_features(self, backbone_out):
605
- """Prepares and flattens visual features from the image backbone output for further processing."""
605
+ """Prepare and flatten visual features from the image backbone output for further processing."""
606
606
  assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
607
607
  assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
608
608
 
@@ -627,7 +627,7 @@ class SAM2Model(torch.nn.Module):
627
627
  num_frames,
628
628
  track_in_reverse=False, # tracking in reverse time order (for demo usage)
629
629
  ):
630
- """Prepares memory-conditioned features by fusing current frame's visual features with previous memories."""
630
+ """Prepare memory-conditioned features by fusing current frame's visual features with previous memories."""
631
631
  B = current_vision_feats[-1].size(1) # batch size on this frame
632
632
  C = self.hidden_dim
633
633
  H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
@@ -788,7 +788,7 @@ class SAM2Model(torch.nn.Module):
788
788
  object_score_logits,
789
789
  is_mask_from_pts,
790
790
  ):
791
- """Encodes frame features and masks into a new memory representation for video segmentation."""
791
+ """Encode frame features and masks into a new memory representation for video segmentation."""
792
792
  B = current_vision_feats[-1].size(1) # batch size on this frame
793
793
  C = self.hidden_dim
794
794
  H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
@@ -838,7 +838,7 @@ class SAM2Model(torch.nn.Module):
838
838
  track_in_reverse,
839
839
  prev_sam_mask_logits,
840
840
  ):
841
- """Performs a single tracking step, updating object masks and memory features based on current frame inputs."""
841
+ """Perform a single tracking step, updating object masks and memory features based on current frame inputs."""
842
842
  current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
843
843
  # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
844
844
  if len(current_vision_feats) > 1:
@@ -893,9 +893,7 @@ class SAM2Model(torch.nn.Module):
893
893
  object_score_logits,
894
894
  current_out,
895
895
  ):
896
- """Finally run the memory encoder on the predicted mask to encode, it into a new memory feature (that can be
897
- used in future frames).
898
- """
896
+ """Run memory encoder on predicted mask to encode it into a new memory feature for future frames."""
899
897
  if run_mem_encoder and self.num_maskmem > 0:
900
898
  high_res_masks_for_mem_enc = high_res_masks
901
899
  maskmem_features, maskmem_pos_enc = self._encode_new_memory(
@@ -932,7 +930,7 @@ class SAM2Model(torch.nn.Module):
932
930
  # The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
933
931
  prev_sam_mask_logits=None,
934
932
  ):
935
- """Performs a single tracking step, updating object masks and memory features based on current frame inputs."""
933
+ """Perform a single tracking step, updating object masks and memory features based on current frame inputs."""
936
934
  current_out, sam_outputs, _, _ = self._track_step(
937
935
  frame_idx,
938
936
  is_init_cond_frame,
@@ -970,7 +968,7 @@ class SAM2Model(torch.nn.Module):
970
968
  return current_out
971
969
 
972
970
  def _use_multimask(self, is_init_cond_frame, point_inputs):
973
- """Determines whether to use multiple mask outputs in the SAM head based on configuration and inputs."""
971
+ """Determine whether to use multiple mask outputs in the SAM head based on configuration and inputs."""
974
972
  num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
975
973
  return (
976
974
  self.multimask_output_in_sam
@@ -980,7 +978,7 @@ class SAM2Model(torch.nn.Module):
980
978
 
981
979
  @staticmethod
982
980
  def _apply_non_overlapping_constraints(pred_masks):
983
- """Applies non-overlapping constraints to masks, keeping the highest scoring object per location."""
981
+ """Apply non-overlapping constraints to masks, keeping the highest scoring object per location."""
984
982
  batch_size = pred_masks.size(0)
985
983
  if batch_size == 1:
986
984
  return pred_masks
@@ -1001,12 +999,7 @@ class SAM2Model(torch.nn.Module):
1001
999
  self.binarize_mask_from_pts_for_mem_enc = binarize
1002
1000
 
1003
1001
  def set_imgsz(self, imgsz):
1004
- """
1005
- Set image size to make model compatible with different image sizes.
1006
-
1007
- Args:
1008
- imgsz (Tuple[int, int]): The size of the input image.
1009
- """
1002
+ """Set image size to make model compatible with different image sizes."""
1010
1003
  self.image_size = imgsz[0]
1011
1004
  self.sam_prompt_encoder.input_image_size = imgsz
1012
1005
  self.sam_prompt_encoder.image_embedding_size = [x // 16 for x in imgsz] # fixed ViT patch size of 16
@@ -27,7 +27,7 @@ class Conv2d_BN(torch.nn.Sequential):
27
27
 
28
28
  Attributes:
29
29
  c (torch.nn.Conv2d): 2D convolution layer.
30
- 1 (torch.nn.BatchNorm2d): Batch normalization layer.
30
+ bn (torch.nn.BatchNorm2d): Batch normalization layer.
31
31
 
32
32
  Methods:
33
33
  __init__: Initializes the Conv2d_BN with specified parameters.
@@ -265,9 +265,9 @@ class ConvLayer(nn.Module):
265
265
  dim (int): The dimensionality of the input and output.
266
266
  input_resolution (Tuple[int, int]): The resolution of the input image.
267
267
  depth (int): The number of MBConv layers in the block.
268
- activation (Callable): Activation function applied after each convolution.
268
+ activation (nn.Module): Activation function applied after each convolution.
269
269
  drop_path (float | List[float]): Drop path rate. Single float or a list of floats for each MBConv.
270
- downsample (Optional[Callable]): Function for downsampling the output. None to skip downsampling.
270
+ downsample (Optional[nn.Module]): Function for downsampling the output. None to skip downsampling.
271
271
  use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
272
272
  out_dim (Optional[int]): The dimensionality of the output. None means it will be the same as `dim`.
273
273
  conv_expand_ratio (float): Expansion ratio for the MBConv layers.
@@ -413,12 +413,9 @@ class Attention(torch.nn.Module):
413
413
  Args:
414
414
  dim (int): The dimensionality of the input and output.
415
415
  key_dim (int): The dimensionality of the keys and queries.
416
- num_heads (int): Number of attention heads. Default is 8.
417
- attn_ratio (float): Attention ratio, affecting the dimensions of the value vectors. Default is 4.
418
- resolution (Tuple[int, int]): Spatial resolution of the input feature map. Default is (14, 14).
419
-
420
- Raises:
421
- AssertionError: If 'resolution' is not a tuple of length 2.
416
+ num_heads (int): Number of attention heads.
417
+ attn_ratio (float): Attention ratio, affecting the dimensions of the value vectors.
418
+ resolution (Tuple[int, int]): Spatial resolution of the input feature map.
422
419
 
423
420
  Examples:
424
421
  >>> attn = Attention(dim=256, key_dim=64, num_heads=8, resolution=(14, 14))
@@ -821,22 +818,20 @@ class TinyViT(nn.Module):
821
818
  attention and convolution blocks, and a classification head.
822
819
 
823
820
  Args:
824
- img_size (int): Size of the input image. Default is 224.
825
- in_chans (int): Number of input channels. Default is 3.
826
- num_classes (int): Number of classes for classification. Default is 1000.
821
+ img_size (int): Size of the input image.
822
+ in_chans (int): Number of input channels.
823
+ num_classes (int): Number of classes for classification.
827
824
  embed_dims (Tuple[int, int, int, int]): Embedding dimensions for each stage.
828
- Default is (96, 192, 384, 768).
829
- depths (Tuple[int, int, int, int]): Number of blocks in each stage. Default is (2, 2, 6, 2).
825
+ depths (Tuple[int, int, int, int]): Number of blocks in each stage.
830
826
  num_heads (Tuple[int, int, int, int]): Number of attention heads in each stage.
831
- Default is (3, 6, 12, 24).
832
- window_sizes (Tuple[int, int, int, int]): Window sizes for each stage. Default is (7, 7, 14, 7).
833
- mlp_ratio (float): Ratio of MLP hidden dim to embedding dim. Default is 4.0.
834
- drop_rate (float): Dropout rate. Default is 0.0.
835
- drop_path_rate (float): Stochastic depth rate. Default is 0.1.
836
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default is False.
837
- mbconv_expand_ratio (float): Expansion ratio for MBConv layer. Default is 4.0.
838
- local_conv_size (int): Kernel size for local convolutions. Default is 3.
839
- layer_lr_decay (float): Layer-wise learning rate decay factor. Default is 1.0.
827
+ window_sizes (Tuple[int, int, int, int]): Window sizes for each stage.
828
+ mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
829
+ drop_rate (float): Dropout rate.
830
+ drop_path_rate (float): Stochastic depth rate.
831
+ use_checkpoint (bool): Whether to use checkpointing to save memory.
832
+ mbconv_expand_ratio (float): Expansion ratio for MBConv layer.
833
+ local_conv_size (int): Kernel size for local convolutions.
834
+ layer_lr_decay (float): Layer-wise learning rate decay factor.
840
835
 
841
836
  Examples:
842
837
  >>> model = TinyViT(img_size=224, num_classes=1000)
@@ -992,12 +987,7 @@ class TinyViT(nn.Module):
992
987
  return self.forward_features(x)
993
988
 
994
989
  def set_imgsz(self, imgsz=[1024, 1024]):
995
- """
996
- Set image size to make model compatible with different image sizes.
997
-
998
- Args:
999
- imgsz (Tuple[int, int]): The size of the input image.
1000
- """
990
+ """Set image size to make model compatible with different image sizes."""
1001
991
  imgsz = [s // 4 for s in imgsz]
1002
992
  self.patches_resolution = imgsz
1003
993
  for i, layer in enumerate(self.layers):
@@ -57,23 +57,6 @@ class TwoWayTransformer(nn.Module):
57
57
  mlp_dim (int): Internal channel dimension for the MLP block.
58
58
  activation (Type[nn.Module]): Activation function to use in the MLP block.
59
59
  attention_downsample_rate (int): Downsampling rate for attention mechanism.
60
-
61
- Attributes:
62
- depth (int): Number of layers in the transformer.
63
- embedding_dim (int): Channel dimension for input embeddings.
64
- num_heads (int): Number of heads for multihead attention.
65
- mlp_dim (int): Internal channel dimension for the MLP block.
66
- layers (nn.ModuleList): List of TwoWayAttentionBlock layers.
67
- final_attn_token_to_image (Attention): Final attention layer from queries to image.
68
- norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.
69
-
70
- Examples:
71
- >>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
72
- >>> image_embedding = torch.randn(1, 256, 32, 32)
73
- >>> image_pe = torch.randn(1, 256, 32, 32)
74
- >>> point_embedding = torch.randn(1, 100, 256)
75
- >>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
76
- >>> print(output_queries.shape, output_image.shape)
77
60
  """
78
61
  super().__init__()
79
62
  self.depth = depth
@@ -104,23 +87,16 @@ class TwoWayTransformer(nn.Module):
104
87
  point_embedding: Tensor,
105
88
  ) -> Tuple[Tensor, Tensor]:
106
89
  """
107
- Processes image and point embeddings through the Two-Way Transformer.
90
+ Process image and point embeddings through the Two-Way Transformer.
108
91
 
109
92
  Args:
110
- image_embedding (torch.Tensor): Image to attend to, with shape (B, embedding_dim, H, W).
111
- image_pe (torch.Tensor): Positional encoding to add to the image, with same shape as image_embedding.
112
- point_embedding (torch.Tensor): Embedding to add to query points, with shape (B, N_points, embedding_dim).
93
+ image_embedding (Tensor): Image to attend to, with shape (B, embedding_dim, H, W).
94
+ image_pe (Tensor): Positional encoding to add to the image, with same shape as image_embedding.
95
+ point_embedding (Tensor): Embedding to add to query points, with shape (B, N_points, embedding_dim).
113
96
 
114
97
  Returns:
115
- (Tuple[torch.Tensor, torch.Tensor]): Processed point_embedding and image_embedding.
116
-
117
- Examples:
118
- >>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
119
- >>> image_embedding = torch.randn(1, 256, 32, 32)
120
- >>> image_pe = torch.randn(1, 256, 32, 32)
121
- >>> point_embedding = torch.randn(1, 100, 256)
122
- >>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
123
- >>> print(output_queries.shape, output_image.shape)
98
+ queries (Tensor): Processed point embeddings with shape (B, N_points, embedding_dim).
99
+ keys (Tensor): Processed image embeddings with shape (B, H*W, embedding_dim).
124
100
  """
125
101
  # BxCxHxW -> BxHWxC == B x N_image_tokens x C
126
102
  image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
@@ -191,7 +167,7 @@ class TwoWayAttentionBlock(nn.Module):
191
167
  skip_first_layer_pe: bool = False,
192
168
  ) -> None:
193
169
  """
194
- Initializes a TwoWayAttentionBlock for simultaneous attention to image and query points.
170
+ Initialize a TwoWayAttentionBlock for simultaneous attention to image and query points.
195
171
 
196
172
  This block implements a specialized transformer layer with four main components: self-attention on sparse
197
173
  inputs, cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention
@@ -204,15 +180,6 @@ class TwoWayAttentionBlock(nn.Module):
204
180
  activation (Type[nn.Module]): Activation function for the MLP block.
205
181
  attention_downsample_rate (int): Downsampling rate for the attention mechanism.
206
182
  skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer.
207
-
208
- Examples:
209
- >>> embedding_dim, num_heads = 256, 8
210
- >>> block = TwoWayAttentionBlock(embedding_dim, num_heads)
211
- >>> queries = torch.randn(1, 100, embedding_dim)
212
- >>> keys = torch.randn(1, 1000, embedding_dim)
213
- >>> query_pe = torch.randn(1, 100, embedding_dim)
214
- >>> key_pe = torch.randn(1, 1000, embedding_dim)
215
- >>> processed_queries, processed_keys = block(queries, keys, query_pe, key_pe)
216
183
  """
217
184
  super().__init__()
218
185
  self.self_attn = Attention(embedding_dim, num_heads)
@@ -230,7 +197,19 @@ class TwoWayAttentionBlock(nn.Module):
230
197
  self.skip_first_layer_pe = skip_first_layer_pe
231
198
 
232
199
  def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]:
233
- """Applies two-way attention to process query and key embeddings in a transformer block."""
200
+ """
201
+ Apply two-way attention to process query and key embeddings in a transformer block.
202
+
203
+ Args:
204
+ queries (Tensor): Query embeddings with shape (B, N_queries, embedding_dim).
205
+ keys (Tensor): Key embeddings with shape (B, N_keys, embedding_dim).
206
+ query_pe (Tensor): Positional encodings for queries with same shape as queries.
207
+ key_pe (Tensor): Positional encodings for keys with same shape as keys.
208
+
209
+ Returns:
210
+ queries (Tensor): Processed query embeddings with shape (B, N_queries, embedding_dim).
211
+ keys (Tensor): Processed key embeddings with shape (B, N_keys, embedding_dim).
212
+ """
234
213
  # Self attention block
235
214
  if self.skip_first_layer_pe:
236
215
  queries = self.self_attn(q=queries, k=queries, v=queries)
@@ -301,27 +280,16 @@ class Attention(nn.Module):
301
280
  kv_in_dim: int = None,
302
281
  ) -> None:
303
282
  """
304
- Initializes the Attention module with specified dimensions and settings.
305
-
306
- This class implements a multi-head attention mechanism with optional downsampling of the internal
307
- dimension for queries, keys, and values.
283
+ Initialize the Attention module with specified dimensions and settings.
308
284
 
309
285
  Args:
310
286
  embedding_dim (int): Dimensionality of input embeddings.
311
287
  num_heads (int): Number of attention heads.
312
- downsample_rate (int): Factor by which internal dimensions are downsampled. Defaults to 1.
288
+ downsample_rate (int): Factor by which internal dimensions are downsampled.
313
289
  kv_in_dim (int | None): Dimensionality of key and value inputs. If None, uses embedding_dim.
314
290
 
315
291
  Raises:
316
292
  AssertionError: If num_heads does not evenly divide the internal dim (embedding_dim / downsample_rate).
317
-
318
- Examples:
319
- >>> attn = Attention(embedding_dim=256, num_heads=8, downsample_rate=2)
320
- >>> q = torch.randn(1, 100, 256)
321
- >>> k = v = torch.randn(1, 50, 256)
322
- >>> output = attn(q, k, v)
323
- >>> print(output.shape)
324
- torch.Size([1, 100, 256])
325
293
  """
326
294
  super().__init__()
327
295
  self.embedding_dim = embedding_dim
@@ -337,20 +305,30 @@ class Attention(nn.Module):
337
305
 
338
306
  @staticmethod
339
307
  def _separate_heads(x: Tensor, num_heads: int) -> Tensor:
340
- """Separates the input tensor into the specified number of attention heads."""
308
+ """Separate the input tensor into the specified number of attention heads."""
341
309
  b, n, c = x.shape
342
310
  x = x.reshape(b, n, num_heads, c // num_heads)
343
311
  return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
344
312
 
345
313
  @staticmethod
346
314
  def _recombine_heads(x: Tensor) -> Tensor:
347
- """Recombines separated attention heads into a single tensor."""
315
+ """Recombine separated attention heads into a single tensor."""
348
316
  b, n_heads, n_tokens, c_per_head = x.shape
349
317
  x = x.transpose(1, 2)
350
318
  return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
351
319
 
352
320
  def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
353
- """Applies multi-head attention to query, key, and value tensors with optional downsampling."""
321
+ """
322
+ Apply multi-head attention to query, key, and value tensors with optional downsampling.
323
+
324
+ Args:
325
+ q (Tensor): Query tensor with shape (B, N_q, embedding_dim).
326
+ k (Tensor): Key tensor with shape (B, N_k, embedding_dim).
327
+ v (Tensor): Value tensor with shape (B, N_k, embedding_dim).
328
+
329
+ Returns:
330
+ (Tensor): Output tensor after attention with shape (B, N_q, embedding_dim).
331
+ """
354
332
  # Input projections
355
333
  q = self.q_proj(q)
356
334
  k = self.k_proj(k)
@@ -8,7 +8,7 @@ import torch.nn.functional as F
8
8
 
9
9
  def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
10
10
  """
11
- Selects the closest conditioning frames to a given frame index.
11
+ Select the closest conditioning frames to a given frame index.
12
12
 
13
13
  Args:
14
14
  frame_idx (int): Current frame index.
@@ -37,17 +37,17 @@ def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num
37
37
  assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames"
38
38
  selected_outputs = {}
39
39
 
40
- # the closest conditioning frame before `frame_idx` (if any)
40
+ # The closest conditioning frame before `frame_idx` (if any)
41
41
  idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None)
42
42
  if idx_before is not None:
43
43
  selected_outputs[idx_before] = cond_frame_outputs[idx_before]
44
44
 
45
- # the closest conditioning frame after `frame_idx` (if any)
45
+ # The closest conditioning frame after `frame_idx` (if any)
46
46
  idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None)
47
47
  if idx_after is not None:
48
48
  selected_outputs[idx_after] = cond_frame_outputs[idx_after]
49
49
 
50
- # add other temporally closest conditioning frames until reaching a total
50
+ # Add other temporally closest conditioning frames until reaching a total
51
51
  # of `max_cond_frame_num` conditioning frames.
52
52
  num_remain = max_cond_frame_num - len(selected_outputs)
53
53
  inds_remain = sorted(
@@ -61,7 +61,7 @@ def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num
61
61
 
62
62
 
63
63
  def get_1d_sine_pe(pos_inds, dim, temperature=10000):
64
- """Generates 1D sinusoidal positional embeddings for given positions and dimensions."""
64
+ """Generate 1D sinusoidal positional embeddings for given positions and dimensions."""
65
65
  pe_dim = dim // 2
66
66
  dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
67
67
  dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
@@ -72,7 +72,7 @@ def get_1d_sine_pe(pos_inds, dim, temperature=10000):
72
72
 
73
73
 
74
74
  def init_t_xy(end_x: int, end_y: int):
75
- """Initializes 1D and 2D coordinate tensors for a grid of specified dimensions."""
75
+ """Initialize 1D and 2D coordinate tensors for a grid of specified dimensions."""
76
76
  t = torch.arange(end_x * end_y, dtype=torch.float32)
77
77
  t_x = (t % end_x).float()
78
78
  t_y = torch.div(t, end_x, rounding_mode="floor").float()
@@ -80,7 +80,7 @@ def init_t_xy(end_x: int, end_y: int):
80
80
 
81
81
 
82
82
  def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
83
- """Computes axial complex exponential positional encodings for 2D spatial positions in a grid."""
83
+ """Compute axial complex exponential positional encodings for 2D spatial positions in a grid."""
84
84
  freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
85
85
  freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
86
86
 
@@ -93,7 +93,7 @@ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
93
93
 
94
94
 
95
95
  def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
96
- """Reshapes frequency tensor for broadcasting with input tensor, ensuring dimensional compatibility."""
96
+ """Reshape frequency tensor for broadcasting with input tensor, ensuring dimensional compatibility."""
97
97
  ndim = x.ndim
98
98
  assert 0 <= 1 < ndim
99
99
  assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
@@ -107,15 +107,15 @@ def apply_rotary_enc(
107
107
  freqs_cis: torch.Tensor,
108
108
  repeat_freqs_k: bool = False,
109
109
  ):
110
- """Applies rotary positional encoding to query and key tensors using complex-valued frequency components."""
110
+ """Apply rotary positional encoding to query and key tensors using complex-valued frequency components."""
111
111
  xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
112
112
  xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) if xk.shape[-2] != 0 else None
113
113
  freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
114
114
  xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
115
115
  if xk_ is None:
116
- # no keys to rotate, due to dropout
116
+ # No keys to rotate, due to dropout
117
117
  return xq_out.type_as(xq).to(xq.device), xk
118
- # repeat freqs along seq_len dim to match k seq_len
118
+ # Repeat freqs along seq_len dim to match k seq_len
119
119
  if repeat_freqs_k:
120
120
  r = xk_.shape[-2] // xq_.shape[-2]
121
121
  freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
@@ -125,7 +125,7 @@ def apply_rotary_enc(
125
125
 
126
126
  def window_partition(x, window_size):
127
127
  """
128
- Partitions input tensor into non-overlapping windows with padding if needed.
128
+ Partition input tensor into non-overlapping windows with padding if needed.
129
129
 
130
130
  Args:
131
131
  x (torch.Tensor): Input tensor with shape (B, H, W, C).
@@ -157,7 +157,7 @@ def window_partition(x, window_size):
157
157
 
158
158
  def window_unpartition(windows, window_size, pad_hw, hw):
159
159
  """
160
- Unpartitions windowed sequences into original sequences and removes padding.
160
+ Unpartition windowed sequences into original sequences and remove padding.
161
161
 
162
162
  This function reverses the windowing process, reconstructing the original input from windowed segments
163
163
  and removing any padding that was added during the windowing process.
@@ -195,7 +195,7 @@ def window_unpartition(windows, window_size, pad_hw, hw):
195
195
 
196
196
  def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
197
197
  """
198
- Extracts relative positional embeddings based on query and key sizes.
198
+ Extract relative positional embeddings based on query and key sizes.
199
199
 
200
200
  Args:
201
201
  q_size (int): Size of the query.
@@ -244,7 +244,7 @@ def add_decomposed_rel_pos(
244
244
  k_size: Tuple[int, int],
245
245
  ) -> torch.Tensor:
246
246
  """
247
- Adds decomposed Relative Positional Embeddings to the attention map.
247
+ Add decomposed Relative Positional Embeddings to the attention map.
248
248
 
249
249
  This function calculates and applies decomposed Relative Positional Embeddings as described in the MVITv2
250
250
  paper. It enhances the attention mechanism by incorporating spatial relationships between query and key
@@ -701,9 +701,6 @@ class SAM2Predictor(Predictor):
701
701
  - The method supports batched inference for multiple objects when points or bboxes are provided.
702
702
  - Input prompts (bboxes, points) are automatically scaled to match the input image dimensions.
703
703
  - When both bboxes and points are provided, they are merged into a single 'points' input for the model.
704
-
705
- References:
706
- - SAM2 Paper: [Add link to SAM2 paper when available]
707
704
  """
708
705
  features = self.get_im_features(im) if self.features is None else self.features
709
706