dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
- dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -6
- tests/conftest.py +15 -39
- tests/test_cli.py +17 -17
- tests/test_cuda.py +17 -8
- tests/test_engine.py +36 -10
- tests/test_exports.py +98 -37
- tests/test_integrations.py +12 -15
- tests/test_python.py +126 -82
- tests/test_solutions.py +319 -135
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +83 -87
- ultralytics/cfg/datasets/Argoverse.yaml +4 -4
- ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
- ultralytics/cfg/datasets/ImageNet.yaml +3 -3
- ultralytics/cfg/datasets/Objects365.yaml +24 -20
- ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
- ultralytics/cfg/datasets/VOC.yaml +10 -13
- ultralytics/cfg/datasets/VisDrone.yaml +43 -33
- ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
- ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
- ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
- ultralytics/cfg/datasets/coco-pose.yaml +26 -4
- ultralytics/cfg/datasets/coco.yaml +4 -4
- ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco128.yaml +2 -2
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco8.yaml +2 -2
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +5 -5
- ultralytics/cfg/datasets/dog-pose.yaml +32 -4
- ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
- ultralytics/cfg/datasets/lvis.yaml +9 -9
- ultralytics/cfg/datasets/medical-pills.yaml +4 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
- ultralytics/cfg/datasets/package-seg.yaml +5 -5
- ultralytics/cfg/datasets/signature.yaml +4 -4
- ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
- ultralytics/cfg/datasets/xView.yaml +5 -5
- ultralytics/cfg/default.yaml +96 -93
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +12 -12
- ultralytics/data/augment.py +531 -564
- ultralytics/data/base.py +76 -81
- ultralytics/data/build.py +206 -42
- ultralytics/data/converter.py +179 -78
- ultralytics/data/dataset.py +121 -121
- ultralytics/data/loaders.py +114 -91
- ultralytics/data/split.py +28 -15
- ultralytics/data/split_dota.py +67 -48
- ultralytics/data/utils.py +110 -89
- ultralytics/engine/exporter.py +422 -460
- ultralytics/engine/model.py +224 -252
- ultralytics/engine/predictor.py +94 -89
- ultralytics/engine/results.py +345 -595
- ultralytics/engine/trainer.py +231 -134
- ultralytics/engine/tuner.py +279 -73
- ultralytics/engine/validator.py +53 -46
- ultralytics/hub/__init__.py +26 -28
- ultralytics/hub/auth.py +30 -16
- ultralytics/hub/google/__init__.py +34 -36
- ultralytics/hub/session.py +53 -77
- ultralytics/hub/utils.py +23 -109
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +36 -18
- ultralytics/models/fastsam/predict.py +33 -44
- ultralytics/models/fastsam/utils.py +4 -5
- ultralytics/models/fastsam/val.py +12 -14
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +16 -20
- ultralytics/models/nas/predict.py +12 -14
- ultralytics/models/nas/val.py +4 -5
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +9 -9
- ultralytics/models/rtdetr/predict.py +22 -17
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +79 -59
- ultralytics/models/sam/__init__.py +8 -2
- ultralytics/models/sam/amg.py +53 -38
- ultralytics/models/sam/build.py +29 -31
- ultralytics/models/sam/model.py +33 -38
- ultralytics/models/sam/modules/blocks.py +159 -182
- ultralytics/models/sam/modules/decoders.py +38 -47
- ultralytics/models/sam/modules/encoders.py +114 -133
- ultralytics/models/sam/modules/memory_attention.py +38 -31
- ultralytics/models/sam/modules/sam.py +114 -93
- ultralytics/models/sam/modules/tiny_encoder.py +268 -291
- ultralytics/models/sam/modules/transformer.py +59 -66
- ultralytics/models/sam/modules/utils.py +55 -72
- ultralytics/models/sam/predict.py +745 -341
- ultralytics/models/utils/loss.py +118 -107
- ultralytics/models/utils/ops.py +118 -71
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +28 -26
- ultralytics/models/yolo/classify/train.py +50 -81
- ultralytics/models/yolo/classify/val.py +68 -61
- ultralytics/models/yolo/detect/predict.py +12 -15
- ultralytics/models/yolo/detect/train.py +56 -46
- ultralytics/models/yolo/detect/val.py +279 -223
- ultralytics/models/yolo/model.py +167 -86
- ultralytics/models/yolo/obb/predict.py +7 -11
- ultralytics/models/yolo/obb/train.py +23 -25
- ultralytics/models/yolo/obb/val.py +107 -99
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +12 -14
- ultralytics/models/yolo/pose/train.py +31 -69
- ultralytics/models/yolo/pose/val.py +119 -254
- ultralytics/models/yolo/segment/predict.py +21 -25
- ultralytics/models/yolo/segment/train.py +12 -66
- ultralytics/models/yolo/segment/val.py +126 -305
- ultralytics/models/yolo/world/train.py +53 -45
- ultralytics/models/yolo/world/train_world.py +51 -32
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +30 -37
- ultralytics/models/yolo/yoloe/train.py +89 -71
- ultralytics/models/yolo/yoloe/train_seg.py +15 -17
- ultralytics/models/yolo/yoloe/val.py +56 -41
- ultralytics/nn/__init__.py +9 -11
- ultralytics/nn/autobackend.py +179 -107
- ultralytics/nn/modules/__init__.py +67 -67
- ultralytics/nn/modules/activation.py +8 -7
- ultralytics/nn/modules/block.py +302 -323
- ultralytics/nn/modules/conv.py +61 -104
- ultralytics/nn/modules/head.py +488 -186
- ultralytics/nn/modules/transformer.py +183 -123
- ultralytics/nn/modules/utils.py +15 -20
- ultralytics/nn/tasks.py +327 -203
- ultralytics/nn/text_model.py +81 -65
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +19 -27
- ultralytics/solutions/analytics.py +36 -26
- ultralytics/solutions/config.py +29 -28
- ultralytics/solutions/distance_calculation.py +23 -24
- ultralytics/solutions/heatmap.py +17 -19
- ultralytics/solutions/instance_segmentation.py +21 -19
- ultralytics/solutions/object_blurrer.py +16 -17
- ultralytics/solutions/object_counter.py +48 -53
- ultralytics/solutions/object_cropper.py +22 -16
- ultralytics/solutions/parking_management.py +61 -58
- ultralytics/solutions/queue_management.py +19 -19
- ultralytics/solutions/region_counter.py +63 -50
- ultralytics/solutions/security_alarm.py +22 -25
- ultralytics/solutions/similarity_search.py +107 -60
- ultralytics/solutions/solutions.py +343 -262
- ultralytics/solutions/speed_estimation.py +35 -31
- ultralytics/solutions/streamlit_inference.py +104 -40
- ultralytics/solutions/templates/similarity-search.html +31 -24
- ultralytics/solutions/trackzone.py +24 -24
- ultralytics/solutions/vision_eye.py +11 -12
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +18 -27
- ultralytics/trackers/bot_sort.py +48 -39
- ultralytics/trackers/byte_tracker.py +94 -94
- ultralytics/trackers/track.py +7 -16
- ultralytics/trackers/utils/gmc.py +37 -69
- ultralytics/trackers/utils/kalman_filter.py +68 -76
- ultralytics/trackers/utils/matching.py +13 -17
- ultralytics/utils/__init__.py +251 -275
- ultralytics/utils/autobatch.py +19 -7
- ultralytics/utils/autodevice.py +68 -38
- ultralytics/utils/benchmarks.py +169 -130
- ultralytics/utils/callbacks/base.py +12 -13
- ultralytics/utils/callbacks/clearml.py +14 -15
- ultralytics/utils/callbacks/comet.py +139 -66
- ultralytics/utils/callbacks/dvc.py +19 -27
- ultralytics/utils/callbacks/hub.py +8 -6
- ultralytics/utils/callbacks/mlflow.py +6 -10
- ultralytics/utils/callbacks/neptune.py +11 -19
- ultralytics/utils/callbacks/platform.py +73 -0
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +9 -12
- ultralytics/utils/callbacks/wb.py +33 -30
- ultralytics/utils/checks.py +163 -114
- ultralytics/utils/cpu.py +89 -0
- ultralytics/utils/dist.py +24 -20
- ultralytics/utils/downloads.py +176 -146
- ultralytics/utils/errors.py +11 -13
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +81 -63
- ultralytics/utils/export/imx.py +294 -0
- ultralytics/utils/export/tensorflow.py +217 -0
- ultralytics/utils/files.py +33 -36
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +105 -120
- ultralytics/utils/logger.py +404 -0
- ultralytics/utils/loss.py +99 -61
- ultralytics/utils/metrics.py +649 -478
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +263 -451
- ultralytics/utils/patches.py +70 -31
- ultralytics/utils/plotting.py +253 -223
- ultralytics/utils/tal.py +48 -61
- ultralytics/utils/torch_utils.py +244 -251
- ultralytics/utils/tqdm.py +438 -0
- ultralytics/utils/triton.py +22 -23
- ultralytics/utils/tuner.py +11 -10
- dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
|
@@ -8,8 +8,12 @@ using SAM. It forms an integral part of the Ultralytics framework and is designe
|
|
|
8
8
|
segmentation tasks.
|
|
9
9
|
"""
|
|
10
10
|
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
11
13
|
from collections import OrderedDict
|
|
14
|
+
from typing import Any
|
|
12
15
|
|
|
16
|
+
import cv2
|
|
13
17
|
import numpy as np
|
|
14
18
|
import torch
|
|
15
19
|
import torch.nn.functional as F
|
|
@@ -34,12 +38,11 @@ from .amg import (
|
|
|
34
38
|
|
|
35
39
|
|
|
36
40
|
class Predictor(BasePredictor):
|
|
37
|
-
"""
|
|
38
|
-
Predictor class for SAM, enabling real-time image segmentation with promptable capabilities.
|
|
41
|
+
"""Predictor class for SAM, enabling real-time image segmentation with promptable capabilities.
|
|
39
42
|
|
|
40
|
-
This class extends BasePredictor and implements the Segment Anything Model (SAM) for advanced image
|
|
41
|
-
|
|
42
|
-
|
|
43
|
+
This class extends BasePredictor and implements the Segment Anything Model (SAM) for advanced image segmentation
|
|
44
|
+
tasks. It supports various input prompts like points, bounding boxes, and masks for fine-grained control over
|
|
45
|
+
segmentation results.
|
|
43
46
|
|
|
44
47
|
Attributes:
|
|
45
48
|
args (SimpleNamespace): Configuration arguments for the predictor.
|
|
@@ -47,26 +50,26 @@ class Predictor(BasePredictor):
|
|
|
47
50
|
device (torch.device): The device (CPU or GPU) on which the model is loaded.
|
|
48
51
|
im (torch.Tensor): The preprocessed input image.
|
|
49
52
|
features (torch.Tensor): Extracted image features.
|
|
50
|
-
prompts (dict): Dictionary to store various types of prompts (e.g., bboxes, points, masks).
|
|
53
|
+
prompts (dict[str, Any]): Dictionary to store various types of prompts (e.g., bboxes, points, masks).
|
|
51
54
|
segment_all (bool): Flag to indicate if full image segmentation should be performed.
|
|
52
55
|
mean (torch.Tensor): Mean values for image normalization.
|
|
53
56
|
std (torch.Tensor): Standard deviation values for image normalization.
|
|
54
57
|
|
|
55
58
|
Methods:
|
|
56
|
-
preprocess:
|
|
57
|
-
pre_transform:
|
|
58
|
-
inference:
|
|
59
|
+
preprocess: Prepare input images for model inference.
|
|
60
|
+
pre_transform: Perform initial transformations on the input image.
|
|
61
|
+
inference: Perform segmentation inference based on input prompts.
|
|
59
62
|
prompt_inference: Internal function for prompt-based segmentation inference.
|
|
60
|
-
generate:
|
|
61
|
-
setup_model:
|
|
62
|
-
get_model:
|
|
63
|
-
postprocess: Post-
|
|
64
|
-
setup_source:
|
|
65
|
-
set_image:
|
|
66
|
-
get_im_features:
|
|
67
|
-
set_prompts:
|
|
68
|
-
reset_image:
|
|
69
|
-
remove_small_regions:
|
|
63
|
+
generate: Generate segmentation masks for an entire image.
|
|
64
|
+
setup_model: Initialize the SAM model for inference.
|
|
65
|
+
get_model: Build and return a SAM model.
|
|
66
|
+
postprocess: Post-process model outputs to generate final results.
|
|
67
|
+
setup_source: Set up the data source for inference.
|
|
68
|
+
set_image: Set and preprocess a single image for inference.
|
|
69
|
+
get_im_features: Extract image features using the SAM image encoder.
|
|
70
|
+
set_prompts: Set prompts for subsequent inference.
|
|
71
|
+
reset_image: Reset the current image and its features.
|
|
72
|
+
remove_small_regions: Remove small disconnected regions and holes from masks.
|
|
70
73
|
|
|
71
74
|
Examples:
|
|
72
75
|
>>> predictor = Predictor()
|
|
@@ -77,17 +80,16 @@ class Predictor(BasePredictor):
|
|
|
77
80
|
"""
|
|
78
81
|
|
|
79
82
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
80
|
-
"""
|
|
81
|
-
Initialize the Predictor with configuration, overrides, and callbacks.
|
|
83
|
+
"""Initialize the Predictor with configuration, overrides, and callbacks.
|
|
82
84
|
|
|
83
85
|
Sets up the Predictor object for SAM (Segment Anything Model) and applies any configuration overrides or
|
|
84
|
-
callbacks provided. Initializes task-specific settings for SAM, such as retina_masks being set to True
|
|
85
|
-
|
|
86
|
+
callbacks provided. Initializes task-specific settings for SAM, such as retina_masks being set to True for
|
|
87
|
+
optimal results.
|
|
86
88
|
|
|
87
89
|
Args:
|
|
88
90
|
cfg (dict): Configuration dictionary containing default settings.
|
|
89
|
-
overrides (
|
|
90
|
-
_callbacks (
|
|
91
|
+
overrides (dict | None): Dictionary of values to override default configuration.
|
|
92
|
+
_callbacks (dict | None): Dictionary of callback functions to customize behavior.
|
|
91
93
|
|
|
92
94
|
Examples:
|
|
93
95
|
>>> predictor_example = Predictor(cfg=DEFAULT_CFG)
|
|
@@ -105,17 +107,16 @@ class Predictor(BasePredictor):
|
|
|
105
107
|
self.segment_all = False
|
|
106
108
|
|
|
107
109
|
def preprocess(self, im):
|
|
108
|
-
"""
|
|
109
|
-
Preprocess the input image for model inference.
|
|
110
|
+
"""Preprocess the input image for model inference.
|
|
110
111
|
|
|
111
112
|
This method prepares the input image by applying transformations and normalization. It supports both
|
|
112
113
|
torch.Tensor and list of np.ndarray as input formats.
|
|
113
114
|
|
|
114
115
|
Args:
|
|
115
|
-
im (torch.Tensor |
|
|
116
|
+
im (torch.Tensor | list[np.ndarray]): Input image(s) in BCHW tensor format or list of HWC numpy arrays.
|
|
116
117
|
|
|
117
118
|
Returns:
|
|
118
|
-
|
|
119
|
+
(torch.Tensor): The preprocessed image tensor, normalized and converted to the appropriate dtype.
|
|
119
120
|
|
|
120
121
|
Examples:
|
|
121
122
|
>>> predictor = Predictor()
|
|
@@ -132,23 +133,22 @@ class Predictor(BasePredictor):
|
|
|
132
133
|
im = torch.from_numpy(im)
|
|
133
134
|
|
|
134
135
|
im = im.to(self.device)
|
|
135
|
-
im = im.half() if self.model.fp16 else im.float()
|
|
136
136
|
if not_tensor:
|
|
137
137
|
im = (im - self.mean) / self.std
|
|
138
|
+
im = im.half() if self.model.fp16 else im.float()
|
|
138
139
|
return im
|
|
139
140
|
|
|
140
141
|
def pre_transform(self, im):
|
|
141
|
-
"""
|
|
142
|
-
Perform initial transformations on the input image for preprocessing.
|
|
142
|
+
"""Perform initial transformations on the input image for preprocessing.
|
|
143
143
|
|
|
144
|
-
This method applies transformations such as resizing to prepare the image for further preprocessing.
|
|
145
|
-
|
|
144
|
+
This method applies transformations such as resizing to prepare the image for further preprocessing. Currently,
|
|
145
|
+
batched inference is not supported; hence the list length should be 1.
|
|
146
146
|
|
|
147
147
|
Args:
|
|
148
|
-
im (
|
|
148
|
+
im (list[np.ndarray]): List containing a single image in HWC numpy array format.
|
|
149
149
|
|
|
150
150
|
Returns:
|
|
151
|
-
(
|
|
151
|
+
(list[np.ndarray]): List containing the transformed image.
|
|
152
152
|
|
|
153
153
|
Raises:
|
|
154
154
|
AssertionError: If the input list contains more than one image.
|
|
@@ -165,26 +165,25 @@ class Predictor(BasePredictor):
|
|
|
165
165
|
return [letterbox(image=x) for x in im]
|
|
166
166
|
|
|
167
167
|
def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs):
|
|
168
|
-
"""
|
|
169
|
-
Perform image segmentation inference based on the given input cues, using the currently loaded image.
|
|
168
|
+
"""Perform image segmentation inference based on the given input cues, using the currently loaded image.
|
|
170
169
|
|
|
171
|
-
This method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt
|
|
172
|
-
|
|
170
|
+
This method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder,
|
|
171
|
+
and mask decoder for real-time and promptable segmentation tasks.
|
|
173
172
|
|
|
174
173
|
Args:
|
|
175
174
|
im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).
|
|
176
|
-
bboxes (np.ndarray |
|
|
177
|
-
points (np.ndarray |
|
|
178
|
-
labels (np.ndarray |
|
|
175
|
+
bboxes (np.ndarray | list | None): Bounding boxes with shape (N, 4), in XYXY format.
|
|
176
|
+
points (np.ndarray | list | None): Points indicating object locations with shape (N, 2), in pixels.
|
|
177
|
+
labels (np.ndarray | list | None): Labels for point prompts, shape (N,). 1 = foreground, 0 = background.
|
|
179
178
|
masks (np.ndarray | None): Low-resolution masks from previous predictions, shape (N, H, W). For SAM H=W=256.
|
|
180
179
|
multimask_output (bool): Flag to return multiple masks. Helpful for ambiguous prompts.
|
|
181
180
|
*args (Any): Additional positional arguments.
|
|
182
181
|
**kwargs (Any): Additional keyword arguments.
|
|
183
182
|
|
|
184
183
|
Returns:
|
|
185
|
-
(
|
|
186
|
-
(
|
|
187
|
-
|
|
184
|
+
pred_masks (torch.Tensor): The output masks in shape (C, H, W), where C is the number of generated masks.
|
|
185
|
+
pred_scores (torch.Tensor): An array of length C containing quality scores predicted by the model for each
|
|
186
|
+
mask.
|
|
188
187
|
|
|
189
188
|
Examples:
|
|
190
189
|
>>> predictor = Predictor()
|
|
@@ -204,26 +203,24 @@ class Predictor(BasePredictor):
|
|
|
204
203
|
return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output)
|
|
205
204
|
|
|
206
205
|
def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False):
|
|
207
|
-
"""
|
|
208
|
-
Performs image segmentation inference based on input cues using SAM's specialized architecture.
|
|
206
|
+
"""Perform image segmentation inference based on input cues using SAM's specialized architecture.
|
|
209
207
|
|
|
210
|
-
This internal function leverages the Segment Anything Model (SAM) for prompt-based, real-time segmentation.
|
|
211
|
-
|
|
208
|
+
This internal function leverages the Segment Anything Model (SAM) for prompt-based, real-time segmentation. It
|
|
209
|
+
processes various input prompts such as bounding boxes, points, and masks to generate segmentation masks.
|
|
212
210
|
|
|
213
211
|
Args:
|
|
214
212
|
im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W).
|
|
215
|
-
bboxes (np.ndarray |
|
|
216
|
-
points (np.ndarray |
|
|
217
|
-
|
|
213
|
+
bboxes (np.ndarray | list | None): Bounding boxes in XYXY format with shape (N, 4).
|
|
214
|
+
points (np.ndarray | list | None): Points indicating object locations with shape (N, 2) or (N, num_points,
|
|
215
|
+
2), in pixels.
|
|
216
|
+
labels (np.ndarray | list | None): Point prompt labels with shape (N) or (N, num_points). 1 for foreground,
|
|
217
|
+
0 for background.
|
|
218
218
|
masks (np.ndarray | None): Low-res masks from previous predictions with shape (N, H, W). For SAM, H=W=256.
|
|
219
219
|
multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
|
|
220
220
|
|
|
221
|
-
Raises:
|
|
222
|
-
AssertionError: If the number of points don't match the number of labels, in case labels were passed.
|
|
223
|
-
|
|
224
221
|
Returns:
|
|
225
|
-
(
|
|
226
|
-
(
|
|
222
|
+
pred_masks (torch.Tensor): Output masks with shape (C, H, W), where C is the number of generated masks.
|
|
223
|
+
pred_scores (torch.Tensor): Quality scores predicted by the model for each mask, with length C.
|
|
227
224
|
|
|
228
225
|
Examples:
|
|
229
226
|
>>> predictor = Predictor()
|
|
@@ -233,7 +230,32 @@ class Predictor(BasePredictor):
|
|
|
233
230
|
"""
|
|
234
231
|
features = self.get_im_features(im) if self.features is None else self.features
|
|
235
232
|
|
|
236
|
-
|
|
233
|
+
prompts = self._prepare_prompts(im.shape[2:], self.batch[1][0].shape[:2], bboxes, points, labels, masks)
|
|
234
|
+
return self._inference_features(features, *prompts, multimask_output)
|
|
235
|
+
|
|
236
|
+
def _inference_features(
|
|
237
|
+
self,
|
|
238
|
+
features,
|
|
239
|
+
bboxes=None,
|
|
240
|
+
points=None,
|
|
241
|
+
labels=None,
|
|
242
|
+
masks=None,
|
|
243
|
+
multimask_output=False,
|
|
244
|
+
):
|
|
245
|
+
"""Perform inference on image features using the SAM model.
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
features (torch.Tensor): Extracted image features with shape (B, C, H, W) from the SAM model image encoder.
|
|
249
|
+
bboxes (np.ndarray | list[list[float]] | None): Bounding boxes in XYXY format with shape (N, 4).
|
|
250
|
+
points (np.ndarray | list[list[float]] | None): Object location points with shape (N, 2), in pixels.
|
|
251
|
+
labels (np.ndarray | list[int] | None): Point prompt labels with shape (N,). 1 = foreground, 0 = background.
|
|
252
|
+
masks (list[np.ndarray] | np.ndarray | None): Masks for the objects, where each mask is a 2D array.
|
|
253
|
+
multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
pred_masks (torch.Tensor): Output masks with shape (C, H, W), where C is the number of generated masks.
|
|
257
|
+
pred_scores (torch.Tensor): Quality scores for each mask, with length C.
|
|
258
|
+
"""
|
|
237
259
|
points = (points, labels) if points is not None else None
|
|
238
260
|
# Embed prompts
|
|
239
261
|
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=points, boxes=bboxes, masks=masks)
|
|
@@ -251,28 +273,33 @@ class Predictor(BasePredictor):
|
|
|
251
273
|
# `d` could be 1 or 3 depends on `multimask_output`.
|
|
252
274
|
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
|
|
253
275
|
|
|
254
|
-
def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None):
|
|
255
|
-
"""
|
|
256
|
-
Prepares and transforms the input prompts for processing based on the destination shape.
|
|
276
|
+
def _prepare_prompts(self, dst_shape, src_shape, bboxes=None, points=None, labels=None, masks=None):
|
|
277
|
+
"""Prepare and transform the input prompts for processing based on the destination shape.
|
|
257
278
|
|
|
258
279
|
Args:
|
|
259
|
-
dst_shape (tuple): The target shape (height, width) for the prompts.
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
280
|
+
dst_shape (tuple[int, int]): The target shape (height, width) for the prompts.
|
|
281
|
+
src_shape (tuple[int, int]): The source shape (height, width) of the input image.
|
|
282
|
+
bboxes (np.ndarray | list | None): Bounding boxes in XYXY format with shape (N, 4).
|
|
283
|
+
points (np.ndarray | list | None): Points indicating object locations with shape (N, 2) or (N, num_points,
|
|
284
|
+
2), in pixels.
|
|
285
|
+
labels (np.ndarray | list | None): Point prompt labels with shape (N) or (N, num_points). 1 for foreground,
|
|
286
|
+
0 for background.
|
|
287
|
+
masks (list[np.ndarray] | np.ndarray | None): Masks for the objects, where each mask is a 2D array with
|
|
288
|
+
shape (H, W).
|
|
289
|
+
|
|
290
|
+
Returns:
|
|
291
|
+
bboxes (torch.Tensor | None): Transformed bounding boxes.
|
|
292
|
+
points (torch.Tensor | None): Transformed points.
|
|
293
|
+
labels (torch.Tensor | None): Transformed labels.
|
|
294
|
+
masks (torch.Tensor | None): Transformed masks.
|
|
264
295
|
|
|
265
296
|
Raises:
|
|
266
297
|
AssertionError: If the number of points don't match the number of labels, in case labels were passed.
|
|
267
|
-
|
|
268
|
-
Returns:
|
|
269
|
-
(tuple): A tuple containing transformed bounding boxes, points, labels, and masks.
|
|
270
298
|
"""
|
|
271
|
-
src_shape = self.batch[1][0].shape[:2]
|
|
272
299
|
r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
|
|
273
300
|
# Transform input prompts
|
|
274
301
|
if points is not None:
|
|
275
|
-
points = torch.as_tensor(points, dtype=
|
|
302
|
+
points = torch.as_tensor(points, dtype=self.torch_dtype, device=self.device)
|
|
276
303
|
points = points[None] if points.ndim == 1 else points
|
|
277
304
|
# Assuming labels are all positive if users don't pass labels.
|
|
278
305
|
if labels is None:
|
|
@@ -286,11 +313,15 @@ class Predictor(BasePredictor):
|
|
|
286
313
|
# (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)
|
|
287
314
|
points, labels = points[:, None, :], labels[:, None]
|
|
288
315
|
if bboxes is not None:
|
|
289
|
-
bboxes = torch.as_tensor(bboxes, dtype=
|
|
316
|
+
bboxes = torch.as_tensor(bboxes, dtype=self.torch_dtype, device=self.device)
|
|
290
317
|
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
|
|
291
318
|
bboxes *= r
|
|
292
319
|
if masks is not None:
|
|
293
|
-
masks =
|
|
320
|
+
masks = np.asarray(masks, dtype=np.uint8)
|
|
321
|
+
masks = masks[None] if masks.ndim == 2 else masks
|
|
322
|
+
letterbox = LetterBox(dst_shape, auto=False, center=False, padding_value=0, interpolation=cv2.INTER_NEAREST)
|
|
323
|
+
masks = np.stack([letterbox(image=x).squeeze() for x in masks], axis=0)
|
|
324
|
+
masks = torch.tensor(masks, dtype=self.torch_dtype, device=self.device)
|
|
294
325
|
return bboxes, points, labels, masks
|
|
295
326
|
|
|
296
327
|
def generate(
|
|
@@ -307,18 +338,17 @@ class Predictor(BasePredictor):
|
|
|
307
338
|
stability_score_offset=0.95,
|
|
308
339
|
crop_nms_thresh=0.7,
|
|
309
340
|
):
|
|
310
|
-
"""
|
|
311
|
-
Perform image segmentation using the Segment Anything Model (SAM).
|
|
341
|
+
"""Perform image segmentation using the Segment Anything Model (SAM).
|
|
312
342
|
|
|
313
|
-
This method segments an entire image into constituent parts by leveraging SAM's advanced architecture
|
|
314
|
-
|
|
343
|
+
This method segments an entire image into constituent parts by leveraging SAM's advanced architecture and
|
|
344
|
+
real-time performance capabilities. It can optionally work on image crops for finer segmentation.
|
|
315
345
|
|
|
316
346
|
Args:
|
|
317
347
|
im (torch.Tensor): Input tensor representing the preprocessed image with shape (N, C, H, W).
|
|
318
348
|
crop_n_layers (int): Number of layers for additional mask predictions on image crops.
|
|
319
349
|
crop_overlap_ratio (float): Overlap between crops, scaled down in subsequent layers.
|
|
320
350
|
crop_downscale_factor (int): Scaling factor for sampled points-per-side in each layer.
|
|
321
|
-
point_grids (
|
|
351
|
+
point_grids (list[np.ndarray] | None): Custom grids for point sampling normalized to [0,1].
|
|
322
352
|
points_stride (int): Number of points to sample along each side of the image.
|
|
323
353
|
points_batch_size (int): Batch size for the number of points processed simultaneously.
|
|
324
354
|
conf_thres (float): Confidence threshold [0,1] for filtering based on mask quality prediction.
|
|
@@ -390,7 +420,7 @@ class Predictor(BasePredictor):
|
|
|
390
420
|
pred_masks.append(crop_masks)
|
|
391
421
|
pred_bboxes.append(crop_bboxes)
|
|
392
422
|
pred_scores.append(crop_scores)
|
|
393
|
-
region_areas.append(area.expand(
|
|
423
|
+
region_areas.append(area.expand(crop_masks.shape[0]))
|
|
394
424
|
|
|
395
425
|
pred_masks = torch.cat(pred_masks)
|
|
396
426
|
pred_bboxes = torch.cat(pred_bboxes)
|
|
@@ -406,8 +436,7 @@ class Predictor(BasePredictor):
|
|
|
406
436
|
return pred_masks, pred_scores, pred_bboxes
|
|
407
437
|
|
|
408
438
|
def setup_model(self, model=None, verbose=True):
|
|
409
|
-
"""
|
|
410
|
-
Initializes the Segment Anything Model (SAM) for inference.
|
|
439
|
+
"""Initialize the Segment Anything Model (SAM) for inference.
|
|
411
440
|
|
|
412
441
|
This method sets up the SAM model by allocating it to the appropriate device and initializing the necessary
|
|
413
442
|
parameters for image normalization and other Ultralytics compatibility settings.
|
|
@@ -424,7 +453,8 @@ class Predictor(BasePredictor):
|
|
|
424
453
|
if model is None:
|
|
425
454
|
model = self.get_model()
|
|
426
455
|
model.eval()
|
|
427
|
-
|
|
456
|
+
model = model.to(device)
|
|
457
|
+
self.model = model.half() if self.args.half else model.float()
|
|
428
458
|
self.device = device
|
|
429
459
|
self.mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1).to(device)
|
|
430
460
|
self.std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1).to(device)
|
|
@@ -433,33 +463,33 @@ class Predictor(BasePredictor):
|
|
|
433
463
|
self.model.pt = False
|
|
434
464
|
self.model.triton = False
|
|
435
465
|
self.model.stride = 32
|
|
436
|
-
self.model.fp16 =
|
|
466
|
+
self.model.fp16 = self.args.half
|
|
437
467
|
self.done_warmup = True
|
|
468
|
+
self.torch_dtype = torch.float16 if self.model.fp16 else torch.float32
|
|
438
469
|
|
|
439
470
|
def get_model(self):
|
|
440
|
-
"""
|
|
471
|
+
"""Retrieve or build the Segment Anything Model (SAM) for image segmentation tasks."""
|
|
441
472
|
from .build import build_sam # slow import
|
|
442
473
|
|
|
443
474
|
return build_sam(self.args.model)
|
|
444
475
|
|
|
445
476
|
def postprocess(self, preds, img, orig_imgs):
|
|
446
|
-
"""
|
|
447
|
-
Post-processes SAM's inference outputs to generate object detection masks and bounding boxes.
|
|
477
|
+
"""Post-process SAM's inference outputs to generate object detection masks and bounding boxes.
|
|
448
478
|
|
|
449
479
|
This method scales masks and boxes to the original image size and applies a threshold to the mask
|
|
450
480
|
predictions. It leverages SAM's advanced architecture for real-time, promptable segmentation tasks.
|
|
451
481
|
|
|
452
482
|
Args:
|
|
453
|
-
preds (
|
|
483
|
+
preds (tuple): The output from SAM model inference, containing:
|
|
454
484
|
- pred_masks (torch.Tensor): Predicted masks with shape (N, 1, H, W).
|
|
455
485
|
- pred_scores (torch.Tensor): Confidence scores for each mask with shape (N, 1).
|
|
456
486
|
- pred_bboxes (torch.Tensor, optional): Predicted bounding boxes if segment_all is True.
|
|
457
487
|
img (torch.Tensor): The processed input image tensor with shape (C, H, W).
|
|
458
|
-
orig_imgs (
|
|
488
|
+
orig_imgs (list[np.ndarray] | torch.Tensor): The original, unprocessed images.
|
|
459
489
|
|
|
460
490
|
Returns:
|
|
461
|
-
|
|
462
|
-
|
|
491
|
+
(list[Results]): List of Results objects containing detection masks, bounding boxes, and other metadata for
|
|
492
|
+
each processed image.
|
|
463
493
|
|
|
464
494
|
Examples:
|
|
465
495
|
>>> predictor = Predictor()
|
|
@@ -469,14 +499,14 @@ class Predictor(BasePredictor):
|
|
|
469
499
|
# (N, 1, H, W), (N, 1)
|
|
470
500
|
pred_masks, pred_scores = preds[:2]
|
|
471
501
|
pred_bboxes = preds[2] if self.segment_all else None
|
|
472
|
-
names = dict(enumerate(str(i) for i in range(
|
|
502
|
+
names = dict(enumerate(str(i) for i in range(pred_masks.shape[0])))
|
|
473
503
|
|
|
474
504
|
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
|
475
505
|
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
|
476
506
|
|
|
477
507
|
results = []
|
|
478
508
|
for masks, orig_img, img_path in zip([pred_masks], orig_imgs, self.batch[0]):
|
|
479
|
-
if
|
|
509
|
+
if masks.shape[0] == 0:
|
|
480
510
|
masks, pred_bboxes = None, torch.zeros((0, 6), device=pred_masks.device)
|
|
481
511
|
else:
|
|
482
512
|
masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0]
|
|
@@ -486,23 +516,24 @@ class Predictor(BasePredictor):
|
|
|
486
516
|
else:
|
|
487
517
|
pred_bboxes = batched_mask_to_box(masks)
|
|
488
518
|
# NOTE: SAM models do not return cls info. This `cls` here is just a placeholder for consistency.
|
|
489
|
-
cls = torch.arange(
|
|
490
|
-
|
|
519
|
+
cls = torch.arange(pred_masks.shape[0], dtype=torch.int32, device=pred_masks.device)
|
|
520
|
+
idx = pred_scores > self.args.conf
|
|
521
|
+
pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1)[idx]
|
|
522
|
+
masks = masks[idx]
|
|
491
523
|
results.append(Results(orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes))
|
|
492
524
|
# Reset segment-all mode.
|
|
493
525
|
self.segment_all = False
|
|
494
526
|
return results
|
|
495
527
|
|
|
496
528
|
def setup_source(self, source):
|
|
497
|
-
"""
|
|
498
|
-
Sets up the data source for inference.
|
|
529
|
+
"""Set up the data source for inference.
|
|
499
530
|
|
|
500
|
-
This method configures the data source from which images will be fetched for inference. It supports
|
|
501
|
-
|
|
531
|
+
This method configures the data source from which images will be fetched for inference. It supports various
|
|
532
|
+
input types such as image files, directories, video files, and other compatible data sources.
|
|
502
533
|
|
|
503
534
|
Args:
|
|
504
|
-
source (str | Path | None): The path or identifier for the image data source. Can be a file path,
|
|
505
|
-
|
|
535
|
+
source (str | Path | None): The path or identifier for the image data source. Can be a file path, directory
|
|
536
|
+
path, URL, or other supported source types.
|
|
506
537
|
|
|
507
538
|
Examples:
|
|
508
539
|
>>> predictor = Predictor()
|
|
@@ -519,16 +550,15 @@ class Predictor(BasePredictor):
|
|
|
519
550
|
super().setup_source(source)
|
|
520
551
|
|
|
521
552
|
def set_image(self, image):
|
|
522
|
-
"""
|
|
523
|
-
Preprocesses and sets a single image for inference.
|
|
553
|
+
"""Preprocess and set a single image for inference.
|
|
524
554
|
|
|
525
555
|
This method prepares the model for inference on a single image by setting up the model if not already
|
|
526
|
-
initialized, configuring the data source, and preprocessing the image for feature extraction. It
|
|
527
|
-
|
|
556
|
+
initialized, configuring the data source, and preprocessing the image for feature extraction. It ensures that
|
|
557
|
+
only one image is set at a time and extracts image features for subsequent use.
|
|
528
558
|
|
|
529
559
|
Args:
|
|
530
|
-
image (str | np.ndarray): Path to the image file as a string, or a numpy array representing
|
|
531
|
-
|
|
560
|
+
image (str | np.ndarray): Path to the image file as a string, or a numpy array representing an image read by
|
|
561
|
+
cv2.
|
|
532
562
|
|
|
533
563
|
Raises:
|
|
534
564
|
AssertionError: If more than one image is attempted to be set.
|
|
@@ -543,7 +573,7 @@ class Predictor(BasePredictor):
|
|
|
543
573
|
- The extracted features are stored in the `self.features` attribute for later use.
|
|
544
574
|
"""
|
|
545
575
|
if self.model is None:
|
|
546
|
-
self.setup_model(
|
|
576
|
+
self.setup_model()
|
|
547
577
|
self.setup_source(image)
|
|
548
578
|
assert len(self.dataset) == 1, "`set_image` only supports setting one image!"
|
|
549
579
|
for batch in self.dataset:
|
|
@@ -552,7 +582,7 @@ class Predictor(BasePredictor):
|
|
|
552
582
|
break
|
|
553
583
|
|
|
554
584
|
def get_im_features(self, im):
|
|
555
|
-
"""
|
|
585
|
+
"""Extract image features using the SAM model's image encoder for subsequent mask prediction."""
|
|
556
586
|
assert isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1], (
|
|
557
587
|
f"SAM models only support square image size, but got {self.imgsz}."
|
|
558
588
|
)
|
|
@@ -560,22 +590,21 @@ class Predictor(BasePredictor):
|
|
|
560
590
|
return self.model.image_encoder(im)
|
|
561
591
|
|
|
562
592
|
def set_prompts(self, prompts):
|
|
563
|
-
"""
|
|
593
|
+
"""Set prompts for subsequent inference operations."""
|
|
564
594
|
self.prompts = prompts
|
|
565
595
|
|
|
566
596
|
def reset_image(self):
|
|
567
|
-
"""
|
|
597
|
+
"""Reset the current image and its features, clearing them for subsequent inference."""
|
|
568
598
|
self.im = None
|
|
569
599
|
self.features = None
|
|
570
600
|
|
|
571
601
|
@staticmethod
|
|
572
602
|
def remove_small_regions(masks, min_area=0, nms_thresh=0.7):
|
|
573
|
-
"""
|
|
574
|
-
Remove small disconnected regions and holes from segmentation masks.
|
|
603
|
+
"""Remove small disconnected regions and holes from segmentation masks.
|
|
575
604
|
|
|
576
|
-
This function performs post-processing on segmentation masks generated by the Segment Anything Model (SAM).
|
|
577
|
-
|
|
578
|
-
|
|
605
|
+
This function performs post-processing on segmentation masks generated by the Segment Anything Model (SAM). It
|
|
606
|
+
removes small disconnected regions and holes from the input masks, and then performs Non-Maximum Suppression
|
|
607
|
+
(NMS) to eliminate any newly created duplicate boxes.
|
|
579
608
|
|
|
580
609
|
Args:
|
|
581
610
|
masks (torch.Tensor): Segmentation masks to be processed, with shape (N, H, W) where N is the number of
|
|
@@ -586,7 +615,7 @@ class Predictor(BasePredictor):
|
|
|
586
615
|
|
|
587
616
|
Returns:
|
|
588
617
|
new_masks (torch.Tensor): Processed masks with small regions removed, shape (N, H, W).
|
|
589
|
-
keep (
|
|
618
|
+
keep (list[int]): Indices of remaining masks after NMS, for filtering corresponding boxes.
|
|
590
619
|
|
|
591
620
|
Examples:
|
|
592
621
|
>>> masks = torch.rand(5, 640, 640) > 0.5 # 5 random binary masks
|
|
@@ -596,7 +625,7 @@ class Predictor(BasePredictor):
|
|
|
596
625
|
"""
|
|
597
626
|
import torchvision # scope for faster 'import ultralytics'
|
|
598
627
|
|
|
599
|
-
if
|
|
628
|
+
if masks.shape[0] == 0:
|
|
600
629
|
return masks
|
|
601
630
|
|
|
602
631
|
# Filter small disconnected regions and holes
|
|
@@ -620,28 +649,74 @@ class Predictor(BasePredictor):
|
|
|
620
649
|
|
|
621
650
|
return new_masks[keep].to(device=masks.device, dtype=masks.dtype), keep
|
|
622
651
|
|
|
652
|
+
@smart_inference_mode()
|
|
653
|
+
def inference_features(
|
|
654
|
+
self,
|
|
655
|
+
features,
|
|
656
|
+
src_shape,
|
|
657
|
+
dst_shape=None,
|
|
658
|
+
bboxes=None,
|
|
659
|
+
points=None,
|
|
660
|
+
labels=None,
|
|
661
|
+
masks=None,
|
|
662
|
+
multimask_output=False,
|
|
663
|
+
):
|
|
664
|
+
"""Perform prompts preprocessing and inference on provided image features using the SAM model.
|
|
665
|
+
|
|
666
|
+
Args:
|
|
667
|
+
features (torch.Tensor | dict[str, Any]): Extracted image features from the SAM/SAM2 model image encoder.
|
|
668
|
+
src_shape (tuple[int, int]): The source shape (height, width) of the input image.
|
|
669
|
+
dst_shape (tuple[int, int] | None): The target shape (height, width) for the prompts. If None, defaults to
|
|
670
|
+
(imgsz, imgsz).
|
|
671
|
+
bboxes (np.ndarray | list[list[float]] | None): Bounding boxes in xyxy format with shape (N, 4).
|
|
672
|
+
points (np.ndarray | list[list[float]] | None): Points indicating object locations with shape (N, 2), in
|
|
673
|
+
pixels.
|
|
674
|
+
labels (np.ndarray | list[int] | None): Point prompt labels with shape (N, ).
|
|
675
|
+
masks (list[np.ndarray] | np.ndarray | None): Masks for the objects, where each mask is a 2D array.
|
|
676
|
+
multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
|
|
677
|
+
|
|
678
|
+
Returns:
|
|
679
|
+
pred_masks (torch.Tensor): The output masks in shape (C, H, W), where C is the number of generated masks.
|
|
680
|
+
pred_bboxes (torch.Tensor): Bounding boxes for each mask with shape (N, 6), where N is the number of boxes.
|
|
681
|
+
Each box is in xyxy format with additional columns for score and class.
|
|
682
|
+
|
|
683
|
+
Notes:
|
|
684
|
+
- The input features is a torch.Tensor of shape (B, C, H, W) if performing on SAM, or a dict[str, Any] if performing on SAM2.
|
|
685
|
+
"""
|
|
686
|
+
dst_shape = dst_shape or (self.args.imgsz, self.args.imgsz)
|
|
687
|
+
prompts = self._prepare_prompts(dst_shape, src_shape, bboxes, points, labels, masks)
|
|
688
|
+
pred_masks, pred_scores = self._inference_features(features, *prompts, multimask_output)
|
|
689
|
+
if pred_masks.shape[0] == 0:
|
|
690
|
+
pred_masks, pred_bboxes = None, torch.zeros((0, 6), device=pred_masks.device)
|
|
691
|
+
else:
|
|
692
|
+
pred_masks = ops.scale_masks(pred_masks[None].float(), src_shape, padding=False)[0]
|
|
693
|
+
pred_masks = pred_masks > self.model.mask_threshold # to bool
|
|
694
|
+
pred_bboxes = batched_mask_to_box(pred_masks)
|
|
695
|
+
# NOTE: SAM models do not return cls info. This `cls` here is just a placeholder for consistency.
|
|
696
|
+
cls = torch.arange(pred_masks.shape[0], dtype=torch.int32, device=pred_masks.device)
|
|
697
|
+
pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1)
|
|
698
|
+
return pred_masks, pred_bboxes
|
|
699
|
+
|
|
623
700
|
|
|
624
701
|
class SAM2Predictor(Predictor):
|
|
625
|
-
"""
|
|
626
|
-
SAM2Predictor class for advanced image segmentation using Segment Anything Model 2 architecture.
|
|
702
|
+
"""SAM2Predictor class for advanced image segmentation using Segment Anything Model 2 architecture.
|
|
627
703
|
|
|
628
|
-
This class extends the base Predictor class to implement SAM2-specific functionality for image
|
|
629
|
-
|
|
630
|
-
prompt-based inference.
|
|
704
|
+
This class extends the base Predictor class to implement SAM2-specific functionality for image segmentation tasks.
|
|
705
|
+
It provides methods for model initialization, feature extraction, and prompt-based inference.
|
|
631
706
|
|
|
632
707
|
Attributes:
|
|
633
|
-
_bb_feat_sizes (
|
|
708
|
+
_bb_feat_sizes (list[tuple]): Feature sizes for different backbone levels.
|
|
634
709
|
model (torch.nn.Module): The loaded SAM2 model.
|
|
635
710
|
device (torch.device): The device (CPU or GPU) on which the model is loaded.
|
|
636
|
-
features (
|
|
711
|
+
features (dict): Cached image features for efficient inference.
|
|
637
712
|
segment_all (bool): Flag to indicate if all segments should be predicted.
|
|
638
|
-
prompts (dict): Dictionary to store various types of prompts for inference.
|
|
713
|
+
prompts (dict[str, Any]): Dictionary to store various types of prompts for inference.
|
|
639
714
|
|
|
640
715
|
Methods:
|
|
641
|
-
get_model:
|
|
642
|
-
prompt_inference:
|
|
643
|
-
set_image:
|
|
644
|
-
get_im_features:
|
|
716
|
+
get_model: Retrieve and initialize the SAM2 model.
|
|
717
|
+
prompt_inference: Perform image segmentation inference based on various prompts.
|
|
718
|
+
set_image: Preprocess and set a single image for inference.
|
|
719
|
+
get_im_features: Extract and process image features using SAM2's image encoder.
|
|
645
720
|
|
|
646
721
|
Examples:
|
|
647
722
|
>>> predictor = SAM2Predictor(cfg)
|
|
@@ -658,100 +733,36 @@ class SAM2Predictor(Predictor):
|
|
|
658
733
|
]
|
|
659
734
|
|
|
660
735
|
def get_model(self):
|
|
661
|
-
"""
|
|
736
|
+
"""Retrieve and initialize the Segment Anything Model 2 (SAM2) for image segmentation tasks."""
|
|
662
737
|
from .build import build_sam # slow import
|
|
663
738
|
|
|
664
739
|
return build_sam(self.args.model)
|
|
665
740
|
|
|
666
|
-
def
|
|
667
|
-
|
|
668
|
-
im,
|
|
669
|
-
bboxes=None,
|
|
670
|
-
points=None,
|
|
671
|
-
labels=None,
|
|
672
|
-
masks=None,
|
|
673
|
-
multimask_output=False,
|
|
674
|
-
img_idx=-1,
|
|
675
|
-
):
|
|
676
|
-
"""
|
|
677
|
-
Performs image segmentation inference based on various prompts using SAM2 architecture.
|
|
678
|
-
|
|
679
|
-
This method leverages the Segment Anything Model 2 (SAM2) to generate segmentation masks for input images
|
|
680
|
-
based on provided prompts such as bounding boxes, points, or existing masks. It supports both single and
|
|
681
|
-
multi-object prediction scenarios.
|
|
741
|
+
def _prepare_prompts(self, dst_shape, src_shape, bboxes=None, points=None, labels=None, masks=None):
|
|
742
|
+
"""Prepare and transform the input prompts for processing based on the destination shape.
|
|
682
743
|
|
|
683
744
|
Args:
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
745
|
+
dst_shape (tuple[int, int]): The target shape (height, width) for the prompts.
|
|
746
|
+
src_shape (tuple[int, int]): The source shape (height, width) of the input image.
|
|
747
|
+
bboxes (np.ndarray | list | None): Bounding boxes in XYXY format with shape (N, 4).
|
|
748
|
+
points (np.ndarray | list | None): Points indicating object locations with shape (N, 2) or (N, num_points,
|
|
749
|
+
2), in pixels.
|
|
750
|
+
labels (np.ndarray | list | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground,
|
|
751
|
+
0 for background.
|
|
752
|
+
masks (list | np.ndarray | None): Masks for the objects, where each mask is a 2D array.
|
|
691
753
|
|
|
692
754
|
Returns:
|
|
693
|
-
(
|
|
694
|
-
(
|
|
695
|
-
|
|
696
|
-
Examples:
|
|
697
|
-
>>> predictor = SAM2Predictor(cfg)
|
|
698
|
-
>>> image = torch.rand(1, 3, 640, 640)
|
|
699
|
-
>>> bboxes = [[100, 100, 200, 200]]
|
|
700
|
-
>>> result = predictor(image, bboxes=bboxes)[0]
|
|
701
|
-
>>> print(f"Generated {result.masks.shape[0]} masks with average score {result.boxes.conf.mean():.2f}")
|
|
702
|
-
|
|
703
|
-
Notes:
|
|
704
|
-
- The method supports batched inference for multiple objects when points or bboxes are provided.
|
|
705
|
-
- Input prompts (bboxes, points) are automatically scaled to match the input image dimensions.
|
|
706
|
-
- When both bboxes and points are provided, they are merged into a single 'points' input for the model.
|
|
707
|
-
"""
|
|
708
|
-
features = self.get_im_features(im) if self.features is None else self.features
|
|
709
|
-
|
|
710
|
-
points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)
|
|
711
|
-
points = (points, labels) if points is not None else None
|
|
712
|
-
|
|
713
|
-
sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
|
|
714
|
-
points=points,
|
|
715
|
-
boxes=None,
|
|
716
|
-
masks=masks,
|
|
717
|
-
)
|
|
718
|
-
# Predict masks
|
|
719
|
-
batched_mode = points is not None and points[0].shape[0] > 1 # multi object prediction
|
|
720
|
-
high_res_features = [feat_level[img_idx].unsqueeze(0) for feat_level in features["high_res_feats"]]
|
|
721
|
-
pred_masks, pred_scores, _, _ = self.model.sam_mask_decoder(
|
|
722
|
-
image_embeddings=features["image_embed"][img_idx].unsqueeze(0),
|
|
723
|
-
image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
|
|
724
|
-
sparse_prompt_embeddings=sparse_embeddings,
|
|
725
|
-
dense_prompt_embeddings=dense_embeddings,
|
|
726
|
-
multimask_output=multimask_output,
|
|
727
|
-
repeat_image=batched_mode,
|
|
728
|
-
high_res_features=high_res_features,
|
|
729
|
-
)
|
|
730
|
-
# (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )
|
|
731
|
-
# `d` could be 1 or 3 depends on `multimask_output`.
|
|
732
|
-
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
|
|
733
|
-
|
|
734
|
-
def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None):
|
|
735
|
-
"""
|
|
736
|
-
Prepares and transforms the input prompts for processing based on the destination shape.
|
|
737
|
-
|
|
738
|
-
Args:
|
|
739
|
-
dst_shape (tuple): The target shape (height, width) for the prompts.
|
|
740
|
-
bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
|
|
741
|
-
points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels.
|
|
742
|
-
labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background.
|
|
743
|
-
masks (List | np.ndarray, Optional): Masks for the objects, where each mask is a 2D array.
|
|
755
|
+
points (torch.Tensor | None): Transformed points.
|
|
756
|
+
labels (torch.Tensor | None): Transformed labels.
|
|
757
|
+
masks (torch.Tensor | None): Transformed masks.
|
|
744
758
|
|
|
745
759
|
Raises:
|
|
746
760
|
AssertionError: If the number of points don't match the number of labels, in case labels were passed.
|
|
747
|
-
|
|
748
|
-
Returns:
|
|
749
|
-
(tuple): A tuple containing transformed points, labels, and masks.
|
|
750
761
|
"""
|
|
751
|
-
bboxes, points, labels, masks = super()._prepare_prompts(dst_shape, bboxes, points, labels, masks)
|
|
762
|
+
bboxes, points, labels, masks = super()._prepare_prompts(dst_shape, src_shape, bboxes, points, labels, masks)
|
|
752
763
|
if bboxes is not None:
|
|
753
764
|
bboxes = bboxes.view(-1, 2, 2)
|
|
754
|
-
bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(
|
|
765
|
+
bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(bboxes.shape[0], -1)
|
|
755
766
|
# NOTE: merge "boxes" and "points" into a single "points" input
|
|
756
767
|
# (where boxes are added at the beginning) to model.sam_prompt_encoder
|
|
757
768
|
if points is not None:
|
|
@@ -762,11 +773,10 @@ class SAM2Predictor(Predictor):
|
|
|
762
773
|
return points, labels, masks
|
|
763
774
|
|
|
764
775
|
def set_image(self, image):
|
|
765
|
-
"""
|
|
766
|
-
Preprocesses and sets a single image for inference using the SAM2 model.
|
|
776
|
+
"""Preprocess and set a single image for inference using the SAM2 model.
|
|
767
777
|
|
|
768
|
-
This method initializes the model if not already done, configures the data source to the specified image,
|
|
769
|
-
|
|
778
|
+
This method initializes the model if not already done, configures the data source to the specified image, and
|
|
779
|
+
preprocesses the image for feature extraction. It supports setting only one image at a time.
|
|
770
780
|
|
|
771
781
|
Args:
|
|
772
782
|
image (str | np.ndarray): Path to the image file as a string, or a numpy array representing the image.
|
|
@@ -794,7 +804,7 @@ class SAM2Predictor(Predictor):
|
|
|
794
804
|
break
|
|
795
805
|
|
|
796
806
|
def get_im_features(self, im):
|
|
797
|
-
"""
|
|
807
|
+
"""Extract image features from the SAM image encoder for subsequent processing."""
|
|
798
808
|
assert isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1], (
|
|
799
809
|
f"SAM 2 models only support square image size, but got {self.imgsz}."
|
|
800
810
|
)
|
|
@@ -806,50 +816,108 @@ class SAM2Predictor(Predictor):
|
|
|
806
816
|
if self.model.directly_add_no_mem_embed:
|
|
807
817
|
vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
|
|
808
818
|
feats = [
|
|
809
|
-
feat.permute(1, 2, 0).view(1, -1, *feat_size)
|
|
810
|
-
|
|
811
|
-
][::-1]
|
|
819
|
+
feat.permute(1, 2, 0).view(1, -1, *feat_size) for feat, feat_size in zip(vision_feats, self._bb_feat_sizes)
|
|
820
|
+
]
|
|
812
821
|
return {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
|
|
813
822
|
|
|
823
|
+
def _inference_features(
|
|
824
|
+
self,
|
|
825
|
+
features,
|
|
826
|
+
points=None,
|
|
827
|
+
labels=None,
|
|
828
|
+
masks=None,
|
|
829
|
+
multimask_output=False,
|
|
830
|
+
img_idx=-1,
|
|
831
|
+
):
|
|
832
|
+
"""Perform inference on image features using the SAM2 model.
|
|
833
|
+
|
|
834
|
+
Args:
|
|
835
|
+
features (torch.Tensor | dict[str, Any]): Extracted image features with shape (B, C, H, W) from the SAM2
|
|
836
|
+
model image encoder, it could also be a dictionary including:
|
|
837
|
+
- image_embed (torch.Tensor): Image embedding with shape (B, C, H, W).
|
|
838
|
+
- high_res_feats (list[torch.Tensor]): List of high-resolution feature maps from the backbone, each with shape (B, C, H, W).
|
|
839
|
+
points (np.ndarray | list[list[float]] | None): Object location points with shape (N, 2), in pixels.
|
|
840
|
+
labels (np.ndarray | list[int] | None): Point prompt labels with shape (N,). 1 = foreground, 0 = background.
|
|
841
|
+
masks (list[np.ndarray] | np.ndarray | None): Masks for the objects, where each mask is a 2D array.
|
|
842
|
+
multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
|
|
843
|
+
img_idx (int): Index of the image in the batch to process.
|
|
844
|
+
|
|
845
|
+
Returns:
|
|
846
|
+
pred_masks (torch.Tensor): Output masks with shape (C, H, W), where C is the number of generated masks.
|
|
847
|
+
pred_scores (torch.Tensor): Quality scores for each mask, with length C.
|
|
848
|
+
"""
|
|
849
|
+
points = (points, labels) if points is not None else None
|
|
850
|
+
sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
|
|
851
|
+
points=points,
|
|
852
|
+
boxes=None,
|
|
853
|
+
masks=masks,
|
|
854
|
+
)
|
|
855
|
+
# Predict masks
|
|
856
|
+
batched_mode = points is not None and points[0].shape[0] > 1 # multi object prediction
|
|
857
|
+
high_res_features = None
|
|
858
|
+
if isinstance(features, dict):
|
|
859
|
+
high_res_features = [feat_level[img_idx].unsqueeze(0) for feat_level in features["high_res_feats"]]
|
|
860
|
+
features = features["image_embed"][[img_idx]]
|
|
861
|
+
pred_masks, pred_scores, _, _ = self.model.sam_mask_decoder(
|
|
862
|
+
image_embeddings=features,
|
|
863
|
+
image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
|
|
864
|
+
sparse_prompt_embeddings=sparse_embeddings,
|
|
865
|
+
dense_prompt_embeddings=dense_embeddings,
|
|
866
|
+
multimask_output=multimask_output,
|
|
867
|
+
repeat_image=batched_mode,
|
|
868
|
+
high_res_features=high_res_features,
|
|
869
|
+
)
|
|
870
|
+
# (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )
|
|
871
|
+
# `d` could be 1 or 3 depends on `multimask_output`.
|
|
872
|
+
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
|
|
873
|
+
|
|
814
874
|
|
|
815
875
|
class SAM2VideoPredictor(SAM2Predictor):
|
|
816
|
-
"""
|
|
817
|
-
SAM2VideoPredictor to handle user interactions with videos and manage inference states.
|
|
876
|
+
"""SAM2VideoPredictor to handle user interactions with videos and manage inference states.
|
|
818
877
|
|
|
819
|
-
This class extends the functionality of SAM2Predictor to support video processing and maintains
|
|
820
|
-
|
|
821
|
-
|
|
878
|
+
This class extends the functionality of SAM2Predictor to support video processing and maintains the state of
|
|
879
|
+
inference operations. It includes configurations for managing non-overlapping masks, clearing memory for
|
|
880
|
+
non-conditional inputs, and setting up callbacks for prediction events.
|
|
822
881
|
|
|
823
882
|
Attributes:
|
|
824
883
|
inference_state (dict): A dictionary to store the current state of inference operations.
|
|
825
884
|
non_overlap_masks (bool): A flag indicating whether masks should be non-overlapping.
|
|
826
885
|
clear_non_cond_mem_around_input (bool): A flag to control clearing non-conditional memory around inputs.
|
|
827
|
-
clear_non_cond_mem_for_multi_obj (bool): A flag to control clearing non-conditional memory for multi-object
|
|
886
|
+
clear_non_cond_mem_for_multi_obj (bool): A flag to control clearing non-conditional memory for multi-object
|
|
887
|
+
scenarios.
|
|
828
888
|
callbacks (dict): A dictionary of callbacks for various prediction lifecycle events.
|
|
829
889
|
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
890
|
+
Methods:
|
|
891
|
+
get_model: Retrieve and configure the model with binarization enabled.
|
|
892
|
+
inference: Perform image segmentation inference based on the given input cues.
|
|
893
|
+
postprocess: Post-process the predictions to apply non-overlapping constraints if required.
|
|
894
|
+
add_new_prompts: Add new points or masks to a specific frame for a given object ID.
|
|
895
|
+
propagate_in_video_preflight: Prepare inference_state and consolidate temporary outputs before tracking.
|
|
896
|
+
init_state: Initialize an inference state for the predictor.
|
|
897
|
+
get_im_features: Extract image features using SAM2's image encoder for subsequent segmentation tasks.
|
|
898
|
+
|
|
899
|
+
Examples:
|
|
900
|
+
>>> predictor = SAM2VideoPredictor(cfg=DEFAULT_CFG)
|
|
901
|
+
>>> predictor.set_image("path/to/video_frame.jpg")
|
|
902
|
+
>>> bboxes = [[100, 100, 200, 200]]
|
|
903
|
+
>>> results = predictor(bboxes=bboxes)
|
|
834
904
|
|
|
835
|
-
|
|
905
|
+
Notes:
|
|
836
906
|
The `fill_hole_area` attribute is defined but not used in the current implementation.
|
|
837
907
|
"""
|
|
838
908
|
|
|
839
909
|
# fill_hole_area = 8 # not used
|
|
840
910
|
|
|
841
911
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
842
|
-
"""
|
|
843
|
-
Initialize the predictor with configuration and optional overrides.
|
|
912
|
+
"""Initialize the predictor with configuration and optional overrides.
|
|
844
913
|
|
|
845
|
-
This constructor initializes the SAM2VideoPredictor with a given configuration, applies any
|
|
846
|
-
|
|
847
|
-
that control the behavior of the predictor.
|
|
914
|
+
This constructor initializes the SAM2VideoPredictor with a given configuration, applies any specified overrides,
|
|
915
|
+
and sets up the inference state along with certain flags that control the behavior of the predictor.
|
|
848
916
|
|
|
849
917
|
Args:
|
|
850
918
|
cfg (dict): Configuration dictionary containing default settings.
|
|
851
|
-
overrides (
|
|
852
|
-
_callbacks (
|
|
919
|
+
overrides (dict | None): Dictionary of values to override default configuration.
|
|
920
|
+
_callbacks (dict | None): Dictionary of callback functions to customize behavior.
|
|
853
921
|
|
|
854
922
|
Examples:
|
|
855
923
|
>>> predictor = SAM2VideoPredictor(cfg=DEFAULT_CFG)
|
|
@@ -864,10 +932,9 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
|
864
932
|
self.callbacks["on_predict_start"].append(self.init_state)
|
|
865
933
|
|
|
866
934
|
def get_model(self):
|
|
867
|
-
"""
|
|
868
|
-
Retrieves and configures the model with binarization enabled.
|
|
935
|
+
"""Retrieve and configure the model with binarization enabled.
|
|
869
936
|
|
|
870
|
-
|
|
937
|
+
Notes:
|
|
871
938
|
This method overrides the base class implementation to set the binarize flag to True.
|
|
872
939
|
"""
|
|
873
940
|
model = super().get_model()
|
|
@@ -875,21 +942,20 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
|
875
942
|
return model
|
|
876
943
|
|
|
877
944
|
def inference(self, im, bboxes=None, points=None, labels=None, masks=None):
|
|
878
|
-
"""
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
mask decoder for real-time and promptable segmentation tasks.
|
|
945
|
+
"""Perform image segmentation inference based on the given input cues, using the currently loaded image. This
|
|
946
|
+
method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt
|
|
947
|
+
encoder, and mask decoder for real-time and promptable segmentation tasks.
|
|
882
948
|
|
|
883
949
|
Args:
|
|
884
950
|
im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).
|
|
885
|
-
bboxes (np.ndarray |
|
|
886
|
-
points (np.ndarray |
|
|
887
|
-
labels (np.ndarray |
|
|
951
|
+
bboxes (np.ndarray | list, optional): Bounding boxes with shape (N, 4), in XYXY format.
|
|
952
|
+
points (np.ndarray | list, optional): Points indicating object locations with shape (N, 2), in pixels.
|
|
953
|
+
labels (np.ndarray | list, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
|
|
888
954
|
masks (np.ndarray, optional): Low-resolution masks from previous predictions shape (N,H,W). For SAM H=W=256.
|
|
889
955
|
|
|
890
956
|
Returns:
|
|
891
|
-
(
|
|
892
|
-
(
|
|
957
|
+
pred_masks (torch.Tensor): The output masks in shape CxHxW, where C is the number of generated masks.
|
|
958
|
+
pred_scores (torch.Tensor): An array of length C containing predicted quality scores for each mask.
|
|
893
959
|
"""
|
|
894
960
|
# Override prompts if any stored in self.prompts
|
|
895
961
|
bboxes = self.prompts.pop("bboxes", bboxes)
|
|
@@ -900,7 +966,9 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
|
900
966
|
self.inference_state["im"] = im
|
|
901
967
|
output_dict = self.inference_state["output_dict"]
|
|
902
968
|
if len(output_dict["cond_frame_outputs"]) == 0: # initialize prompts
|
|
903
|
-
points, labels, masks = self._prepare_prompts(
|
|
969
|
+
points, labels, masks = self._prepare_prompts(
|
|
970
|
+
im.shape[2:], self.batch[1][0].shape[:2], bboxes, points, labels, masks
|
|
971
|
+
)
|
|
904
972
|
if points is not None:
|
|
905
973
|
for i in range(len(points)):
|
|
906
974
|
self.add_new_prompts(obj_id=i, points=points[[i]], labels=labels[[i]], frame_idx=frame)
|
|
@@ -943,25 +1011,24 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
|
943
1011
|
pred_masks = current_out["pred_masks"].flatten(0, 1)
|
|
944
1012
|
pred_masks = pred_masks[(pred_masks > self.model.mask_threshold).sum((1, 2)) > 0] # filter blank masks
|
|
945
1013
|
|
|
946
|
-
return pred_masks, torch.ones(
|
|
1014
|
+
return pred_masks, torch.ones(pred_masks.shape[0], dtype=pred_masks.dtype, device=pred_masks.device)
|
|
947
1015
|
|
|
948
1016
|
def postprocess(self, preds, img, orig_imgs):
|
|
949
|
-
"""
|
|
950
|
-
Post-processes the predictions to apply non-overlapping constraints if required.
|
|
1017
|
+
"""Post-process the predictions to apply non-overlapping constraints if required.
|
|
951
1018
|
|
|
952
|
-
This method extends the post-processing functionality by applying non-overlapping constraints
|
|
953
|
-
|
|
954
|
-
|
|
1019
|
+
This method extends the post-processing functionality by applying non-overlapping constraints to the predicted
|
|
1020
|
+
masks if the `non_overlap_masks` flag is set to True. This ensures that the masks do not overlap, which can be
|
|
1021
|
+
useful for certain applications.
|
|
955
1022
|
|
|
956
1023
|
Args:
|
|
957
|
-
preds (
|
|
1024
|
+
preds (tuple[torch.Tensor, torch.Tensor]): The predicted masks and scores from the model.
|
|
958
1025
|
img (torch.Tensor): The processed image tensor.
|
|
959
|
-
orig_imgs (
|
|
1026
|
+
orig_imgs (list[np.ndarray]): The original images before processing.
|
|
960
1027
|
|
|
961
1028
|
Returns:
|
|
962
|
-
|
|
1029
|
+
(list): The post-processed predictions.
|
|
963
1030
|
|
|
964
|
-
|
|
1031
|
+
Notes:
|
|
965
1032
|
If `non_overlap_masks` is True, the method applies constraints to ensure non-overlapping masks.
|
|
966
1033
|
"""
|
|
967
1034
|
results = super().postprocess(preds, img, orig_imgs)
|
|
@@ -981,28 +1048,28 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
|
981
1048
|
masks=None,
|
|
982
1049
|
frame_idx=0,
|
|
983
1050
|
):
|
|
984
|
-
"""
|
|
985
|
-
Adds new points or masks to a specific frame for a given object ID.
|
|
1051
|
+
"""Add new points or masks to a specific frame for a given object ID.
|
|
986
1052
|
|
|
987
|
-
This method updates the inference state with new prompts (points or masks) for a specified
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
|
|
1053
|
+
This method updates the inference state with new prompts (points or masks) for a specified object and frame
|
|
1054
|
+
index. It ensures that the prompts are either points or masks, but not both, and updates the internal state
|
|
1055
|
+
accordingly. It also handles the generation of new segmentations based on the provided prompts and the existing
|
|
1056
|
+
state.
|
|
991
1057
|
|
|
992
1058
|
Args:
|
|
993
1059
|
obj_id (int): The ID of the object to which the prompts are associated.
|
|
994
|
-
points (torch.Tensor,
|
|
995
|
-
labels (torch.Tensor,
|
|
996
|
-
masks (torch.Tensor, optional): Binary masks for the object.
|
|
997
|
-
frame_idx (int, optional): The index of the frame to which the prompts are applied.
|
|
1060
|
+
points (torch.Tensor, optional): The coordinates of the points of interest.
|
|
1061
|
+
labels (torch.Tensor, optional): The labels corresponding to the points.
|
|
1062
|
+
masks (torch.Tensor, optional): Binary masks for the object.
|
|
1063
|
+
frame_idx (int, optional): The index of the frame to which the prompts are applied.
|
|
998
1064
|
|
|
999
1065
|
Returns:
|
|
1000
|
-
(
|
|
1066
|
+
pred_masks (torch.Tensor): The flattened predicted masks.
|
|
1067
|
+
pred_scores (torch.Tensor): A tensor of ones indicating the number of objects.
|
|
1001
1068
|
|
|
1002
1069
|
Raises:
|
|
1003
1070
|
AssertionError: If both `masks` and `points` are provided, or neither is provided.
|
|
1004
1071
|
|
|
1005
|
-
|
|
1072
|
+
Notes:
|
|
1006
1073
|
- Only one type of prompt (either points or masks) can be added per call.
|
|
1007
1074
|
- If the frame is being tracked for the first time, it is treated as an initial conditioning frame.
|
|
1008
1075
|
- The method handles the consolidation of outputs and resizing of masks to the original video resolution.
|
|
@@ -1043,7 +1110,9 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
|
1043
1110
|
)
|
|
1044
1111
|
|
|
1045
1112
|
if prev_out is not None and prev_out.get("pred_masks") is not None:
|
|
1046
|
-
prev_sam_mask_logits = prev_out["pred_masks"].to(
|
|
1113
|
+
prev_sam_mask_logits = prev_out["pred_masks"].to(
|
|
1114
|
+
device=self.device, non_blocking=self.device.type == "cuda"
|
|
1115
|
+
)
|
|
1047
1116
|
# Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
|
|
1048
1117
|
prev_sam_mask_logits.clamp_(-32.0, 32.0)
|
|
1049
1118
|
current_out = self._run_single_frame_inference(
|
|
@@ -1075,13 +1144,12 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
|
1075
1144
|
|
|
1076
1145
|
@smart_inference_mode()
|
|
1077
1146
|
def propagate_in_video_preflight(self):
|
|
1078
|
-
"""
|
|
1079
|
-
Prepare inference_state and consolidate temporary outputs before tracking.
|
|
1147
|
+
"""Prepare inference_state and consolidate temporary outputs before tracking.
|
|
1080
1148
|
|
|
1081
|
-
This method marks the start of tracking, disallowing the addition of new objects until the session is reset.
|
|
1082
|
-
|
|
1083
|
-
|
|
1084
|
-
|
|
1149
|
+
This method marks the start of tracking, disallowing the addition of new objects until the session is reset. It
|
|
1150
|
+
consolidates temporary outputs from `temp_output_dict_per_obj` and merges them into `output_dict`. Additionally,
|
|
1151
|
+
it clears non-conditioning memory around input frames and ensures that the state is consistent with the provided
|
|
1152
|
+
inputs.
|
|
1085
1153
|
"""
|
|
1086
1154
|
# Tracking has started and we don't allow adding new objects until session is reset.
|
|
1087
1155
|
self.inference_state["tracking_has_started"] = True
|
|
@@ -1146,12 +1214,11 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
|
1146
1214
|
|
|
1147
1215
|
@staticmethod
|
|
1148
1216
|
def init_state(predictor):
|
|
1149
|
-
"""
|
|
1150
|
-
Initialize an inference state for the predictor.
|
|
1217
|
+
"""Initialize an inference state for the predictor.
|
|
1151
1218
|
|
|
1152
|
-
This function sets up the initial state required for performing inference on video data.
|
|
1153
|
-
|
|
1154
|
-
|
|
1219
|
+
This function sets up the initial state required for performing inference on video data. It includes
|
|
1220
|
+
initializing various dictionaries and ordered dictionaries that will store inputs, outputs, and other metadata
|
|
1221
|
+
relevant to the tracking process.
|
|
1155
1222
|
|
|
1156
1223
|
Args:
|
|
1157
1224
|
predictor (SAM2VideoPredictor): The predictor object for which to initialize the state.
|
|
@@ -1193,22 +1260,22 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
|
1193
1260
|
predictor.inference_state = inference_state
|
|
1194
1261
|
|
|
1195
1262
|
def get_im_features(self, im, batch=1):
|
|
1196
|
-
"""
|
|
1197
|
-
Extracts and processes image features using SAM2's image encoder for subsequent segmentation tasks.
|
|
1263
|
+
"""Extract and process image features using SAM2's image encoder for subsequent segmentation tasks.
|
|
1198
1264
|
|
|
1199
1265
|
Args:
|
|
1200
1266
|
im (torch.Tensor): The input image tensor.
|
|
1201
|
-
batch (int, optional): The batch size for expanding features if there are multiple prompts.
|
|
1267
|
+
batch (int, optional): The batch size for expanding features if there are multiple prompts.
|
|
1202
1268
|
|
|
1203
1269
|
Returns:
|
|
1204
1270
|
vis_feats (torch.Tensor): The visual features extracted from the image.
|
|
1205
1271
|
vis_pos_embed (torch.Tensor): The positional embeddings for the visual features.
|
|
1206
|
-
feat_sizes (
|
|
1272
|
+
feat_sizes (list[tuple]): A list containing the sizes of the extracted features.
|
|
1207
1273
|
|
|
1208
|
-
|
|
1274
|
+
Notes:
|
|
1209
1275
|
- If `batch` is greater than 1, the features are expanded to fit the batch size.
|
|
1210
1276
|
- The method leverages the model's `_prepare_backbone_features` method to prepare the backbone features.
|
|
1211
1277
|
"""
|
|
1278
|
+
self.model.set_imgsz(self.imgsz)
|
|
1212
1279
|
backbone_out = self.model.forward_image(im)
|
|
1213
1280
|
if batch > 1: # expand features if there's more than one prompt
|
|
1214
1281
|
for i, feat in enumerate(backbone_out["backbone_fpn"]):
|
|
@@ -1220,19 +1287,18 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
|
1220
1287
|
return vis_feats, vis_pos_embed, feat_sizes
|
|
1221
1288
|
|
|
1222
1289
|
def _obj_id_to_idx(self, obj_id):
|
|
1223
|
-
"""
|
|
1224
|
-
Map client-side object id to model-side object index.
|
|
1290
|
+
"""Map client-side object id to model-side object index.
|
|
1225
1291
|
|
|
1226
1292
|
Args:
|
|
1227
1293
|
obj_id (int): The unique identifier of the object provided by the client side.
|
|
1228
1294
|
|
|
1229
1295
|
Returns:
|
|
1230
|
-
|
|
1296
|
+
(int): The index of the object on the model side.
|
|
1231
1297
|
|
|
1232
1298
|
Raises:
|
|
1233
1299
|
RuntimeError: If an attempt is made to add a new object after tracking has started.
|
|
1234
1300
|
|
|
1235
|
-
|
|
1301
|
+
Notes:
|
|
1236
1302
|
- The method updates or retrieves mappings between object IDs and indices stored in
|
|
1237
1303
|
`inference_state`.
|
|
1238
1304
|
- It ensures that new objects can only be added before tracking commences.
|
|
@@ -1283,27 +1349,26 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
|
1283
1349
|
run_mem_encoder,
|
|
1284
1350
|
prev_sam_mask_logits=None,
|
|
1285
1351
|
):
|
|
1286
|
-
"""
|
|
1287
|
-
Run tracking on a single frame based on current inputs and previous memory.
|
|
1352
|
+
"""Run tracking on a single frame based on current inputs and previous memory.
|
|
1288
1353
|
|
|
1289
1354
|
Args:
|
|
1290
1355
|
output_dict (dict): The dictionary containing the output states of the tracking process.
|
|
1291
1356
|
frame_idx (int): The index of the current frame.
|
|
1292
1357
|
batch_size (int): The batch size for processing the frame.
|
|
1293
1358
|
is_init_cond_frame (bool): Indicates if the current frame is an initial conditioning frame.
|
|
1294
|
-
point_inputs (dict
|
|
1295
|
-
mask_inputs (torch.Tensor
|
|
1359
|
+
point_inputs (dict | None): Input points and their labels.
|
|
1360
|
+
mask_inputs (torch.Tensor | None): Input binary masks.
|
|
1296
1361
|
reverse (bool): Indicates if the tracking should be performed in reverse order.
|
|
1297
1362
|
run_mem_encoder (bool): Indicates if the memory encoder should be executed.
|
|
1298
|
-
prev_sam_mask_logits (torch.Tensor
|
|
1363
|
+
prev_sam_mask_logits (torch.Tensor | None): Previous mask logits for the current object.
|
|
1299
1364
|
|
|
1300
1365
|
Returns:
|
|
1301
|
-
|
|
1366
|
+
(dict): A dictionary containing the output of the tracking step, including updated features and predictions.
|
|
1302
1367
|
|
|
1303
1368
|
Raises:
|
|
1304
1369
|
AssertionError: If both `point_inputs` and `mask_inputs` are provided, or neither is provided.
|
|
1305
1370
|
|
|
1306
|
-
|
|
1371
|
+
Notes:
|
|
1307
1372
|
- The method assumes that `point_inputs` and `mask_inputs` are mutually exclusive.
|
|
1308
1373
|
- The method retrieves image features using the `get_im_features` method.
|
|
1309
1374
|
- The `maskmem_pos_enc` is assumed to be constant across frames, hence only one copy is stored.
|
|
@@ -1334,12 +1399,12 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
|
1334
1399
|
maskmem_features = current_out["maskmem_features"]
|
|
1335
1400
|
if maskmem_features is not None:
|
|
1336
1401
|
current_out["maskmem_features"] = maskmem_features.to(
|
|
1337
|
-
dtype=torch.float16, device=self.device, non_blocking=
|
|
1402
|
+
dtype=torch.float16, device=self.device, non_blocking=self.device.type == "cuda"
|
|
1338
1403
|
)
|
|
1339
1404
|
# NOTE: Do not support the `fill_holes_in_mask_scores` function since it needs cuda extensions
|
|
1340
1405
|
# potentially fill holes in the predicted masks
|
|
1341
1406
|
# if self.fill_hole_area > 0:
|
|
1342
|
-
# pred_masks = current_out["pred_masks"].to(self.device, non_blocking=
|
|
1407
|
+
# pred_masks = current_out["pred_masks"].to(self.device, non_blocking=self.device.type == "cuda")
|
|
1343
1408
|
# pred_masks = fill_holes_in_mask_scores(pred_masks, self.fill_hole_area)
|
|
1344
1409
|
|
|
1345
1410
|
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
|
@@ -1347,24 +1412,22 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
|
1347
1412
|
return current_out
|
|
1348
1413
|
|
|
1349
1414
|
def _get_maskmem_pos_enc(self, out_maskmem_pos_enc):
|
|
1350
|
-
"""
|
|
1351
|
-
Caches and manages the positional encoding for mask memory across frames and objects.
|
|
1415
|
+
"""Cache and manage the positional encoding for mask memory across frames and objects.
|
|
1352
1416
|
|
|
1353
|
-
This method optimizes storage by caching the positional encoding (`maskmem_pos_enc`) for
|
|
1354
|
-
|
|
1355
|
-
|
|
1356
|
-
encoding
|
|
1357
|
-
|
|
1358
|
-
the current batch size.
|
|
1417
|
+
This method optimizes storage by caching the positional encoding (`maskmem_pos_enc`) for mask memory, which is
|
|
1418
|
+
constant across frames and objects, thus reducing the amount of redundant information stored during an inference
|
|
1419
|
+
session. It checks if the positional encoding has already been cached; if not, it caches a slice of the provided
|
|
1420
|
+
encoding. If the batch size is greater than one, it expands the cached positional encoding to match the current
|
|
1421
|
+
batch size.
|
|
1359
1422
|
|
|
1360
1423
|
Args:
|
|
1361
|
-
out_maskmem_pos_enc (
|
|
1362
|
-
|
|
1424
|
+
out_maskmem_pos_enc (list[torch.Tensor] | None): The positional encoding for mask memory. Should be a list
|
|
1425
|
+
of tensors or None.
|
|
1363
1426
|
|
|
1364
1427
|
Returns:
|
|
1365
|
-
|
|
1428
|
+
(list[torch.Tensor]): The positional encoding for mask memory, either cached or expanded.
|
|
1366
1429
|
|
|
1367
|
-
|
|
1430
|
+
Notes:
|
|
1368
1431
|
- The method assumes that `out_maskmem_pos_enc` is a list of tensors or None.
|
|
1369
1432
|
- Only a single object's slice is cached since the encoding is the same across objects.
|
|
1370
1433
|
- The method checks if the positional encoding has already been cached in the session's constants.
|
|
@@ -1381,7 +1444,7 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
|
1381
1444
|
else:
|
|
1382
1445
|
maskmem_pos_enc = model_constants["maskmem_pos_enc"]
|
|
1383
1446
|
# expand the cached maskmem_pos_enc to the actual batch size
|
|
1384
|
-
batch_size = out_maskmem_pos_enc[0].
|
|
1447
|
+
batch_size = out_maskmem_pos_enc[0].shape[0]
|
|
1385
1448
|
if batch_size > 1:
|
|
1386
1449
|
out_maskmem_pos_enc = [x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc]
|
|
1387
1450
|
return out_maskmem_pos_enc
|
|
@@ -1392,25 +1455,23 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
|
1392
1455
|
is_cond=False,
|
|
1393
1456
|
run_mem_encoder=False,
|
|
1394
1457
|
):
|
|
1395
|
-
"""
|
|
1396
|
-
Consolidates per-object temporary outputs into a single output for all objects.
|
|
1458
|
+
"""Consolidate per-object temporary outputs into a single output for all objects.
|
|
1397
1459
|
|
|
1398
1460
|
This method combines the temporary outputs for each object on a given frame into a unified
|
|
1399
1461
|
output. It fills in any missing objects either from the main output dictionary or leaves
|
|
1400
|
-
placeholders if they do not exist in the main output. Optionally, it can re-run the memory
|
|
1401
|
-
|
|
1462
|
+
placeholders if they do not exist in the main output. Optionally, it can re-run the memory encoder after
|
|
1463
|
+
applying non-overlapping constraints to the object scores.
|
|
1402
1464
|
|
|
1403
1465
|
Args:
|
|
1404
1466
|
frame_idx (int): The index of the frame for which to consolidate outputs.
|
|
1405
|
-
is_cond (bool,
|
|
1406
|
-
|
|
1407
|
-
|
|
1408
|
-
consolidating the outputs. Defaults to False.
|
|
1467
|
+
is_cond (bool, optional): Indicates if the frame is considered a conditioning frame.
|
|
1468
|
+
run_mem_encoder (bool, optional): Specifies whether to run the memory encoder after consolidating the
|
|
1469
|
+
outputs.
|
|
1409
1470
|
|
|
1410
1471
|
Returns:
|
|
1411
|
-
|
|
1472
|
+
(dict): A consolidated output dictionary containing the combined results for all objects.
|
|
1412
1473
|
|
|
1413
|
-
|
|
1474
|
+
Notes:
|
|
1414
1475
|
- The method initializes the consolidated output with placeholder values for missing objects.
|
|
1415
1476
|
- It searches for outputs in both the temporary and main output dictionaries.
|
|
1416
1477
|
- If `run_mem_encoder` is True, it applies non-overlapping constraints and re-runs the memory encoder.
|
|
@@ -1429,13 +1490,13 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
|
1429
1490
|
"pred_masks": torch.full(
|
|
1430
1491
|
size=(batch_size, 1, self.imgsz[0] // 4, self.imgsz[1] // 4),
|
|
1431
1492
|
fill_value=-1024.0,
|
|
1432
|
-
dtype=
|
|
1493
|
+
dtype=self.torch_dtype,
|
|
1433
1494
|
device=self.device,
|
|
1434
1495
|
),
|
|
1435
1496
|
"obj_ptr": torch.full(
|
|
1436
1497
|
size=(batch_size, self.model.hidden_dim),
|
|
1437
1498
|
fill_value=-1024.0,
|
|
1438
|
-
dtype=
|
|
1499
|
+
dtype=self.torch_dtype,
|
|
1439
1500
|
device=self.device,
|
|
1440
1501
|
),
|
|
1441
1502
|
"object_score_logits": torch.full(
|
|
@@ -1443,7 +1504,7 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
|
1443
1504
|
# default to 10.0 for object_score_logits, i.e. assuming the object is
|
|
1444
1505
|
# present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
|
|
1445
1506
|
fill_value=10.0,
|
|
1446
|
-
dtype=
|
|
1507
|
+
dtype=self.torch_dtype,
|
|
1447
1508
|
device=self.device,
|
|
1448
1509
|
),
|
|
1449
1510
|
}
|
|
@@ -1494,8 +1555,7 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
|
1494
1555
|
return consolidated_out
|
|
1495
1556
|
|
|
1496
1557
|
def _get_empty_mask_ptr(self, frame_idx):
|
|
1497
|
-
"""
|
|
1498
|
-
Get a dummy object pointer based on an empty mask on the current frame.
|
|
1558
|
+
"""Get a dummy object pointer based on an empty mask on the current frame.
|
|
1499
1559
|
|
|
1500
1560
|
Args:
|
|
1501
1561
|
frame_idx (int): The index of the current frame for which to generate the dummy object pointer.
|
|
@@ -1515,7 +1575,7 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
|
1515
1575
|
feat_sizes=feat_sizes,
|
|
1516
1576
|
point_inputs=None,
|
|
1517
1577
|
# A dummy (empty) mask with a single object
|
|
1518
|
-
mask_inputs=torch.zeros((1, 1, *self.imgsz), dtype=
|
|
1578
|
+
mask_inputs=torch.zeros((1, 1, *self.imgsz), dtype=self.torch_dtype, device=self.device),
|
|
1519
1579
|
output_dict={},
|
|
1520
1580
|
num_frames=self.inference_state["num_frames"],
|
|
1521
1581
|
track_in_reverse=False,
|
|
@@ -1525,8 +1585,7 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
|
1525
1585
|
return current_out["obj_ptr"]
|
|
1526
1586
|
|
|
1527
1587
|
def _run_memory_encoder(self, batch_size, high_res_masks, object_score_logits, is_mask_from_pts):
|
|
1528
|
-
"""
|
|
1529
|
-
Run the memory encoder on masks.
|
|
1588
|
+
"""Run the memory encoder on masks.
|
|
1530
1589
|
|
|
1531
1590
|
This is usually after applying non-overlapping constraints to object scores. Since their scores changed, their
|
|
1532
1591
|
memory also needs to be computed again with the memory encoder.
|
|
@@ -1538,7 +1597,8 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
|
1538
1597
|
is_mask_from_pts (bool): Indicates if the mask is derived from point interactions.
|
|
1539
1598
|
|
|
1540
1599
|
Returns:
|
|
1541
|
-
(
|
|
1600
|
+
maskmem_features (torch.Tensor): The encoded mask features.
|
|
1601
|
+
maskmem_pos_enc (torch.Tensor): The positional encoding.
|
|
1542
1602
|
"""
|
|
1543
1603
|
# Retrieve correct image features
|
|
1544
1604
|
current_vision_feats, _, feat_sizes = self.get_im_features(self.inference_state["im"], batch_size)
|
|
@@ -1552,11 +1612,12 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
|
1552
1612
|
|
|
1553
1613
|
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
|
1554
1614
|
maskmem_pos_enc = self._get_maskmem_pos_enc(maskmem_pos_enc)
|
|
1555
|
-
return maskmem_features.to(
|
|
1615
|
+
return maskmem_features.to(
|
|
1616
|
+
dtype=torch.float16, device=self.device, non_blocking=self.device.type == "cuda"
|
|
1617
|
+
), maskmem_pos_enc
|
|
1556
1618
|
|
|
1557
1619
|
def _add_output_per_object(self, frame_idx, current_out, storage_key):
|
|
1558
|
-
"""
|
|
1559
|
-
Split a multi-object output into per-object output slices and add them into Output_Dict_Per_Obj.
|
|
1620
|
+
"""Split a multi-object output into per-object output slices and add them into Output_Dict_Per_Obj.
|
|
1560
1621
|
|
|
1561
1622
|
The resulting slices share the same tensor storage.
|
|
1562
1623
|
|
|
@@ -1586,12 +1647,12 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
|
1586
1647
|
obj_output_dict[storage_key][frame_idx] = obj_out
|
|
1587
1648
|
|
|
1588
1649
|
def _clear_non_cond_mem_around_input(self, frame_idx):
|
|
1589
|
-
"""
|
|
1590
|
-
Remove the non-conditioning memory around the input frame.
|
|
1650
|
+
"""Remove the non-conditioning memory around the input frame.
|
|
1591
1651
|
|
|
1592
|
-
When users provide correction clicks, the surrounding frames' non-conditioning memories can still contain
|
|
1593
|
-
object appearance information and could confuse the model. This method clears those non-conditioning
|
|
1594
|
-
surrounding the interacted frame to avoid giving the model both old and new information about the
|
|
1652
|
+
When users provide correction clicks, the surrounding frames' non-conditioning memories can still contain
|
|
1653
|
+
outdated object appearance information and could confuse the model. This method clears those non-conditioning
|
|
1654
|
+
memories surrounding the interacted frame to avoid giving the model both old and new information about the
|
|
1655
|
+
object.
|
|
1595
1656
|
|
|
1596
1657
|
Args:
|
|
1597
1658
|
frame_idx (int): The index of the current frame where user interaction occurred.
|
|
@@ -1603,3 +1664,346 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
|
1603
1664
|
self.inference_state["output_dict"]["non_cond_frame_outputs"].pop(t, None)
|
|
1604
1665
|
for obj_output_dict in self.inference_state["output_dict_per_obj"].values():
|
|
1605
1666
|
obj_output_dict["non_cond_frame_outputs"].pop(t, None)
|
|
1667
|
+
|
|
1668
|
+
|
|
1669
|
+
class SAM2DynamicInteractivePredictor(SAM2Predictor):
|
|
1670
|
+
"""SAM2DynamicInteractivePredictor extends SAM2Predictor to support dynamic interactions with video frames or a
|
|
1671
|
+
sequence of images.
|
|
1672
|
+
|
|
1673
|
+
Attributes:
|
|
1674
|
+
memory_bank (list): OrderedDict: Stores the states of each image with prompts.
|
|
1675
|
+
obj_idx_set (set): A set to keep track of the object indices that have been added.
|
|
1676
|
+
obj_id_to_idx (OrderedDict): Maps object IDs to their corresponding indices.
|
|
1677
|
+
obj_idx_to_id (OrderedDict): Maps object indices to their corresponding IDs.
|
|
1678
|
+
|
|
1679
|
+
Methods:
|
|
1680
|
+
get_model: Retrieves and configures the model with binarization enabled.
|
|
1681
|
+
inference: Performs inference on a single image with optional prompts and object IDs.
|
|
1682
|
+
postprocess: Post-processes the predictions to apply non-overlapping constraints if required.
|
|
1683
|
+
update_memory: Append the imgState to the memory_bank and update the memory for the model.
|
|
1684
|
+
track_step: Tracking step for the current image state to predict masks.
|
|
1685
|
+
get_maskmem_enc: Get memory and positional encoding from the memory bank.
|
|
1686
|
+
|
|
1687
|
+
Examples:
|
|
1688
|
+
>>> predictor = SAM2DynamicInteractivePredictor(cfg=DEFAULT_CFG)
|
|
1689
|
+
>>> predictor(source=support_img1, bboxes=bboxes1, obj_ids=labels1, update_memory=True)
|
|
1690
|
+
>>> results1 = predictor(source=query_img1)
|
|
1691
|
+
>>> predictor(source=support_img2, bboxes=bboxes2, obj_ids=labels2, update_memory=True)
|
|
1692
|
+
>>> results2 = predictor(source=query_img2)
|
|
1693
|
+
"""
|
|
1694
|
+
|
|
1695
|
+
def __init__(
|
|
1696
|
+
self,
|
|
1697
|
+
cfg: Any = DEFAULT_CFG,
|
|
1698
|
+
overrides: dict[str, Any] | None = None,
|
|
1699
|
+
max_obj_num: int = 3,
|
|
1700
|
+
_callbacks: dict[str, Any] | None = None,
|
|
1701
|
+
) -> None:
|
|
1702
|
+
"""Initialize the predictor with configuration and optional overrides.
|
|
1703
|
+
|
|
1704
|
+
This constructor initializes the SAM2DynamicInteractivePredictor with a given configuration, applies any
|
|
1705
|
+
specified overrides
|
|
1706
|
+
|
|
1707
|
+
Args:
|
|
1708
|
+
cfg (dict[str, Any]): Configuration dictionary containing default settings.
|
|
1709
|
+
overrides (dict[str, Any] | None): Dictionary of values to override default configuration.
|
|
1710
|
+
max_obj_num (int): Maximum number of objects to track. Default is 3. this is set to keep fix feature size
|
|
1711
|
+
for the model.
|
|
1712
|
+
_callbacks (dict[str, Any] | None): Dictionary of callback functions to customize behavior.
|
|
1713
|
+
|
|
1714
|
+
Examples:
|
|
1715
|
+
>>> predictor = SAM2DynamicInteractivePredictor(cfg=DEFAULT_CFG)
|
|
1716
|
+
>>> predictor_example_with_imgsz = SAM2DynamicInteractivePredictor(overrides={"imgsz": 640})
|
|
1717
|
+
>>> predictor_example_with_callback = SAM2DynamicInteractivePredictor(
|
|
1718
|
+
... _callbacks={"on_predict_start": custom_callback}
|
|
1719
|
+
... )
|
|
1720
|
+
"""
|
|
1721
|
+
super().__init__(cfg, overrides, _callbacks)
|
|
1722
|
+
self.non_overlap_masks = True
|
|
1723
|
+
|
|
1724
|
+
# Initialize the memory bank to store image states
|
|
1725
|
+
# NOTE: probably need to use dict for better query
|
|
1726
|
+
self.memory_bank = []
|
|
1727
|
+
|
|
1728
|
+
# Initialize the object index set and mappings
|
|
1729
|
+
self.obj_idx_set = set()
|
|
1730
|
+
self.obj_id_to_idx = OrderedDict()
|
|
1731
|
+
self.obj_idx_to_id = OrderedDict()
|
|
1732
|
+
self._max_obj_num = max_obj_num
|
|
1733
|
+
for i in range(self._max_obj_num):
|
|
1734
|
+
self.obj_id_to_idx[i + 1] = i
|
|
1735
|
+
self.obj_idx_to_id[i] = i + 1
|
|
1736
|
+
|
|
1737
|
+
@smart_inference_mode()
|
|
1738
|
+
def inference(
|
|
1739
|
+
self,
|
|
1740
|
+
im: torch.Tensor | np.ndarray,
|
|
1741
|
+
bboxes: list[list[float]] | None = None,
|
|
1742
|
+
masks: torch.Tensor | np.ndarray | None = None,
|
|
1743
|
+
points: list[list[float]] | None = None,
|
|
1744
|
+
labels: list[int] | None = None,
|
|
1745
|
+
obj_ids: list[int] | None = None,
|
|
1746
|
+
update_memory: bool = False,
|
|
1747
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
1748
|
+
"""Perform inference on a single image with optional bounding boxes, masks, points and object IDs. It has two
|
|
1749
|
+
modes: one is to run inference on a single image without updating the memory, and the other is to update
|
|
1750
|
+
the memory with the provided prompts and object IDs. When update_memory is True, it will update the
|
|
1751
|
+
memory with the provided prompts and obj_ids. When update_memory is False, it will only run inference on
|
|
1752
|
+
the provided image without updating the memory.
|
|
1753
|
+
|
|
1754
|
+
Args:
|
|
1755
|
+
im (torch.Tensor | np.ndarray): The input image tensor or numpy array.
|
|
1756
|
+
bboxes (list[list[float]] | None): Optional list of bounding boxes to update the memory.
|
|
1757
|
+
masks (list[torch.Tensor | np.ndarray] | None): Optional masks to update the memory.
|
|
1758
|
+
points (list[list[float]] | None): Optional list of points to update the memory, each point is [x, y].
|
|
1759
|
+
labels (list[int] | None): Optional list of object IDs corresponding to the points (>0 for positive, 0 for
|
|
1760
|
+
negative).
|
|
1761
|
+
obj_ids (list[int] | None): Optional list of object IDs corresponding to the prompts.
|
|
1762
|
+
update_memory (bool): Flag to indicate whether to update the memory with new objects.
|
|
1763
|
+
|
|
1764
|
+
Returns:
|
|
1765
|
+
res_masks (torch.Tensor): The output masks in shape (C, H, W)
|
|
1766
|
+
object_score_logits (torch.Tensor): Quality scores for each mask
|
|
1767
|
+
"""
|
|
1768
|
+
self.get_im_features(im)
|
|
1769
|
+
points, labels, masks = self._prepare_prompts(
|
|
1770
|
+
dst_shape=self.imgsz,
|
|
1771
|
+
src_shape=self.batch[1][0].shape[:2],
|
|
1772
|
+
points=points,
|
|
1773
|
+
bboxes=bboxes,
|
|
1774
|
+
labels=labels,
|
|
1775
|
+
masks=masks,
|
|
1776
|
+
)
|
|
1777
|
+
|
|
1778
|
+
if update_memory:
|
|
1779
|
+
if isinstance(obj_ids, int):
|
|
1780
|
+
obj_ids = [obj_ids]
|
|
1781
|
+
assert obj_ids is not None, "obj_ids must be provided when update_memory is True"
|
|
1782
|
+
assert masks is not None or points is not None, (
|
|
1783
|
+
"bboxes, masks, or points must be provided when update_memory is True"
|
|
1784
|
+
)
|
|
1785
|
+
if points is None: # placeholder
|
|
1786
|
+
points = torch.zeros((len(obj_ids), 0, 2), dtype=self.torch_dtype, device=self.device)
|
|
1787
|
+
labels = torch.zeros((len(obj_ids), 0), dtype=torch.int32, device=self.device)
|
|
1788
|
+
if masks is not None:
|
|
1789
|
+
assert len(masks) == len(obj_ids), "masks and obj_ids must have the same length."
|
|
1790
|
+
assert len(points) == len(obj_ids), "points and obj_ids must have the same length."
|
|
1791
|
+
self.update_memory(obj_ids, points, labels, masks)
|
|
1792
|
+
|
|
1793
|
+
current_out = self.track_step()
|
|
1794
|
+
pred_masks, pred_scores = current_out["pred_masks"], current_out["object_score_logits"]
|
|
1795
|
+
# filter the masks and logits based on the object indices
|
|
1796
|
+
if len(self.obj_idx_set) == 0:
|
|
1797
|
+
raise RuntimeError("No objects have been added to the state. Please add objects before inference.")
|
|
1798
|
+
idx = list(self.obj_idx_set) # cls id
|
|
1799
|
+
pred_masks, pred_scores = pred_masks[idx], pred_scores[idx]
|
|
1800
|
+
# the original score are in [-32,32], and a object score larger than 0 means the object is present, we map it to [-1,1] range,
|
|
1801
|
+
# and use a activate function to make sure the object score logits are non-negative, so that we can use it as a mask
|
|
1802
|
+
pred_scores = torch.clamp_(pred_scores / 32, min=0)
|
|
1803
|
+
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
|
|
1804
|
+
|
|
1805
|
+
def get_im_features(self, img: torch.Tensor | np.ndarray) -> None:
|
|
1806
|
+
"""Initialize the image state by processing the input image and extracting features.
|
|
1807
|
+
|
|
1808
|
+
Args:
|
|
1809
|
+
img (torch.Tensor | np.ndarray): The input image tensor or numpy array.
|
|
1810
|
+
"""
|
|
1811
|
+
vis_feats, vis_pos_embed, feat_sizes = SAM2VideoPredictor.get_im_features(self, img, batch=self._max_obj_num)
|
|
1812
|
+
self.high_res_features = [
|
|
1813
|
+
feat.permute(1, 2, 0).view(*feat.shape[1:], *feat_size)
|
|
1814
|
+
for feat, feat_size in zip(vis_feats[:-1], feat_sizes[:-1])
|
|
1815
|
+
]
|
|
1816
|
+
|
|
1817
|
+
self.vision_feats = vis_feats
|
|
1818
|
+
self.vision_pos_embeds = vis_pos_embed
|
|
1819
|
+
self.feat_sizes = feat_sizes
|
|
1820
|
+
|
|
1821
|
+
@smart_inference_mode()
|
|
1822
|
+
def update_memory(
|
|
1823
|
+
self,
|
|
1824
|
+
obj_ids: list[int] | None = None,
|
|
1825
|
+
points: torch.Tensor | None = None,
|
|
1826
|
+
labels: torch.Tensor | None = None,
|
|
1827
|
+
masks: torch.Tensor | None = None,
|
|
1828
|
+
) -> None:
|
|
1829
|
+
"""Append the imgState to the memory_bank and update the memory for the model.
|
|
1830
|
+
|
|
1831
|
+
Args:
|
|
1832
|
+
obj_ids (list[int]): List of object IDs corresponding to the prompts.
|
|
1833
|
+
points (torch.Tensor | None): Tensor of shape (B, N, 2) representing the input points for N objects.
|
|
1834
|
+
labels (torch.Tensor | None): Tensor of shape (B, N) representing the labels for the input points.
|
|
1835
|
+
masks (torch.Tensor | None): Optional tensor of shape (N, H, W) representing the input masks for N objects.
|
|
1836
|
+
"""
|
|
1837
|
+
consolidated_out = {
|
|
1838
|
+
"maskmem_features": None,
|
|
1839
|
+
"maskmem_pos_enc": None,
|
|
1840
|
+
"pred_masks": torch.full(
|
|
1841
|
+
size=(self._max_obj_num, 1, self.imgsz[0] // 4, self.imgsz[1] // 4),
|
|
1842
|
+
fill_value=-1024.0,
|
|
1843
|
+
dtype=self.torch_dtype,
|
|
1844
|
+
device=self.device,
|
|
1845
|
+
),
|
|
1846
|
+
"obj_ptr": torch.full(
|
|
1847
|
+
size=(self._max_obj_num, self.model.hidden_dim),
|
|
1848
|
+
fill_value=-1024.0,
|
|
1849
|
+
dtype=self.torch_dtype,
|
|
1850
|
+
device=self.device,
|
|
1851
|
+
),
|
|
1852
|
+
"object_score_logits": torch.full(
|
|
1853
|
+
size=(self._max_obj_num, 1),
|
|
1854
|
+
# default to 10.0 for object_score_logits, i.e. assuming the object is
|
|
1855
|
+
# present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
|
|
1856
|
+
fill_value=-32, # 10.0,
|
|
1857
|
+
dtype=self.torch_dtype,
|
|
1858
|
+
device=self.device,
|
|
1859
|
+
),
|
|
1860
|
+
}
|
|
1861
|
+
|
|
1862
|
+
for i, obj_id in enumerate(obj_ids):
|
|
1863
|
+
assert obj_id < self._max_obj_num
|
|
1864
|
+
obj_idx = self._obj_id_to_idx(int(obj_id))
|
|
1865
|
+
self.obj_idx_set.add(obj_idx)
|
|
1866
|
+
point, label = points[[i]], labels[[i]]
|
|
1867
|
+
mask = masks[[i]][None] if masks is not None else None
|
|
1868
|
+
# Currently, only bbox prompt or mask prompt is supported, so we assert that bbox is not None.
|
|
1869
|
+
assert point is not None or mask is not None, "Either bbox, points or mask is required"
|
|
1870
|
+
out = self.track_step(obj_idx, point, label, mask)
|
|
1871
|
+
if out is not None:
|
|
1872
|
+
obj_mask = out["pred_masks"]
|
|
1873
|
+
assert obj_mask.shape[-2:] == consolidated_out["pred_masks"].shape[-2:], (
|
|
1874
|
+
f"Expected mask shape {consolidated_out['pred_masks'].shape[-2:]} but got {obj_mask.shape[-2:]} for object {obj_idx}."
|
|
1875
|
+
)
|
|
1876
|
+
consolidated_out["pred_masks"][obj_idx : obj_idx + 1] = obj_mask
|
|
1877
|
+
consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
|
|
1878
|
+
|
|
1879
|
+
if "object_score_logits" in out.keys():
|
|
1880
|
+
consolidated_out["object_score_logits"][obj_idx : obj_idx + 1] = out["object_score_logits"]
|
|
1881
|
+
|
|
1882
|
+
high_res_masks = F.interpolate(
|
|
1883
|
+
consolidated_out["pred_masks"].to(self.device, non_blocking=self.device.type == "cuda"),
|
|
1884
|
+
size=self.imgsz,
|
|
1885
|
+
mode="bilinear",
|
|
1886
|
+
align_corners=False,
|
|
1887
|
+
)
|
|
1888
|
+
|
|
1889
|
+
if self.model.non_overlap_masks_for_mem_enc:
|
|
1890
|
+
high_res_masks = self.model._apply_non_overlapping_constraints(high_res_masks)
|
|
1891
|
+
maskmem_features, maskmem_pos_enc = self.model._encode_new_memory(
|
|
1892
|
+
current_vision_feats=self.vision_feats,
|
|
1893
|
+
feat_sizes=self.feat_sizes,
|
|
1894
|
+
pred_masks_high_res=high_res_masks,
|
|
1895
|
+
object_score_logits=consolidated_out["object_score_logits"],
|
|
1896
|
+
is_mask_from_pts=True,
|
|
1897
|
+
)
|
|
1898
|
+
consolidated_out["maskmem_features"] = maskmem_features
|
|
1899
|
+
consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc
|
|
1900
|
+
self.memory_bank.append(consolidated_out)
|
|
1901
|
+
|
|
1902
|
+
def _prepare_memory_conditioned_features(self, obj_idx: int | None) -> torch.Tensor:
|
|
1903
|
+
"""Prepare the memory-conditioned features for the current image state. If obj_idx is provided, it supposes to
|
|
1904
|
+
prepare features for a specific prompted object in the image. If obj_idx is None, it prepares features
|
|
1905
|
+
for all objects in the image. If there is no memory, it will directly add a no-memory embedding to the
|
|
1906
|
+
current vision features. If there is memory, it will use the memory features from previous frames to
|
|
1907
|
+
condition the current vision features using a transformer attention mechanism.
|
|
1908
|
+
|
|
1909
|
+
Args:
|
|
1910
|
+
obj_idx (int | None): The index of the object for which to prepare the features.
|
|
1911
|
+
|
|
1912
|
+
Returns:
|
|
1913
|
+
pix_feat_with_mem (torch.Tensor): The memory-conditioned pixel features.
|
|
1914
|
+
"""
|
|
1915
|
+
if len(self.memory_bank) == 0 or isinstance(obj_idx, int):
|
|
1916
|
+
# for initial conditioning frames with, encode them without using any previous memory
|
|
1917
|
+
# directly add no-mem embedding (instead of using the transformer encoder)
|
|
1918
|
+
pix_feat_with_mem = self.vision_feats[-1] + self.model.no_mem_embed
|
|
1919
|
+
else:
|
|
1920
|
+
# for inference frames, use the memory features from previous frames
|
|
1921
|
+
memory, memory_pos_embed = self.get_maskmem_enc()
|
|
1922
|
+
pix_feat_with_mem = self.model.memory_attention(
|
|
1923
|
+
curr=self.vision_feats[-1:],
|
|
1924
|
+
curr_pos=self.vision_pos_embeds[-1:],
|
|
1925
|
+
memory=memory,
|
|
1926
|
+
memory_pos=memory_pos_embed,
|
|
1927
|
+
num_obj_ptr_tokens=0, # num_obj_ptr_tokens
|
|
1928
|
+
)
|
|
1929
|
+
# reshape the output (HW)BC => BCHW
|
|
1930
|
+
return pix_feat_with_mem.permute(1, 2, 0).view(
|
|
1931
|
+
self._max_obj_num,
|
|
1932
|
+
self.model.memory_attention.d_model,
|
|
1933
|
+
*self.feat_sizes[-1],
|
|
1934
|
+
)
|
|
1935
|
+
|
|
1936
|
+
def get_maskmem_enc(self) -> tuple[torch.Tensor, torch.Tensor]:
|
|
1937
|
+
"""Get memory and positional encoding from memory, which is used to condition the current image features."""
|
|
1938
|
+
to_cat_memory, to_cat_memory_pos_embed = [], []
|
|
1939
|
+
for consolidated_out in self.memory_bank:
|
|
1940
|
+
to_cat_memory.append(consolidated_out["maskmem_features"].flatten(2).permute(2, 0, 1)) # (H*W, B, C)
|
|
1941
|
+
maskmem_enc = consolidated_out["maskmem_pos_enc"][-1].flatten(2).permute(2, 0, 1)
|
|
1942
|
+
maskmem_enc = maskmem_enc + self.model.maskmem_tpos_enc[self.model.num_maskmem - 1]
|
|
1943
|
+
to_cat_memory_pos_embed.append(maskmem_enc)
|
|
1944
|
+
|
|
1945
|
+
memory = torch.cat(to_cat_memory, dim=0)
|
|
1946
|
+
memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)
|
|
1947
|
+
return memory, memory_pos_embed
|
|
1948
|
+
|
|
1949
|
+
def _obj_id_to_idx(self, obj_id: int) -> int | None:
|
|
1950
|
+
"""Map client-side object id to model-side object index.
|
|
1951
|
+
|
|
1952
|
+
Args:
|
|
1953
|
+
obj_id (int): The client-side object ID.
|
|
1954
|
+
|
|
1955
|
+
Returns:
|
|
1956
|
+
(int): The model-side object index, or None if not found.
|
|
1957
|
+
"""
|
|
1958
|
+
return self.obj_id_to_idx.get(obj_id, None)
|
|
1959
|
+
|
|
1960
|
+
def track_step(
|
|
1961
|
+
self,
|
|
1962
|
+
obj_idx: int | None = None,
|
|
1963
|
+
point: torch.Tensor | None = None,
|
|
1964
|
+
label: torch.Tensor | None = None,
|
|
1965
|
+
mask: torch.Tensor | None = None,
|
|
1966
|
+
) -> dict[str, Any]:
|
|
1967
|
+
"""Tracking step for the current image state to predict masks.
|
|
1968
|
+
|
|
1969
|
+
This method processes the image features and runs the SAM heads to predict masks. If obj_idx is provided, it
|
|
1970
|
+
processes the features for a specific prompted object in the image. If obj_idx is None, it processes the
|
|
1971
|
+
features for all objects in the image. The method supports both mask-based output without SAM and full SAM
|
|
1972
|
+
processing with memory-conditioned features.
|
|
1973
|
+
|
|
1974
|
+
Args:
|
|
1975
|
+
obj_idx (int | None): The index of the object for which to predict masks. If None, it processes all objects.
|
|
1976
|
+
point (torch.Tensor | None): The coordinates of the points of interest with shape (N, 2).
|
|
1977
|
+
label (torch.Tensor | None): The labels corresponding to the points where 1 means positive clicks, 0 means
|
|
1978
|
+
negative clicks.
|
|
1979
|
+
mask (torch.Tensor | None): The mask input for the object with shape (H, W).
|
|
1980
|
+
|
|
1981
|
+
Returns:
|
|
1982
|
+
current_out (dict[str, Any]): A dictionary containing the current output with mask predictions and object
|
|
1983
|
+
pointers. Keys include 'point_inputs', 'mask_inputs', 'pred_masks', 'pred_masks_high_res',
|
|
1984
|
+
'obj_ptr', 'object_score_logits'.
|
|
1985
|
+
"""
|
|
1986
|
+
if mask is not None and self.model.use_mask_input_as_output_without_sam:
|
|
1987
|
+
# When use_mask_input_as_output_without_sam=True, we directly output the mask input
|
|
1988
|
+
# (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
|
|
1989
|
+
pix_feat = self.vision_feats[-1].permute(1, 2, 0)
|
|
1990
|
+
pix_feat = pix_feat.view(-1, self.model.memory_attention.d_model, *self.feat_sizes[-1])
|
|
1991
|
+
_, _, _, low_res_masks, high_res_masks, obj_ptr, object_score_logits = self.model._use_mask_as_output(mask)
|
|
1992
|
+
else:
|
|
1993
|
+
# fused the visual feature with previous memory features in the memory bank
|
|
1994
|
+
pix_feat_with_mem = self._prepare_memory_conditioned_features(obj_idx)
|
|
1995
|
+
# calculate the first feature if adding obj_idx exists(means adding prompts)
|
|
1996
|
+
pix_feat_with_mem = pix_feat_with_mem[:1] if obj_idx is not None else pix_feat_with_mem
|
|
1997
|
+
_, _, _, low_res_masks, high_res_masks, obj_ptr, object_score_logits = self.model._forward_sam_heads(
|
|
1998
|
+
backbone_features=pix_feat_with_mem,
|
|
1999
|
+
point_inputs={"point_coords": point, "point_labels": label} if obj_idx is not None else None,
|
|
2000
|
+
mask_inputs=mask,
|
|
2001
|
+
multimask_output=False,
|
|
2002
|
+
high_res_features=[feat[: pix_feat_with_mem.shape[0]] for feat in self.high_res_features],
|
|
2003
|
+
)
|
|
2004
|
+
return {
|
|
2005
|
+
"pred_masks": low_res_masks,
|
|
2006
|
+
"pred_masks_high_res": high_res_masks,
|
|
2007
|
+
"obj_ptr": obj_ptr,
|
|
2008
|
+
"object_score_logits": object_score_logits,
|
|
2009
|
+
}
|