ultralytics 8.3.89__py3-none-any.whl → 8.3.91__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/conftest.py +2 -2
- tests/test_cli.py +13 -11
- tests/test_cuda.py +10 -1
- tests/test_exports.py +2 -2
- tests/test_integrations.py +1 -5
- tests/test_python.py +16 -16
- tests/test_solutions.py +9 -9
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +3 -1
- ultralytics/cfg/models/11/yolo11-cls.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-obb.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-pose.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-seg.yaml +5 -5
- ultralytics/cfg/models/11/yolo11.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-p6.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-world.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8.yaml +5 -5
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9c.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9e.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9m.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9t.yaml +1 -1
- ultralytics/data/annotator.py +9 -14
- ultralytics/data/base.py +118 -30
- ultralytics/data/build.py +63 -24
- ultralytics/data/converter.py +5 -5
- ultralytics/data/dataset.py +207 -53
- ultralytics/data/loaders.py +1 -0
- ultralytics/data/split_dota.py +39 -12
- ultralytics/data/utils.py +15 -19
- ultralytics/engine/exporter.py +24 -23
- ultralytics/engine/model.py +67 -88
- ultralytics/engine/predictor.py +106 -21
- ultralytics/engine/trainer.py +32 -23
- ultralytics/engine/tuner.py +21 -18
- ultralytics/engine/validator.py +75 -41
- ultralytics/hub/__init__.py +12 -13
- ultralytics/hub/auth.py +9 -12
- ultralytics/hub/session.py +76 -21
- ultralytics/hub/utils.py +19 -17
- ultralytics/models/fastsam/model.py +20 -11
- ultralytics/models/fastsam/predict.py +36 -16
- ultralytics/models/fastsam/utils.py +5 -5
- ultralytics/models/fastsam/val.py +6 -6
- ultralytics/models/nas/model.py +22 -11
- ultralytics/models/nas/predict.py +9 -4
- ultralytics/models/nas/val.py +5 -5
- ultralytics/models/rtdetr/model.py +20 -11
- ultralytics/models/rtdetr/predict.py +18 -15
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +42 -6
- ultralytics/models/sam/__init__.py +1 -1
- ultralytics/models/sam/amg.py +50 -4
- ultralytics/models/sam/model.py +8 -14
- ultralytics/models/sam/modules/decoders.py +18 -21
- ultralytics/models/sam/modules/encoders.py +25 -46
- ultralytics/models/sam/modules/memory_attention.py +19 -15
- ultralytics/models/sam/modules/sam.py +18 -25
- ultralytics/models/sam/modules/tiny_encoder.py +19 -29
- ultralytics/models/sam/modules/transformer.py +35 -57
- ultralytics/models/sam/modules/utils.py +15 -15
- ultralytics/models/sam/predict.py +0 -3
- ultralytics/models/utils/loss.py +87 -36
- ultralytics/models/utils/ops.py +26 -31
- ultralytics/models/yolo/classify/predict.py +24 -3
- ultralytics/models/yolo/classify/train.py +77 -10
- ultralytics/models/yolo/classify/val.py +40 -15
- ultralytics/models/yolo/detect/predict.py +23 -10
- ultralytics/models/yolo/detect/train.py +85 -15
- ultralytics/models/yolo/detect/val.py +145 -21
- ultralytics/models/yolo/model.py +1 -2
- ultralytics/models/yolo/obb/predict.py +12 -4
- ultralytics/models/yolo/obb/train.py +7 -0
- ultralytics/models/yolo/obb/val.py +25 -7
- ultralytics/models/yolo/pose/predict.py +22 -6
- ultralytics/models/yolo/pose/train.py +17 -1
- ultralytics/models/yolo/pose/val.py +46 -21
- ultralytics/models/yolo/segment/predict.py +22 -8
- ultralytics/models/yolo/segment/train.py +6 -0
- ultralytics/models/yolo/segment/val.py +100 -14
- ultralytics/models/yolo/world/train.py +38 -8
- ultralytics/models/yolo/world/train_world.py +39 -10
- ultralytics/nn/autobackend.py +28 -14
- ultralytics/nn/modules/__init__.py +3 -0
- ultralytics/nn/modules/activation.py +12 -3
- ultralytics/nn/modules/block.py +587 -84
- ultralytics/nn/modules/conv.py +418 -54
- ultralytics/nn/modules/head.py +3 -4
- ultralytics/nn/modules/transformer.py +320 -34
- ultralytics/nn/modules/utils.py +17 -3
- ultralytics/nn/tasks.py +221 -69
- ultralytics/solutions/ai_gym.py +2 -2
- ultralytics/solutions/analytics.py +4 -4
- ultralytics/solutions/heatmap.py +4 -4
- ultralytics/solutions/instance_segmentation.py +10 -4
- ultralytics/solutions/object_blurrer.py +2 -2
- ultralytics/solutions/object_counter.py +2 -2
- ultralytics/solutions/object_cropper.py +2 -2
- ultralytics/solutions/parking_management.py +9 -9
- ultralytics/solutions/queue_management.py +1 -1
- ultralytics/solutions/region_counter.py +2 -2
- ultralytics/solutions/security_alarm.py +7 -7
- ultralytics/solutions/solutions.py +7 -4
- ultralytics/solutions/speed_estimation.py +2 -2
- ultralytics/solutions/streamlit_inference.py +6 -6
- ultralytics/solutions/trackzone.py +9 -2
- ultralytics/solutions/vision_eye.py +4 -4
- ultralytics/trackers/basetrack.py +1 -1
- ultralytics/trackers/bot_sort.py +23 -22
- ultralytics/trackers/byte_tracker.py +4 -4
- ultralytics/trackers/track.py +2 -1
- ultralytics/trackers/utils/gmc.py +26 -27
- ultralytics/trackers/utils/kalman_filter.py +31 -29
- ultralytics/trackers/utils/matching.py +7 -7
- ultralytics/utils/__init__.py +32 -27
- ultralytics/utils/autobatch.py +5 -5
- ultralytics/utils/benchmarks.py +111 -18
- ultralytics/utils/callbacks/base.py +3 -3
- ultralytics/utils/callbacks/clearml.py +11 -11
- ultralytics/utils/callbacks/comet.py +42 -24
- ultralytics/utils/callbacks/dvc.py +11 -10
- ultralytics/utils/callbacks/hub.py +8 -8
- ultralytics/utils/callbacks/mlflow.py +1 -1
- ultralytics/utils/callbacks/neptune.py +12 -10
- ultralytics/utils/callbacks/raytune.py +1 -1
- ultralytics/utils/callbacks/tensorboard.py +6 -6
- ultralytics/utils/callbacks/wb.py +16 -16
- ultralytics/utils/checks.py +116 -35
- ultralytics/utils/dist.py +15 -2
- ultralytics/utils/downloads.py +13 -9
- ultralytics/utils/files.py +12 -13
- ultralytics/utils/instance.py +112 -45
- ultralytics/utils/loss.py +28 -33
- ultralytics/utils/metrics.py +246 -181
- ultralytics/utils/ops.py +61 -53
- ultralytics/utils/patches.py +8 -6
- ultralytics/utils/plotting.py +65 -45
- ultralytics/utils/tal.py +88 -57
- ultralytics/utils/torch_utils.py +181 -33
- ultralytics/utils/triton.py +13 -3
- ultralytics/utils/tuner.py +8 -16
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/METADATA +1 -1
- ultralytics-8.3.91.dist-info/RECORD +250 -0
- ultralytics-8.3.89.dist-info/RECORD +0 -250
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/LICENSE +0 -0
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/WHEEL +0 -0
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/top_level.txt +0 -0
@@ -176,7 +176,7 @@ class SAM2Model(torch.nn.Module):
|
|
176
176
|
compile_image_encoder: bool = False,
|
177
177
|
):
|
178
178
|
"""
|
179
|
-
|
179
|
+
Initialize the SAM2Model for video object segmentation with memory-based tracking.
|
180
180
|
|
181
181
|
Args:
|
182
182
|
image_encoder (nn.Module): Visual encoder for extracting image features.
|
@@ -213,9 +213,9 @@ class SAM2Model(torch.nn.Module):
|
|
213
213
|
the encoder.
|
214
214
|
proj_tpos_enc_in_obj_ptrs (bool): Whether to add an extra linear projection layer for temporal positional
|
215
215
|
encoding in object pointers.
|
216
|
-
use_signed_tpos_enc_to_obj_ptrs (bool):
|
217
|
-
in the temporal positional encoding in the object pointers, only relevant when both
|
218
|
-
and `add_tpos_enc_to_obj_ptrs=True`.
|
216
|
+
use_signed_tpos_enc_to_obj_ptrs (bool): Whether to use signed distance (instead of unsigned absolute distance)
|
217
|
+
in the temporal positional encoding in the object pointers, only relevant when both
|
218
|
+
`use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`.
|
219
219
|
only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past
|
220
220
|
during evaluation.
|
221
221
|
pred_obj_scores (bool): Whether to predict if there is an object in the frame.
|
@@ -332,18 +332,18 @@ class SAM2Model(torch.nn.Module):
|
|
332
332
|
|
333
333
|
@property
|
334
334
|
def device(self):
|
335
|
-
"""
|
335
|
+
"""Return the device on which the model's parameters are stored."""
|
336
336
|
return next(self.parameters()).device
|
337
337
|
|
338
338
|
def forward(self, *args, **kwargs):
|
339
|
-
"""
|
339
|
+
"""Process image and prompt inputs to generate object masks and scores in video sequences."""
|
340
340
|
raise NotImplementedError(
|
341
341
|
"Please use the corresponding methods in SAM2VideoPredictor for inference."
|
342
342
|
"See notebooks/video_predictor_example.ipynb for an example."
|
343
343
|
)
|
344
344
|
|
345
345
|
def _build_sam_heads(self):
|
346
|
-
"""
|
346
|
+
"""Build SAM-style prompt encoder and mask decoder for image segmentation tasks."""
|
347
347
|
self.sam_prompt_embed_dim = self.hidden_dim
|
348
348
|
self.sam_image_embedding_size = self.image_size // self.backbone_stride
|
349
349
|
|
@@ -545,7 +545,7 @@ class SAM2Model(torch.nn.Module):
|
|
545
545
|
)
|
546
546
|
|
547
547
|
def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
|
548
|
-
"""
|
548
|
+
"""Process mask inputs directly as output, bypassing SAM encoder/decoder."""
|
549
549
|
# Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
|
550
550
|
out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
|
551
551
|
mask_inputs_float = mask_inputs.float()
|
@@ -592,7 +592,7 @@ class SAM2Model(torch.nn.Module):
|
|
592
592
|
)
|
593
593
|
|
594
594
|
def forward_image(self, img_batch: torch.Tensor):
|
595
|
-
"""
|
595
|
+
"""Process image batch through encoder to extract multi-level features for SAM model."""
|
596
596
|
backbone_out = self.image_encoder(img_batch)
|
597
597
|
if self.use_high_res_features_in_sam:
|
598
598
|
# precompute projected level 0 and level 1 features in SAM decoder
|
@@ -602,7 +602,7 @@ class SAM2Model(torch.nn.Module):
|
|
602
602
|
return backbone_out
|
603
603
|
|
604
604
|
def _prepare_backbone_features(self, backbone_out):
|
605
|
-
"""
|
605
|
+
"""Prepare and flatten visual features from the image backbone output for further processing."""
|
606
606
|
assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
|
607
607
|
assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
|
608
608
|
|
@@ -627,7 +627,7 @@ class SAM2Model(torch.nn.Module):
|
|
627
627
|
num_frames,
|
628
628
|
track_in_reverse=False, # tracking in reverse time order (for demo usage)
|
629
629
|
):
|
630
|
-
"""
|
630
|
+
"""Prepare memory-conditioned features by fusing current frame's visual features with previous memories."""
|
631
631
|
B = current_vision_feats[-1].size(1) # batch size on this frame
|
632
632
|
C = self.hidden_dim
|
633
633
|
H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
|
@@ -788,7 +788,7 @@ class SAM2Model(torch.nn.Module):
|
|
788
788
|
object_score_logits,
|
789
789
|
is_mask_from_pts,
|
790
790
|
):
|
791
|
-
"""
|
791
|
+
"""Encode frame features and masks into a new memory representation for video segmentation."""
|
792
792
|
B = current_vision_feats[-1].size(1) # batch size on this frame
|
793
793
|
C = self.hidden_dim
|
794
794
|
H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
|
@@ -838,7 +838,7 @@ class SAM2Model(torch.nn.Module):
|
|
838
838
|
track_in_reverse,
|
839
839
|
prev_sam_mask_logits,
|
840
840
|
):
|
841
|
-
"""
|
841
|
+
"""Perform a single tracking step, updating object masks and memory features based on current frame inputs."""
|
842
842
|
current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
|
843
843
|
# High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
|
844
844
|
if len(current_vision_feats) > 1:
|
@@ -893,9 +893,7 @@ class SAM2Model(torch.nn.Module):
|
|
893
893
|
object_score_logits,
|
894
894
|
current_out,
|
895
895
|
):
|
896
|
-
"""
|
897
|
-
used in future frames).
|
898
|
-
"""
|
896
|
+
"""Run memory encoder on predicted mask to encode it into a new memory feature for future frames."""
|
899
897
|
if run_mem_encoder and self.num_maskmem > 0:
|
900
898
|
high_res_masks_for_mem_enc = high_res_masks
|
901
899
|
maskmem_features, maskmem_pos_enc = self._encode_new_memory(
|
@@ -932,7 +930,7 @@ class SAM2Model(torch.nn.Module):
|
|
932
930
|
# The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
|
933
931
|
prev_sam_mask_logits=None,
|
934
932
|
):
|
935
|
-
"""
|
933
|
+
"""Perform a single tracking step, updating object masks and memory features based on current frame inputs."""
|
936
934
|
current_out, sam_outputs, _, _ = self._track_step(
|
937
935
|
frame_idx,
|
938
936
|
is_init_cond_frame,
|
@@ -970,7 +968,7 @@ class SAM2Model(torch.nn.Module):
|
|
970
968
|
return current_out
|
971
969
|
|
972
970
|
def _use_multimask(self, is_init_cond_frame, point_inputs):
|
973
|
-
"""
|
971
|
+
"""Determine whether to use multiple mask outputs in the SAM head based on configuration and inputs."""
|
974
972
|
num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
|
975
973
|
return (
|
976
974
|
self.multimask_output_in_sam
|
@@ -980,7 +978,7 @@ class SAM2Model(torch.nn.Module):
|
|
980
978
|
|
981
979
|
@staticmethod
|
982
980
|
def _apply_non_overlapping_constraints(pred_masks):
|
983
|
-
"""
|
981
|
+
"""Apply non-overlapping constraints to masks, keeping the highest scoring object per location."""
|
984
982
|
batch_size = pred_masks.size(0)
|
985
983
|
if batch_size == 1:
|
986
984
|
return pred_masks
|
@@ -1001,12 +999,7 @@ class SAM2Model(torch.nn.Module):
|
|
1001
999
|
self.binarize_mask_from_pts_for_mem_enc = binarize
|
1002
1000
|
|
1003
1001
|
def set_imgsz(self, imgsz):
|
1004
|
-
"""
|
1005
|
-
Set image size to make model compatible with different image sizes.
|
1006
|
-
|
1007
|
-
Args:
|
1008
|
-
imgsz (Tuple[int, int]): The size of the input image.
|
1009
|
-
"""
|
1002
|
+
"""Set image size to make model compatible with different image sizes."""
|
1010
1003
|
self.image_size = imgsz[0]
|
1011
1004
|
self.sam_prompt_encoder.input_image_size = imgsz
|
1012
1005
|
self.sam_prompt_encoder.image_embedding_size = [x // 16 for x in imgsz] # fixed ViT patch size of 16
|
@@ -27,7 +27,7 @@ class Conv2d_BN(torch.nn.Sequential):
|
|
27
27
|
|
28
28
|
Attributes:
|
29
29
|
c (torch.nn.Conv2d): 2D convolution layer.
|
30
|
-
|
30
|
+
bn (torch.nn.BatchNorm2d): Batch normalization layer.
|
31
31
|
|
32
32
|
Methods:
|
33
33
|
__init__: Initializes the Conv2d_BN with specified parameters.
|
@@ -265,9 +265,9 @@ class ConvLayer(nn.Module):
|
|
265
265
|
dim (int): The dimensionality of the input and output.
|
266
266
|
input_resolution (Tuple[int, int]): The resolution of the input image.
|
267
267
|
depth (int): The number of MBConv layers in the block.
|
268
|
-
activation (
|
268
|
+
activation (nn.Module): Activation function applied after each convolution.
|
269
269
|
drop_path (float | List[float]): Drop path rate. Single float or a list of floats for each MBConv.
|
270
|
-
downsample (Optional[
|
270
|
+
downsample (Optional[nn.Module]): Function for downsampling the output. None to skip downsampling.
|
271
271
|
use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
|
272
272
|
out_dim (Optional[int]): The dimensionality of the output. None means it will be the same as `dim`.
|
273
273
|
conv_expand_ratio (float): Expansion ratio for the MBConv layers.
|
@@ -413,12 +413,9 @@ class Attention(torch.nn.Module):
|
|
413
413
|
Args:
|
414
414
|
dim (int): The dimensionality of the input and output.
|
415
415
|
key_dim (int): The dimensionality of the keys and queries.
|
416
|
-
num_heads (int): Number of attention heads.
|
417
|
-
attn_ratio (float): Attention ratio, affecting the dimensions of the value vectors.
|
418
|
-
resolution (Tuple[int, int]): Spatial resolution of the input feature map.
|
419
|
-
|
420
|
-
Raises:
|
421
|
-
AssertionError: If 'resolution' is not a tuple of length 2.
|
416
|
+
num_heads (int): Number of attention heads.
|
417
|
+
attn_ratio (float): Attention ratio, affecting the dimensions of the value vectors.
|
418
|
+
resolution (Tuple[int, int]): Spatial resolution of the input feature map.
|
422
419
|
|
423
420
|
Examples:
|
424
421
|
>>> attn = Attention(dim=256, key_dim=64, num_heads=8, resolution=(14, 14))
|
@@ -821,22 +818,20 @@ class TinyViT(nn.Module):
|
|
821
818
|
attention and convolution blocks, and a classification head.
|
822
819
|
|
823
820
|
Args:
|
824
|
-
img_size (int): Size of the input image.
|
825
|
-
in_chans (int): Number of input channels.
|
826
|
-
num_classes (int): Number of classes for classification.
|
821
|
+
img_size (int): Size of the input image.
|
822
|
+
in_chans (int): Number of input channels.
|
823
|
+
num_classes (int): Number of classes for classification.
|
827
824
|
embed_dims (Tuple[int, int, int, int]): Embedding dimensions for each stage.
|
828
|
-
|
829
|
-
depths (Tuple[int, int, int, int]): Number of blocks in each stage. Default is (2, 2, 6, 2).
|
825
|
+
depths (Tuple[int, int, int, int]): Number of blocks in each stage.
|
830
826
|
num_heads (Tuple[int, int, int, int]): Number of attention heads in each stage.
|
831
|
-
|
832
|
-
|
833
|
-
|
834
|
-
|
835
|
-
|
836
|
-
|
837
|
-
|
838
|
-
|
839
|
-
layer_lr_decay (float): Layer-wise learning rate decay factor. Default is 1.0.
|
827
|
+
window_sizes (Tuple[int, int, int, int]): Window sizes for each stage.
|
828
|
+
mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
|
829
|
+
drop_rate (float): Dropout rate.
|
830
|
+
drop_path_rate (float): Stochastic depth rate.
|
831
|
+
use_checkpoint (bool): Whether to use checkpointing to save memory.
|
832
|
+
mbconv_expand_ratio (float): Expansion ratio for MBConv layer.
|
833
|
+
local_conv_size (int): Kernel size for local convolutions.
|
834
|
+
layer_lr_decay (float): Layer-wise learning rate decay factor.
|
840
835
|
|
841
836
|
Examples:
|
842
837
|
>>> model = TinyViT(img_size=224, num_classes=1000)
|
@@ -992,12 +987,7 @@ class TinyViT(nn.Module):
|
|
992
987
|
return self.forward_features(x)
|
993
988
|
|
994
989
|
def set_imgsz(self, imgsz=[1024, 1024]):
|
995
|
-
"""
|
996
|
-
Set image size to make model compatible with different image sizes.
|
997
|
-
|
998
|
-
Args:
|
999
|
-
imgsz (Tuple[int, int]): The size of the input image.
|
1000
|
-
"""
|
990
|
+
"""Set image size to make model compatible with different image sizes."""
|
1001
991
|
imgsz = [s // 4 for s in imgsz]
|
1002
992
|
self.patches_resolution = imgsz
|
1003
993
|
for i, layer in enumerate(self.layers):
|
@@ -57,23 +57,6 @@ class TwoWayTransformer(nn.Module):
|
|
57
57
|
mlp_dim (int): Internal channel dimension for the MLP block.
|
58
58
|
activation (Type[nn.Module]): Activation function to use in the MLP block.
|
59
59
|
attention_downsample_rate (int): Downsampling rate for attention mechanism.
|
60
|
-
|
61
|
-
Attributes:
|
62
|
-
depth (int): Number of layers in the transformer.
|
63
|
-
embedding_dim (int): Channel dimension for input embeddings.
|
64
|
-
num_heads (int): Number of heads for multihead attention.
|
65
|
-
mlp_dim (int): Internal channel dimension for the MLP block.
|
66
|
-
layers (nn.ModuleList): List of TwoWayAttentionBlock layers.
|
67
|
-
final_attn_token_to_image (Attention): Final attention layer from queries to image.
|
68
|
-
norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.
|
69
|
-
|
70
|
-
Examples:
|
71
|
-
>>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
|
72
|
-
>>> image_embedding = torch.randn(1, 256, 32, 32)
|
73
|
-
>>> image_pe = torch.randn(1, 256, 32, 32)
|
74
|
-
>>> point_embedding = torch.randn(1, 100, 256)
|
75
|
-
>>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
|
76
|
-
>>> print(output_queries.shape, output_image.shape)
|
77
60
|
"""
|
78
61
|
super().__init__()
|
79
62
|
self.depth = depth
|
@@ -104,23 +87,16 @@ class TwoWayTransformer(nn.Module):
|
|
104
87
|
point_embedding: Tensor,
|
105
88
|
) -> Tuple[Tensor, Tensor]:
|
106
89
|
"""
|
107
|
-
|
90
|
+
Process image and point embeddings through the Two-Way Transformer.
|
108
91
|
|
109
92
|
Args:
|
110
|
-
image_embedding (
|
111
|
-
image_pe (
|
112
|
-
point_embedding (
|
93
|
+
image_embedding (Tensor): Image to attend to, with shape (B, embedding_dim, H, W).
|
94
|
+
image_pe (Tensor): Positional encoding to add to the image, with same shape as image_embedding.
|
95
|
+
point_embedding (Tensor): Embedding to add to query points, with shape (B, N_points, embedding_dim).
|
113
96
|
|
114
97
|
Returns:
|
115
|
-
(
|
116
|
-
|
117
|
-
Examples:
|
118
|
-
>>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
|
119
|
-
>>> image_embedding = torch.randn(1, 256, 32, 32)
|
120
|
-
>>> image_pe = torch.randn(1, 256, 32, 32)
|
121
|
-
>>> point_embedding = torch.randn(1, 100, 256)
|
122
|
-
>>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
|
123
|
-
>>> print(output_queries.shape, output_image.shape)
|
98
|
+
queries (Tensor): Processed point embeddings with shape (B, N_points, embedding_dim).
|
99
|
+
keys (Tensor): Processed image embeddings with shape (B, H*W, embedding_dim).
|
124
100
|
"""
|
125
101
|
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
|
126
102
|
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
|
@@ -191,7 +167,7 @@ class TwoWayAttentionBlock(nn.Module):
|
|
191
167
|
skip_first_layer_pe: bool = False,
|
192
168
|
) -> None:
|
193
169
|
"""
|
194
|
-
|
170
|
+
Initialize a TwoWayAttentionBlock for simultaneous attention to image and query points.
|
195
171
|
|
196
172
|
This block implements a specialized transformer layer with four main components: self-attention on sparse
|
197
173
|
inputs, cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention
|
@@ -204,15 +180,6 @@ class TwoWayAttentionBlock(nn.Module):
|
|
204
180
|
activation (Type[nn.Module]): Activation function for the MLP block.
|
205
181
|
attention_downsample_rate (int): Downsampling rate for the attention mechanism.
|
206
182
|
skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer.
|
207
|
-
|
208
|
-
Examples:
|
209
|
-
>>> embedding_dim, num_heads = 256, 8
|
210
|
-
>>> block = TwoWayAttentionBlock(embedding_dim, num_heads)
|
211
|
-
>>> queries = torch.randn(1, 100, embedding_dim)
|
212
|
-
>>> keys = torch.randn(1, 1000, embedding_dim)
|
213
|
-
>>> query_pe = torch.randn(1, 100, embedding_dim)
|
214
|
-
>>> key_pe = torch.randn(1, 1000, embedding_dim)
|
215
|
-
>>> processed_queries, processed_keys = block(queries, keys, query_pe, key_pe)
|
216
183
|
"""
|
217
184
|
super().__init__()
|
218
185
|
self.self_attn = Attention(embedding_dim, num_heads)
|
@@ -230,7 +197,19 @@ class TwoWayAttentionBlock(nn.Module):
|
|
230
197
|
self.skip_first_layer_pe = skip_first_layer_pe
|
231
198
|
|
232
199
|
def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]:
|
233
|
-
"""
|
200
|
+
"""
|
201
|
+
Apply two-way attention to process query and key embeddings in a transformer block.
|
202
|
+
|
203
|
+
Args:
|
204
|
+
queries (Tensor): Query embeddings with shape (B, N_queries, embedding_dim).
|
205
|
+
keys (Tensor): Key embeddings with shape (B, N_keys, embedding_dim).
|
206
|
+
query_pe (Tensor): Positional encodings for queries with same shape as queries.
|
207
|
+
key_pe (Tensor): Positional encodings for keys with same shape as keys.
|
208
|
+
|
209
|
+
Returns:
|
210
|
+
queries (Tensor): Processed query embeddings with shape (B, N_queries, embedding_dim).
|
211
|
+
keys (Tensor): Processed key embeddings with shape (B, N_keys, embedding_dim).
|
212
|
+
"""
|
234
213
|
# Self attention block
|
235
214
|
if self.skip_first_layer_pe:
|
236
215
|
queries = self.self_attn(q=queries, k=queries, v=queries)
|
@@ -301,27 +280,16 @@ class Attention(nn.Module):
|
|
301
280
|
kv_in_dim: int = None,
|
302
281
|
) -> None:
|
303
282
|
"""
|
304
|
-
|
305
|
-
|
306
|
-
This class implements a multi-head attention mechanism with optional downsampling of the internal
|
307
|
-
dimension for queries, keys, and values.
|
283
|
+
Initialize the Attention module with specified dimensions and settings.
|
308
284
|
|
309
285
|
Args:
|
310
286
|
embedding_dim (int): Dimensionality of input embeddings.
|
311
287
|
num_heads (int): Number of attention heads.
|
312
|
-
downsample_rate (int): Factor by which internal dimensions are downsampled.
|
288
|
+
downsample_rate (int): Factor by which internal dimensions are downsampled.
|
313
289
|
kv_in_dim (int | None): Dimensionality of key and value inputs. If None, uses embedding_dim.
|
314
290
|
|
315
291
|
Raises:
|
316
292
|
AssertionError: If num_heads does not evenly divide the internal dim (embedding_dim / downsample_rate).
|
317
|
-
|
318
|
-
Examples:
|
319
|
-
>>> attn = Attention(embedding_dim=256, num_heads=8, downsample_rate=2)
|
320
|
-
>>> q = torch.randn(1, 100, 256)
|
321
|
-
>>> k = v = torch.randn(1, 50, 256)
|
322
|
-
>>> output = attn(q, k, v)
|
323
|
-
>>> print(output.shape)
|
324
|
-
torch.Size([1, 100, 256])
|
325
293
|
"""
|
326
294
|
super().__init__()
|
327
295
|
self.embedding_dim = embedding_dim
|
@@ -337,20 +305,30 @@ class Attention(nn.Module):
|
|
337
305
|
|
338
306
|
@staticmethod
|
339
307
|
def _separate_heads(x: Tensor, num_heads: int) -> Tensor:
|
340
|
-
"""
|
308
|
+
"""Separate the input tensor into the specified number of attention heads."""
|
341
309
|
b, n, c = x.shape
|
342
310
|
x = x.reshape(b, n, num_heads, c // num_heads)
|
343
311
|
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
|
344
312
|
|
345
313
|
@staticmethod
|
346
314
|
def _recombine_heads(x: Tensor) -> Tensor:
|
347
|
-
"""
|
315
|
+
"""Recombine separated attention heads into a single tensor."""
|
348
316
|
b, n_heads, n_tokens, c_per_head = x.shape
|
349
317
|
x = x.transpose(1, 2)
|
350
318
|
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
|
351
319
|
|
352
320
|
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
353
|
-
"""
|
321
|
+
"""
|
322
|
+
Apply multi-head attention to query, key, and value tensors with optional downsampling.
|
323
|
+
|
324
|
+
Args:
|
325
|
+
q (Tensor): Query tensor with shape (B, N_q, embedding_dim).
|
326
|
+
k (Tensor): Key tensor with shape (B, N_k, embedding_dim).
|
327
|
+
v (Tensor): Value tensor with shape (B, N_k, embedding_dim).
|
328
|
+
|
329
|
+
Returns:
|
330
|
+
(Tensor): Output tensor after attention with shape (B, N_q, embedding_dim).
|
331
|
+
"""
|
354
332
|
# Input projections
|
355
333
|
q = self.q_proj(q)
|
356
334
|
k = self.k_proj(k)
|
@@ -8,7 +8,7 @@ import torch.nn.functional as F
|
|
8
8
|
|
9
9
|
def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
|
10
10
|
"""
|
11
|
-
|
11
|
+
Select the closest conditioning frames to a given frame index.
|
12
12
|
|
13
13
|
Args:
|
14
14
|
frame_idx (int): Current frame index.
|
@@ -37,17 +37,17 @@ def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num
|
|
37
37
|
assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames"
|
38
38
|
selected_outputs = {}
|
39
39
|
|
40
|
-
#
|
40
|
+
# The closest conditioning frame before `frame_idx` (if any)
|
41
41
|
idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None)
|
42
42
|
if idx_before is not None:
|
43
43
|
selected_outputs[idx_before] = cond_frame_outputs[idx_before]
|
44
44
|
|
45
|
-
#
|
45
|
+
# The closest conditioning frame after `frame_idx` (if any)
|
46
46
|
idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None)
|
47
47
|
if idx_after is not None:
|
48
48
|
selected_outputs[idx_after] = cond_frame_outputs[idx_after]
|
49
49
|
|
50
|
-
#
|
50
|
+
# Add other temporally closest conditioning frames until reaching a total
|
51
51
|
# of `max_cond_frame_num` conditioning frames.
|
52
52
|
num_remain = max_cond_frame_num - len(selected_outputs)
|
53
53
|
inds_remain = sorted(
|
@@ -61,7 +61,7 @@ def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num
|
|
61
61
|
|
62
62
|
|
63
63
|
def get_1d_sine_pe(pos_inds, dim, temperature=10000):
|
64
|
-
"""
|
64
|
+
"""Generate 1D sinusoidal positional embeddings for given positions and dimensions."""
|
65
65
|
pe_dim = dim // 2
|
66
66
|
dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
|
67
67
|
dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
|
@@ -72,7 +72,7 @@ def get_1d_sine_pe(pos_inds, dim, temperature=10000):
|
|
72
72
|
|
73
73
|
|
74
74
|
def init_t_xy(end_x: int, end_y: int):
|
75
|
-
"""
|
75
|
+
"""Initialize 1D and 2D coordinate tensors for a grid of specified dimensions."""
|
76
76
|
t = torch.arange(end_x * end_y, dtype=torch.float32)
|
77
77
|
t_x = (t % end_x).float()
|
78
78
|
t_y = torch.div(t, end_x, rounding_mode="floor").float()
|
@@ -80,7 +80,7 @@ def init_t_xy(end_x: int, end_y: int):
|
|
80
80
|
|
81
81
|
|
82
82
|
def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
|
83
|
-
"""
|
83
|
+
"""Compute axial complex exponential positional encodings for 2D spatial positions in a grid."""
|
84
84
|
freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
|
85
85
|
freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
|
86
86
|
|
@@ -93,7 +93,7 @@ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
|
|
93
93
|
|
94
94
|
|
95
95
|
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
96
|
-
"""
|
96
|
+
"""Reshape frequency tensor for broadcasting with input tensor, ensuring dimensional compatibility."""
|
97
97
|
ndim = x.ndim
|
98
98
|
assert 0 <= 1 < ndim
|
99
99
|
assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
|
@@ -107,15 +107,15 @@ def apply_rotary_enc(
|
|
107
107
|
freqs_cis: torch.Tensor,
|
108
108
|
repeat_freqs_k: bool = False,
|
109
109
|
):
|
110
|
-
"""
|
110
|
+
"""Apply rotary positional encoding to query and key tensors using complex-valued frequency components."""
|
111
111
|
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
112
112
|
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) if xk.shape[-2] != 0 else None
|
113
113
|
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
114
114
|
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
115
115
|
if xk_ is None:
|
116
|
-
#
|
116
|
+
# No keys to rotate, due to dropout
|
117
117
|
return xq_out.type_as(xq).to(xq.device), xk
|
118
|
-
#
|
118
|
+
# Repeat freqs along seq_len dim to match k seq_len
|
119
119
|
if repeat_freqs_k:
|
120
120
|
r = xk_.shape[-2] // xq_.shape[-2]
|
121
121
|
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
|
@@ -125,7 +125,7 @@ def apply_rotary_enc(
|
|
125
125
|
|
126
126
|
def window_partition(x, window_size):
|
127
127
|
"""
|
128
|
-
|
128
|
+
Partition input tensor into non-overlapping windows with padding if needed.
|
129
129
|
|
130
130
|
Args:
|
131
131
|
x (torch.Tensor): Input tensor with shape (B, H, W, C).
|
@@ -157,7 +157,7 @@ def window_partition(x, window_size):
|
|
157
157
|
|
158
158
|
def window_unpartition(windows, window_size, pad_hw, hw):
|
159
159
|
"""
|
160
|
-
|
160
|
+
Unpartition windowed sequences into original sequences and remove padding.
|
161
161
|
|
162
162
|
This function reverses the windowing process, reconstructing the original input from windowed segments
|
163
163
|
and removing any padding that was added during the windowing process.
|
@@ -195,7 +195,7 @@ def window_unpartition(windows, window_size, pad_hw, hw):
|
|
195
195
|
|
196
196
|
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
|
197
197
|
"""
|
198
|
-
|
198
|
+
Extract relative positional embeddings based on query and key sizes.
|
199
199
|
|
200
200
|
Args:
|
201
201
|
q_size (int): Size of the query.
|
@@ -244,7 +244,7 @@ def add_decomposed_rel_pos(
|
|
244
244
|
k_size: Tuple[int, int],
|
245
245
|
) -> torch.Tensor:
|
246
246
|
"""
|
247
|
-
|
247
|
+
Add decomposed Relative Positional Embeddings to the attention map.
|
248
248
|
|
249
249
|
This function calculates and applies decomposed Relative Positional Embeddings as described in the MVITv2
|
250
250
|
paper. It enhances the attention mechanism by incorporating spatial relationships between query and key
|
@@ -701,9 +701,6 @@ class SAM2Predictor(Predictor):
|
|
701
701
|
- The method supports batched inference for multiple objects when points or bboxes are provided.
|
702
702
|
- Input prompts (bboxes, points) are automatically scaled to match the input image dimensions.
|
703
703
|
- When both bboxes and points are provided, they are merged into a single 'points' input for the model.
|
704
|
-
|
705
|
-
References:
|
706
|
-
- SAM2 Paper: [Add link to SAM2 paper when available]
|
707
704
|
"""
|
708
705
|
features = self.get_im_features(im) if self.features is None else self.features
|
709
706
|
|