dgenerate-ultralytics-headless 8.3.196__py3-none-any.whl → 8.3.248__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +33 -34
- dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
- tests/__init__.py +5 -7
- tests/conftest.py +8 -15
- tests/test_cli.py +8 -10
- tests/test_cuda.py +9 -10
- tests/test_engine.py +29 -2
- tests/test_exports.py +69 -21
- tests/test_integrations.py +8 -11
- tests/test_python.py +109 -71
- tests/test_solutions.py +170 -159
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +57 -64
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/Objects365.yaml +19 -15
- ultralytics/cfg/datasets/SKU-110K.yaml +1 -1
- ultralytics/cfg/datasets/VOC.yaml +19 -21
- ultralytics/cfg/datasets/VisDrone.yaml +5 -5
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +24 -2
- ultralytics/cfg/datasets/coco.yaml +2 -2
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +7 -7
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +96 -94
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +3 -4
- ultralytics/data/augment.py +286 -476
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +151 -26
- ultralytics/data/converter.py +38 -50
- ultralytics/data/dataset.py +47 -75
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +41 -45
- ultralytics/engine/exporter.py +462 -462
- ultralytics/engine/model.py +150 -191
- ultralytics/engine/predictor.py +30 -40
- ultralytics/engine/results.py +177 -311
- ultralytics/engine/trainer.py +193 -120
- ultralytics/engine/tuner.py +77 -63
- ultralytics/engine/validator.py +39 -22
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +19 -30
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +7 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +22 -40
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +206 -79
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2268 -366
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +9 -12
- ultralytics/models/yolo/classify/train.py +15 -41
- ultralytics/models/yolo/classify/val.py +34 -32
- ultralytics/models/yolo/detect/predict.py +8 -11
- ultralytics/models/yolo/detect/train.py +13 -32
- ultralytics/models/yolo/detect/val.py +75 -63
- ultralytics/models/yolo/model.py +37 -53
- ultralytics/models/yolo/obb/predict.py +5 -14
- ultralytics/models/yolo/obb/train.py +11 -14
- ultralytics/models/yolo/obb/val.py +42 -39
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +7 -22
- ultralytics/models/yolo/pose/train.py +10 -22
- ultralytics/models/yolo/pose/val.py +40 -59
- ultralytics/models/yolo/segment/predict.py +16 -20
- ultralytics/models/yolo/segment/train.py +3 -12
- ultralytics/models/yolo/segment/val.py +106 -56
- ultralytics/models/yolo/world/train.py +12 -16
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +31 -56
- ultralytics/models/yolo/yoloe/train_seg.py +5 -10
- ultralytics/models/yolo/yoloe/val.py +16 -21
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +152 -80
- ultralytics/nn/modules/__init__.py +60 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +133 -217
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +64 -116
- ultralytics/nn/modules/transformer.py +79 -89
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +111 -156
- ultralytics/nn/text_model.py +40 -67
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +11 -17
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +5 -6
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +7 -13
- ultralytics/solutions/instance_segmentation.py +5 -8
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +33 -31
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +13 -17
- ultralytics/solutions/solutions.py +75 -74
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +4 -7
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +14 -30
- ultralytics/trackers/track.py +3 -6
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +116 -116
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +70 -70
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +314 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +23 -31
- ultralytics/utils/callbacks/wb.py +10 -13
- ultralytics/utils/checks.py +151 -87
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +19 -15
- ultralytics/utils/downloads.py +29 -41
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +16 -16
- ultralytics/utils/export/imx.py +325 -0
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +24 -28
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +15 -24
- ultralytics/utils/metrics.py +131 -160
- ultralytics/utils/nms.py +21 -30
- ultralytics/utils/ops.py +107 -165
- ultralytics/utils/patches.py +33 -21
- ultralytics/utils/plotting.py +122 -119
- ultralytics/utils/tal.py +28 -44
- ultralytics/utils/torch_utils.py +70 -187
- ultralytics/utils/tqdm.py +20 -20
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +17 -5
- dgenerate_ultralytics_headless-8.3.196.dist-info/RECORD +0 -281
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
|
@@ -13,7 +13,7 @@ from torch.nn.init import trunc_normal_
|
|
|
13
13
|
from ultralytics.nn.modules import MLP
|
|
14
14
|
from ultralytics.utils import LOGGER
|
|
15
15
|
|
|
16
|
-
from .blocks import SAM2TwoWayTransformer
|
|
16
|
+
from .blocks import SAM2TwoWayTransformer, TwoWayTransformer
|
|
17
17
|
from .decoders import MaskDecoder, SAM2MaskDecoder
|
|
18
18
|
from .encoders import ImageEncoderViT, PromptEncoder
|
|
19
19
|
from .utils import get_1d_sine_pe, select_closest_cond_frames
|
|
@@ -23,11 +23,10 @@ NO_OBJ_SCORE = -1024.0
|
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
class SAMModel(nn.Module):
|
|
26
|
-
"""
|
|
27
|
-
Segment Anything Model (SAM) for object segmentation tasks.
|
|
26
|
+
"""Segment Anything Model (SAM) for object segmentation tasks.
|
|
28
27
|
|
|
29
|
-
This class combines image encoders, prompt encoders, and mask decoders to predict object masks from images
|
|
30
|
-
|
|
28
|
+
This class combines image encoders, prompt encoders, and mask decoders to predict object masks from images and input
|
|
29
|
+
prompts.
|
|
31
30
|
|
|
32
31
|
Attributes:
|
|
33
32
|
mask_threshold (float): Threshold value for mask prediction.
|
|
@@ -61,8 +60,7 @@ class SAMModel(nn.Module):
|
|
|
61
60
|
pixel_mean: list[float] = (123.675, 116.28, 103.53),
|
|
62
61
|
pixel_std: list[float] = (58.395, 57.12, 57.375),
|
|
63
62
|
) -> None:
|
|
64
|
-
"""
|
|
65
|
-
Initialize the SAMModel class to predict object masks from an image and input prompts.
|
|
63
|
+
"""Initialize the SAMModel class to predict object masks from an image and input prompts.
|
|
66
64
|
|
|
67
65
|
Args:
|
|
68
66
|
image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings.
|
|
@@ -71,13 +69,6 @@ class SAMModel(nn.Module):
|
|
|
71
69
|
pixel_mean (list[float]): Mean values for normalizing pixels in the input image.
|
|
72
70
|
pixel_std (list[float]): Standard deviation values for normalizing pixels in the input image.
|
|
73
71
|
|
|
74
|
-
Examples:
|
|
75
|
-
>>> image_encoder = ImageEncoderViT(...)
|
|
76
|
-
>>> prompt_encoder = PromptEncoder(...)
|
|
77
|
-
>>> mask_decoder = MaskDecoder(...)
|
|
78
|
-
>>> sam_model = SAMModel(image_encoder, prompt_encoder, mask_decoder)
|
|
79
|
-
>>> # Further usage depends on SAMPredictor class
|
|
80
|
-
|
|
81
72
|
Notes:
|
|
82
73
|
All forward() operations moved to SAMPredictor.
|
|
83
74
|
"""
|
|
@@ -98,11 +89,10 @@ class SAMModel(nn.Module):
|
|
|
98
89
|
|
|
99
90
|
|
|
100
91
|
class SAM2Model(torch.nn.Module):
|
|
101
|
-
"""
|
|
102
|
-
SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities.
|
|
92
|
+
"""SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities.
|
|
103
93
|
|
|
104
|
-
This class extends the functionality of SAM to handle video sequences, incorporating memory mechanisms
|
|
105
|
-
|
|
94
|
+
This class extends the functionality of SAM to handle video sequences, incorporating memory mechanisms for temporal
|
|
95
|
+
consistency and efficient tracking of objects across frames.
|
|
106
96
|
|
|
107
97
|
Attributes:
|
|
108
98
|
mask_threshold (float): Threshold value for mask prediction.
|
|
@@ -136,24 +126,24 @@ class SAM2Model(torch.nn.Module):
|
|
|
136
126
|
use_mlp_for_obj_ptr_proj (bool): Whether to use MLP for object pointer projection.
|
|
137
127
|
no_obj_embed_spatial (torch.Tensor | None): No-object embedding for spatial frames.
|
|
138
128
|
max_cond_frames_in_attn (int): Maximum number of conditioning frames to participate in memory attention.
|
|
139
|
-
directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the
|
|
140
|
-
|
|
141
|
-
multimask_output_in_sam (bool): Whether to output multiple masks for the first click on initial
|
|
142
|
-
|
|
129
|
+
directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the first
|
|
130
|
+
frame.
|
|
131
|
+
multimask_output_in_sam (bool): Whether to output multiple masks for the first click on initial conditioning
|
|
132
|
+
frames.
|
|
143
133
|
multimask_min_pt_num (int): Minimum number of clicks to use multimask output in SAM.
|
|
144
134
|
multimask_max_pt_num (int): Maximum number of clicks to use multimask output in SAM.
|
|
145
135
|
multimask_output_for_tracking (bool): Whether to use multimask output for tracking.
|
|
146
136
|
use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers.
|
|
147
137
|
iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1].
|
|
148
138
|
memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation.
|
|
149
|
-
non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in
|
|
150
|
-
|
|
139
|
+
non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in memory
|
|
140
|
+
encoder during evaluation.
|
|
151
141
|
sigmoid_scale_for_mem_enc (float): Scale factor for mask sigmoid probability.
|
|
152
142
|
sigmoid_bias_for_mem_enc (float): Bias factor for mask sigmoid probability.
|
|
153
|
-
binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames
|
|
154
|
-
|
|
155
|
-
use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM
|
|
156
|
-
|
|
143
|
+
binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames with
|
|
144
|
+
clicks during evaluation.
|
|
145
|
+
use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM prompt
|
|
146
|
+
encoder and mask decoder on frames with mask input.
|
|
157
147
|
|
|
158
148
|
Methods:
|
|
159
149
|
forward_image: Process image batch through encoder to extract multi-level features.
|
|
@@ -208,8 +198,7 @@ class SAM2Model(torch.nn.Module):
|
|
|
208
198
|
sam_mask_decoder_extra_args=None,
|
|
209
199
|
compile_image_encoder: bool = False,
|
|
210
200
|
):
|
|
211
|
-
"""
|
|
212
|
-
Initialize the SAM2Model for video object segmentation with memory-based tracking.
|
|
201
|
+
"""Initialize the SAM2Model for video object segmentation with memory-based tracking.
|
|
213
202
|
|
|
214
203
|
Args:
|
|
215
204
|
image_encoder (nn.Module): Visual encoder for extracting image features.
|
|
@@ -220,35 +209,35 @@ class SAM2Model(torch.nn.Module):
|
|
|
220
209
|
backbone_stride (int): Stride of the image backbone output.
|
|
221
210
|
sigmoid_scale_for_mem_enc (float): Scale factor for mask sigmoid probability.
|
|
222
211
|
sigmoid_bias_for_mem_enc (float): Bias factor for mask sigmoid probability.
|
|
223
|
-
binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames
|
|
224
|
-
|
|
212
|
+
binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames with
|
|
213
|
+
clicks during evaluation.
|
|
225
214
|
use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM
|
|
226
215
|
prompt encoder and mask decoder on frames with mask input.
|
|
227
216
|
max_cond_frames_in_attn (int): Maximum number of conditioning frames to participate in memory attention.
|
|
228
|
-
directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the
|
|
229
|
-
|
|
217
|
+
directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the first
|
|
218
|
+
frame.
|
|
230
219
|
use_high_res_features_in_sam (bool): Whether to use high-resolution feature maps in the SAM mask decoder.
|
|
231
|
-
multimask_output_in_sam (bool): Whether to output multiple masks for the first click on initial
|
|
232
|
-
|
|
220
|
+
multimask_output_in_sam (bool): Whether to output multiple masks for the first click on initial conditioning
|
|
221
|
+
frames.
|
|
233
222
|
multimask_min_pt_num (int): Minimum number of clicks to use multimask output in SAM.
|
|
234
223
|
multimask_max_pt_num (int): Maximum number of clicks to use multimask output in SAM.
|
|
235
224
|
multimask_output_for_tracking (bool): Whether to use multimask output for tracking.
|
|
236
225
|
use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers.
|
|
237
226
|
iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1].
|
|
238
227
|
memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation.
|
|
239
|
-
non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in
|
|
240
|
-
|
|
228
|
+
non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in memory
|
|
229
|
+
encoder during evaluation.
|
|
241
230
|
use_obj_ptrs_in_encoder (bool): Whether to cross-attend to object pointers from other frames in the encoder.
|
|
242
231
|
max_obj_ptrs_in_encoder (int): Maximum number of object pointers from other frames in encoder
|
|
243
232
|
cross-attention.
|
|
244
|
-
add_tpos_enc_to_obj_ptrs (bool): Whether to add temporal positional encoding to object pointers in
|
|
245
|
-
|
|
233
|
+
add_tpos_enc_to_obj_ptrs (bool): Whether to add temporal positional encoding to object pointers in the
|
|
234
|
+
encoder.
|
|
246
235
|
proj_tpos_enc_in_obj_ptrs (bool): Whether to add an extra linear projection layer for temporal positional
|
|
247
236
|
encoding in object pointers.
|
|
248
237
|
use_signed_tpos_enc_to_obj_ptrs (bool): Whether to use signed distance in the temporal positional encoding
|
|
249
238
|
in the object pointers.
|
|
250
|
-
only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past
|
|
251
|
-
|
|
239
|
+
only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past during
|
|
240
|
+
evaluation.
|
|
252
241
|
pred_obj_scores (bool): Whether to predict if there is an object in the frame.
|
|
253
242
|
pred_obj_scores_mlp (bool): Whether to use an MLP to predict object scores.
|
|
254
243
|
fixed_no_obj_ptr (bool): Whether to have a fixed no-object pointer when there is no object present.
|
|
@@ -257,15 +246,6 @@ class SAM2Model(torch.nn.Module):
|
|
|
257
246
|
no_obj_embed_spatial (bool): Whether add no obj embedding to spatial frames.
|
|
258
247
|
sam_mask_decoder_extra_args (dict | None): Extra arguments for constructing the SAM mask decoder.
|
|
259
248
|
compile_image_encoder (bool): Whether to compile the image encoder for faster inference.
|
|
260
|
-
|
|
261
|
-
Examples:
|
|
262
|
-
>>> image_encoder = ImageEncoderViT(...)
|
|
263
|
-
>>> memory_attention = SAM2TwoWayTransformer(...)
|
|
264
|
-
>>> memory_encoder = nn.Sequential(...)
|
|
265
|
-
>>> model = SAM2Model(image_encoder, memory_attention, memory_encoder)
|
|
266
|
-
>>> image_batch = torch.rand(1, 3, 512, 512)
|
|
267
|
-
>>> features = model.forward_image(image_batch)
|
|
268
|
-
>>> track_results = model.track_step(0, True, features, None, None, None, {})
|
|
269
249
|
"""
|
|
270
250
|
super().__init__()
|
|
271
251
|
|
|
@@ -349,6 +329,7 @@ class SAM2Model(torch.nn.Module):
|
|
|
349
329
|
|
|
350
330
|
self._build_sam_heads()
|
|
351
331
|
self.max_cond_frames_in_attn = max_cond_frames_in_attn
|
|
332
|
+
self.add_all_frames_to_correct_as_cond = True
|
|
352
333
|
|
|
353
334
|
# Model compilation
|
|
354
335
|
if compile_image_encoder:
|
|
@@ -428,25 +409,23 @@ class SAM2Model(torch.nn.Module):
|
|
|
428
409
|
high_res_features=None,
|
|
429
410
|
multimask_output=False,
|
|
430
411
|
):
|
|
431
|
-
"""
|
|
432
|
-
Forward pass through SAM prompt encoders and mask heads.
|
|
412
|
+
"""Forward pass through SAM prompt encoders and mask heads.
|
|
433
413
|
|
|
434
414
|
This method processes image features and optional point/mask inputs to generate object masks and scores.
|
|
435
415
|
|
|
436
416
|
Args:
|
|
437
417
|
backbone_features (torch.Tensor): Image features with shape (B, C, H, W).
|
|
438
418
|
point_inputs (dict[str, torch.Tensor] | None): Dictionary containing point prompts.
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
mask_inputs (torch.Tensor | None): Mask of shape (B, 1, H*16, W*16), float or bool, with the
|
|
444
|
-
|
|
445
|
-
high_res_features (list[torch.Tensor] | None): List of two feature maps with shapes
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
output only 1 mask and its IoU estimate.
|
|
419
|
+
'point_coords': Tensor of shape (B, P, 2) with float32 dtype, containing absolute pixel-unit coordinates in
|
|
420
|
+
(x, y) format for P input points.
|
|
421
|
+
'point_labels': Tensor of shape (B, P) with int32 dtype, where 1 means positive clicks, 0 means negative
|
|
422
|
+
clicks, and -1 means padding.
|
|
423
|
+
mask_inputs (torch.Tensor | None): Mask of shape (B, 1, H*16, W*16), float or bool, with the same spatial
|
|
424
|
+
size as the image.
|
|
425
|
+
high_res_features (list[torch.Tensor] | None): List of two feature maps with shapes (B, C, 4*H, 4*W) and (B,
|
|
426
|
+
C, 2*H, 2*W) respectively, used as high-resolution feature maps for SAM decoder.
|
|
427
|
+
multimask_output (bool): If True, output 3 candidate masks and their IoU estimates; if False, output only 1
|
|
428
|
+
mask and its IoU estimate.
|
|
450
429
|
|
|
451
430
|
Returns:
|
|
452
431
|
low_res_multimasks (torch.Tensor): Tensor of shape (B, M, H*4, W*4) with SAM output mask logits.
|
|
@@ -472,7 +451,7 @@ class SAM2Model(torch.nn.Module):
|
|
|
472
451
|
... object_score_logits,
|
|
473
452
|
... ) = results
|
|
474
453
|
"""
|
|
475
|
-
B = backbone_features.
|
|
454
|
+
B = backbone_features.shape[0]
|
|
476
455
|
device = backbone_features.device
|
|
477
456
|
assert backbone_features.size(1) == self.sam_prompt_embed_dim
|
|
478
457
|
assert backbone_features.size(2) == self.sam_image_embedding_size
|
|
@@ -482,7 +461,7 @@ class SAM2Model(torch.nn.Module):
|
|
|
482
461
|
if point_inputs is not None:
|
|
483
462
|
sam_point_coords = point_inputs["point_coords"]
|
|
484
463
|
sam_point_labels = point_inputs["point_labels"]
|
|
485
|
-
assert sam_point_coords.
|
|
464
|
+
assert sam_point_coords.shape[0] == B and sam_point_labels.shape[0] == B
|
|
486
465
|
else:
|
|
487
466
|
# If no points are provide, pad with an empty point (with label -1)
|
|
488
467
|
sam_point_coords = torch.zeros(B, 1, 2, device=device, dtype=backbone_features.dtype)
|
|
@@ -495,7 +474,7 @@ class SAM2Model(torch.nn.Module):
|
|
|
495
474
|
assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
|
|
496
475
|
if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
|
|
497
476
|
sam_mask_prompt = F.interpolate(
|
|
498
|
-
mask_inputs.
|
|
477
|
+
mask_inputs.to(backbone_features.dtype),
|
|
499
478
|
size=self.sam_prompt_encoder.mask_input_size,
|
|
500
479
|
align_corners=False,
|
|
501
480
|
mode="bilinear",
|
|
@@ -585,15 +564,15 @@ class SAM2Model(torch.nn.Module):
|
|
|
585
564
|
antialias=True, # use antialias for downsampling
|
|
586
565
|
)
|
|
587
566
|
# a dummy IoU prediction of all 1's under mask input
|
|
588
|
-
ious = mask_inputs.new_ones(mask_inputs.
|
|
567
|
+
ious = mask_inputs.new_ones(mask_inputs.shape[0], 1).float()
|
|
589
568
|
if not self.use_obj_ptrs_in_encoder or backbone_features is None or high_res_features is None:
|
|
590
569
|
# all zeros as a dummy object pointer (of shape [B, C])
|
|
591
|
-
obj_ptr = torch.zeros(mask_inputs.
|
|
570
|
+
obj_ptr = torch.zeros(mask_inputs.shape[0], self.hidden_dim, device=mask_inputs.device)
|
|
592
571
|
else:
|
|
593
572
|
# produce an object pointer using the SAM decoder from the mask input
|
|
594
573
|
_, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
|
|
595
574
|
backbone_features=backbone_features,
|
|
596
|
-
mask_inputs=self.mask_downsample(mask_inputs_float),
|
|
575
|
+
mask_inputs=self.mask_downsample(mask_inputs_float.to(backbone_features.dtype)),
|
|
597
576
|
high_res_features=high_res_features,
|
|
598
577
|
)
|
|
599
578
|
# In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
|
|
@@ -628,8 +607,14 @@ class SAM2Model(torch.nn.Module):
|
|
|
628
607
|
backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
|
|
629
608
|
return backbone_out
|
|
630
609
|
|
|
631
|
-
def _prepare_backbone_features(self, backbone_out):
|
|
610
|
+
def _prepare_backbone_features(self, backbone_out, batch=1):
|
|
632
611
|
"""Prepare and flatten visual features from the image backbone output for further processing."""
|
|
612
|
+
if batch > 1: # expand features if there's more than one prompt
|
|
613
|
+
backbone_out = {
|
|
614
|
+
**backbone_out,
|
|
615
|
+
"backbone_fpn": [feat.expand(batch, -1, -1, -1) for feat in backbone_out["backbone_fpn"]],
|
|
616
|
+
"vision_pos_enc": [pos.expand(batch, -1, -1, -1) for pos in backbone_out["vision_pos_enc"]],
|
|
617
|
+
}
|
|
633
618
|
assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
|
|
634
619
|
assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
|
|
635
620
|
|
|
@@ -640,7 +625,6 @@ class SAM2Model(torch.nn.Module):
|
|
|
640
625
|
# flatten NxCxHxW to HWxNxC
|
|
641
626
|
vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
|
|
642
627
|
vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]
|
|
643
|
-
|
|
644
628
|
return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
|
|
645
629
|
|
|
646
630
|
def _prepare_memory_conditioned_features(
|
|
@@ -712,7 +696,7 @@ class SAM2Model(torch.nn.Module):
|
|
|
712
696
|
continue # skip padding frames
|
|
713
697
|
# "maskmem_features" might have been offloaded to CPU in demo use cases,
|
|
714
698
|
# so we load it back to inference device (it's a no-op if it's already on device).
|
|
715
|
-
feats = prev["maskmem_features"].to(device=device, non_blocking=
|
|
699
|
+
feats = prev["maskmem_features"].to(device=device, non_blocking=device.type == "cuda")
|
|
716
700
|
to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
|
|
717
701
|
# Spatial positional encoding (it might have been offloaded to CPU in eval)
|
|
718
702
|
maskmem_enc = prev["maskmem_pos_enc"][-1].to(device=device)
|
|
@@ -803,7 +787,7 @@ class SAM2Model(torch.nn.Module):
|
|
|
803
787
|
memory_pos=memory_pos_embed,
|
|
804
788
|
num_obj_ptr_tokens=num_obj_ptr_tokens,
|
|
805
789
|
)
|
|
806
|
-
#
|
|
790
|
+
# Reshape output (HW)BC => BCHW
|
|
807
791
|
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
|
|
808
792
|
return pix_feat_with_mem
|
|
809
793
|
|
|
@@ -840,7 +824,6 @@ class SAM2Model(torch.nn.Module):
|
|
|
840
824
|
mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
|
|
841
825
|
maskmem_out = self.memory_encoder(pix_feat, mask_for_mem, skip_mask_sigmoid=True) # sigmoid already applied
|
|
842
826
|
maskmem_features = maskmem_out["vision_features"]
|
|
843
|
-
maskmem_pos_enc = maskmem_out["vision_pos_enc"]
|
|
844
827
|
# add a no-object embedding to the spatial memory to indicate that the frame
|
|
845
828
|
# is predicted to be occluded (i.e. no object is appearing in the frame)
|
|
846
829
|
if self.no_obj_embed_spatial is not None:
|
|
@@ -849,7 +832,7 @@ class SAM2Model(torch.nn.Module):
|
|
|
849
832
|
..., None, None
|
|
850
833
|
].expand(*maskmem_features.shape)
|
|
851
834
|
|
|
852
|
-
return maskmem_features,
|
|
835
|
+
return maskmem_features, maskmem_out["vision_pos_enc"]
|
|
853
836
|
|
|
854
837
|
def _track_step(
|
|
855
838
|
self,
|
|
@@ -881,7 +864,7 @@ class SAM2Model(torch.nn.Module):
|
|
|
881
864
|
pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
|
|
882
865
|
sam_outputs = self._use_mask_as_output(mask_inputs, pix_feat, high_res_features)
|
|
883
866
|
else:
|
|
884
|
-
#
|
|
867
|
+
# Fuse visual features with previous memory features in the memory bank
|
|
885
868
|
pix_feat = self._prepare_memory_conditioned_features(
|
|
886
869
|
frame_idx=frame_idx,
|
|
887
870
|
is_init_cond_frame=is_init_cond_frame,
|
|
@@ -1006,7 +989,7 @@ class SAM2Model(torch.nn.Module):
|
|
|
1006
989
|
@staticmethod
|
|
1007
990
|
def _apply_non_overlapping_constraints(pred_masks):
|
|
1008
991
|
"""Apply non-overlapping constraints to masks, keeping the highest scoring object per location."""
|
|
1009
|
-
batch_size = pred_masks.
|
|
992
|
+
batch_size = pred_masks.shape[0]
|
|
1010
993
|
if batch_size == 1:
|
|
1011
994
|
return pred_masks
|
|
1012
995
|
|
|
@@ -1027,7 +1010,151 @@ class SAM2Model(torch.nn.Module):
|
|
|
1027
1010
|
|
|
1028
1011
|
def set_imgsz(self, imgsz):
|
|
1029
1012
|
"""Set image size to make model compatible with different image sizes."""
|
|
1013
|
+
if hasattr(self.image_encoder, "set_imgsz"):
|
|
1014
|
+
self.image_encoder.set_imgsz(imgsz)
|
|
1030
1015
|
self.image_size = imgsz[0]
|
|
1031
1016
|
self.sam_prompt_encoder.input_image_size = imgsz
|
|
1032
|
-
self.sam_prompt_encoder.image_embedding_size = [
|
|
1017
|
+
self.sam_prompt_encoder.image_embedding_size = [
|
|
1018
|
+
x // self.backbone_stride for x in imgsz
|
|
1019
|
+
] # fixed ViT patch size of 16
|
|
1020
|
+
self.sam_prompt_encoder.mask_input_size = [
|
|
1021
|
+
x // self.backbone_stride * 4 for x in imgsz
|
|
1022
|
+
] # fixed ViT patch size of 16
|
|
1033
1023
|
self.sam_image_embedding_size = self.image_size // self.backbone_stride # update image embedding size
|
|
1024
|
+
|
|
1025
|
+
|
|
1026
|
+
class SAM3Model(SAM2Model):
|
|
1027
|
+
"""SAM3Model class for Segment Anything Model 3 with memory-based video object segmentation capabilities."""
|
|
1028
|
+
|
|
1029
|
+
def __init__(
|
|
1030
|
+
self,
|
|
1031
|
+
image_encoder,
|
|
1032
|
+
memory_attention,
|
|
1033
|
+
memory_encoder,
|
|
1034
|
+
num_maskmem=7,
|
|
1035
|
+
image_size=1008,
|
|
1036
|
+
backbone_stride=14,
|
|
1037
|
+
sigmoid_scale_for_mem_enc=1,
|
|
1038
|
+
sigmoid_bias_for_mem_enc=0,
|
|
1039
|
+
binarize_mask_from_pts_for_mem_enc=False,
|
|
1040
|
+
use_mask_input_as_output_without_sam=False,
|
|
1041
|
+
max_cond_frames_in_attn=-1,
|
|
1042
|
+
directly_add_no_mem_embed=False,
|
|
1043
|
+
use_high_res_features_in_sam=False,
|
|
1044
|
+
multimask_output_in_sam=False,
|
|
1045
|
+
multimask_min_pt_num=1,
|
|
1046
|
+
multimask_max_pt_num=1,
|
|
1047
|
+
multimask_output_for_tracking=False,
|
|
1048
|
+
use_multimask_token_for_obj_ptr: bool = False,
|
|
1049
|
+
iou_prediction_use_sigmoid=False,
|
|
1050
|
+
memory_temporal_stride_for_eval=1,
|
|
1051
|
+
non_overlap_masks_for_mem_enc=False,
|
|
1052
|
+
use_obj_ptrs_in_encoder=False,
|
|
1053
|
+
max_obj_ptrs_in_encoder=16,
|
|
1054
|
+
add_tpos_enc_to_obj_ptrs=True,
|
|
1055
|
+
proj_tpos_enc_in_obj_ptrs=False,
|
|
1056
|
+
use_signed_tpos_enc_to_obj_ptrs=False,
|
|
1057
|
+
only_obj_ptrs_in_the_past_for_eval=False,
|
|
1058
|
+
pred_obj_scores: bool = False,
|
|
1059
|
+
pred_obj_scores_mlp: bool = False,
|
|
1060
|
+
fixed_no_obj_ptr: bool = False,
|
|
1061
|
+
soft_no_obj_ptr: bool = False,
|
|
1062
|
+
use_mlp_for_obj_ptr_proj: bool = False,
|
|
1063
|
+
no_obj_embed_spatial: bool = False,
|
|
1064
|
+
sam_mask_decoder_extra_args=None,
|
|
1065
|
+
compile_image_encoder: bool = False,
|
|
1066
|
+
):
|
|
1067
|
+
"""SAM3Model class for Segment Anything Model 3 with memory-based video object segmentation capabilities."""
|
|
1068
|
+
super().__init__(
|
|
1069
|
+
image_encoder,
|
|
1070
|
+
memory_attention,
|
|
1071
|
+
memory_encoder,
|
|
1072
|
+
num_maskmem,
|
|
1073
|
+
image_size,
|
|
1074
|
+
backbone_stride,
|
|
1075
|
+
sigmoid_scale_for_mem_enc,
|
|
1076
|
+
sigmoid_bias_for_mem_enc,
|
|
1077
|
+
binarize_mask_from_pts_for_mem_enc,
|
|
1078
|
+
use_mask_input_as_output_without_sam,
|
|
1079
|
+
max_cond_frames_in_attn,
|
|
1080
|
+
directly_add_no_mem_embed,
|
|
1081
|
+
use_high_res_features_in_sam,
|
|
1082
|
+
multimask_output_in_sam,
|
|
1083
|
+
multimask_min_pt_num,
|
|
1084
|
+
multimask_max_pt_num,
|
|
1085
|
+
multimask_output_for_tracking,
|
|
1086
|
+
use_multimask_token_for_obj_ptr,
|
|
1087
|
+
iou_prediction_use_sigmoid,
|
|
1088
|
+
memory_temporal_stride_for_eval,
|
|
1089
|
+
non_overlap_masks_for_mem_enc,
|
|
1090
|
+
use_obj_ptrs_in_encoder,
|
|
1091
|
+
max_obj_ptrs_in_encoder,
|
|
1092
|
+
add_tpos_enc_to_obj_ptrs,
|
|
1093
|
+
proj_tpos_enc_in_obj_ptrs,
|
|
1094
|
+
use_signed_tpos_enc_to_obj_ptrs,
|
|
1095
|
+
only_obj_ptrs_in_the_past_for_eval,
|
|
1096
|
+
pred_obj_scores,
|
|
1097
|
+
pred_obj_scores_mlp,
|
|
1098
|
+
fixed_no_obj_ptr,
|
|
1099
|
+
soft_no_obj_ptr,
|
|
1100
|
+
use_mlp_for_obj_ptr_proj,
|
|
1101
|
+
no_obj_embed_spatial,
|
|
1102
|
+
sam_mask_decoder_extra_args,
|
|
1103
|
+
compile_image_encoder,
|
|
1104
|
+
)
|
|
1105
|
+
self.sam_mask_decoder = SAM2MaskDecoder(
|
|
1106
|
+
num_multimask_outputs=3,
|
|
1107
|
+
transformer=TwoWayTransformer(
|
|
1108
|
+
depth=2,
|
|
1109
|
+
embedding_dim=self.sam_prompt_embed_dim,
|
|
1110
|
+
mlp_dim=2048,
|
|
1111
|
+
num_heads=8,
|
|
1112
|
+
),
|
|
1113
|
+
transformer_dim=self.sam_prompt_embed_dim,
|
|
1114
|
+
iou_head_depth=3,
|
|
1115
|
+
iou_head_hidden_dim=256,
|
|
1116
|
+
use_high_res_features=self.use_high_res_features_in_sam,
|
|
1117
|
+
iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,
|
|
1118
|
+
pred_obj_scores=self.pred_obj_scores,
|
|
1119
|
+
pred_obj_scores_mlp=self.pred_obj_scores_mlp,
|
|
1120
|
+
use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
|
|
1121
|
+
**(self.sam_mask_decoder_extra_args or {}),
|
|
1122
|
+
)
|
|
1123
|
+
|
|
1124
|
+
def forward_image(self, img_batch: torch.Tensor):
|
|
1125
|
+
"""Process image batch through encoder to extract multi-level features for SAM model."""
|
|
1126
|
+
backbone_out = self.image_encoder.forward_image_sam2(img_batch)
|
|
1127
|
+
if self.use_high_res_features_in_sam:
|
|
1128
|
+
# precompute projected level 0 and level 1 features in SAM decoder
|
|
1129
|
+
# to avoid running it again on every SAM click
|
|
1130
|
+
backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
|
|
1131
|
+
backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
|
|
1132
|
+
return backbone_out
|
|
1133
|
+
|
|
1134
|
+
def set_imgsz(self, imgsz: tuple[int, int]):
|
|
1135
|
+
"""Set the image size for the model and mask downsampler."""
|
|
1136
|
+
super().set_imgsz(imgsz)
|
|
1137
|
+
self.memory_encoder.mask_downsampler.interpol_size = [size // 14 * 16 for size in imgsz]
|
|
1138
|
+
|
|
1139
|
+
@staticmethod
|
|
1140
|
+
def _suppress_shrinked_masks(pred_masks, new_pred_masks, shrink_threshold=0.3):
|
|
1141
|
+
"""Suppress masks that shrink in area after applying pixelwise non-overlapping constraints."""
|
|
1142
|
+
area_before = (pred_masks > 0).sum(dim=(-1, -2))
|
|
1143
|
+
area_after = (new_pred_masks > 0).sum(dim=(-1, -2))
|
|
1144
|
+
area_before = torch.clamp(area_before, min=1.0)
|
|
1145
|
+
area_ratio = area_after / area_before
|
|
1146
|
+
keep = area_ratio >= shrink_threshold
|
|
1147
|
+
keep_mask = keep[..., None, None].expand_as(pred_masks)
|
|
1148
|
+
pred_masks_after = torch.where(keep_mask, pred_masks, torch.clamp(pred_masks, max=-10.0))
|
|
1149
|
+
return pred_masks_after
|
|
1150
|
+
|
|
1151
|
+
def _suppress_object_pw_area_shrinkage(self, pred_masks):
|
|
1152
|
+
"""This function suppresses masks that shrink in area after applying pixelwise non-overlapping constraints. Note
|
|
1153
|
+
that the final output can still be overlapping.
|
|
1154
|
+
"""
|
|
1155
|
+
# Apply pixel-wise non-overlapping constraint based on mask scores
|
|
1156
|
+
pixel_level_non_overlapping_masks = self._apply_non_overlapping_constraints(pred_masks)
|
|
1157
|
+
# Fully suppress masks with high shrinkage (probably noisy) based on the pixel wise non-overlapping constraints
|
|
1158
|
+
# NOTE: The output of this function can be a no op if none of the masks shrink by a large factor.
|
|
1159
|
+
pred_masks = self._suppress_shrinked_masks(pred_masks, pixel_level_non_overlapping_masks)
|
|
1160
|
+
return pred_masks
|