dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
- dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -6
- tests/conftest.py +15 -39
- tests/test_cli.py +17 -17
- tests/test_cuda.py +17 -8
- tests/test_engine.py +36 -10
- tests/test_exports.py +98 -37
- tests/test_integrations.py +12 -15
- tests/test_python.py +126 -82
- tests/test_solutions.py +319 -135
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +83 -87
- ultralytics/cfg/datasets/Argoverse.yaml +4 -4
- ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
- ultralytics/cfg/datasets/ImageNet.yaml +3 -3
- ultralytics/cfg/datasets/Objects365.yaml +24 -20
- ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
- ultralytics/cfg/datasets/VOC.yaml +10 -13
- ultralytics/cfg/datasets/VisDrone.yaml +43 -33
- ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
- ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
- ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
- ultralytics/cfg/datasets/coco-pose.yaml +26 -4
- ultralytics/cfg/datasets/coco.yaml +4 -4
- ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco128.yaml +2 -2
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco8.yaml +2 -2
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +5 -5
- ultralytics/cfg/datasets/dog-pose.yaml +32 -4
- ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
- ultralytics/cfg/datasets/lvis.yaml +9 -9
- ultralytics/cfg/datasets/medical-pills.yaml +4 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
- ultralytics/cfg/datasets/package-seg.yaml +5 -5
- ultralytics/cfg/datasets/signature.yaml +4 -4
- ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
- ultralytics/cfg/datasets/xView.yaml +5 -5
- ultralytics/cfg/default.yaml +96 -93
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +12 -12
- ultralytics/data/augment.py +531 -564
- ultralytics/data/base.py +76 -81
- ultralytics/data/build.py +206 -42
- ultralytics/data/converter.py +179 -78
- ultralytics/data/dataset.py +121 -121
- ultralytics/data/loaders.py +114 -91
- ultralytics/data/split.py +28 -15
- ultralytics/data/split_dota.py +67 -48
- ultralytics/data/utils.py +110 -89
- ultralytics/engine/exporter.py +422 -460
- ultralytics/engine/model.py +224 -252
- ultralytics/engine/predictor.py +94 -89
- ultralytics/engine/results.py +345 -595
- ultralytics/engine/trainer.py +231 -134
- ultralytics/engine/tuner.py +279 -73
- ultralytics/engine/validator.py +53 -46
- ultralytics/hub/__init__.py +26 -28
- ultralytics/hub/auth.py +30 -16
- ultralytics/hub/google/__init__.py +34 -36
- ultralytics/hub/session.py +53 -77
- ultralytics/hub/utils.py +23 -109
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +36 -18
- ultralytics/models/fastsam/predict.py +33 -44
- ultralytics/models/fastsam/utils.py +4 -5
- ultralytics/models/fastsam/val.py +12 -14
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +16 -20
- ultralytics/models/nas/predict.py +12 -14
- ultralytics/models/nas/val.py +4 -5
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +9 -9
- ultralytics/models/rtdetr/predict.py +22 -17
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +79 -59
- ultralytics/models/sam/__init__.py +8 -2
- ultralytics/models/sam/amg.py +53 -38
- ultralytics/models/sam/build.py +29 -31
- ultralytics/models/sam/model.py +33 -38
- ultralytics/models/sam/modules/blocks.py +159 -182
- ultralytics/models/sam/modules/decoders.py +38 -47
- ultralytics/models/sam/modules/encoders.py +114 -133
- ultralytics/models/sam/modules/memory_attention.py +38 -31
- ultralytics/models/sam/modules/sam.py +114 -93
- ultralytics/models/sam/modules/tiny_encoder.py +268 -291
- ultralytics/models/sam/modules/transformer.py +59 -66
- ultralytics/models/sam/modules/utils.py +55 -72
- ultralytics/models/sam/predict.py +745 -341
- ultralytics/models/utils/loss.py +118 -107
- ultralytics/models/utils/ops.py +118 -71
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +28 -26
- ultralytics/models/yolo/classify/train.py +50 -81
- ultralytics/models/yolo/classify/val.py +68 -61
- ultralytics/models/yolo/detect/predict.py +12 -15
- ultralytics/models/yolo/detect/train.py +56 -46
- ultralytics/models/yolo/detect/val.py +279 -223
- ultralytics/models/yolo/model.py +167 -86
- ultralytics/models/yolo/obb/predict.py +7 -11
- ultralytics/models/yolo/obb/train.py +23 -25
- ultralytics/models/yolo/obb/val.py +107 -99
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +12 -14
- ultralytics/models/yolo/pose/train.py +31 -69
- ultralytics/models/yolo/pose/val.py +119 -254
- ultralytics/models/yolo/segment/predict.py +21 -25
- ultralytics/models/yolo/segment/train.py +12 -66
- ultralytics/models/yolo/segment/val.py +126 -305
- ultralytics/models/yolo/world/train.py +53 -45
- ultralytics/models/yolo/world/train_world.py +51 -32
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +30 -37
- ultralytics/models/yolo/yoloe/train.py +89 -71
- ultralytics/models/yolo/yoloe/train_seg.py +15 -17
- ultralytics/models/yolo/yoloe/val.py +56 -41
- ultralytics/nn/__init__.py +9 -11
- ultralytics/nn/autobackend.py +179 -107
- ultralytics/nn/modules/__init__.py +67 -67
- ultralytics/nn/modules/activation.py +8 -7
- ultralytics/nn/modules/block.py +302 -323
- ultralytics/nn/modules/conv.py +61 -104
- ultralytics/nn/modules/head.py +488 -186
- ultralytics/nn/modules/transformer.py +183 -123
- ultralytics/nn/modules/utils.py +15 -20
- ultralytics/nn/tasks.py +327 -203
- ultralytics/nn/text_model.py +81 -65
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +19 -27
- ultralytics/solutions/analytics.py +36 -26
- ultralytics/solutions/config.py +29 -28
- ultralytics/solutions/distance_calculation.py +23 -24
- ultralytics/solutions/heatmap.py +17 -19
- ultralytics/solutions/instance_segmentation.py +21 -19
- ultralytics/solutions/object_blurrer.py +16 -17
- ultralytics/solutions/object_counter.py +48 -53
- ultralytics/solutions/object_cropper.py +22 -16
- ultralytics/solutions/parking_management.py +61 -58
- ultralytics/solutions/queue_management.py +19 -19
- ultralytics/solutions/region_counter.py +63 -50
- ultralytics/solutions/security_alarm.py +22 -25
- ultralytics/solutions/similarity_search.py +107 -60
- ultralytics/solutions/solutions.py +343 -262
- ultralytics/solutions/speed_estimation.py +35 -31
- ultralytics/solutions/streamlit_inference.py +104 -40
- ultralytics/solutions/templates/similarity-search.html +31 -24
- ultralytics/solutions/trackzone.py +24 -24
- ultralytics/solutions/vision_eye.py +11 -12
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +18 -27
- ultralytics/trackers/bot_sort.py +48 -39
- ultralytics/trackers/byte_tracker.py +94 -94
- ultralytics/trackers/track.py +7 -16
- ultralytics/trackers/utils/gmc.py +37 -69
- ultralytics/trackers/utils/kalman_filter.py +68 -76
- ultralytics/trackers/utils/matching.py +13 -17
- ultralytics/utils/__init__.py +251 -275
- ultralytics/utils/autobatch.py +19 -7
- ultralytics/utils/autodevice.py +68 -38
- ultralytics/utils/benchmarks.py +169 -130
- ultralytics/utils/callbacks/base.py +12 -13
- ultralytics/utils/callbacks/clearml.py +14 -15
- ultralytics/utils/callbacks/comet.py +139 -66
- ultralytics/utils/callbacks/dvc.py +19 -27
- ultralytics/utils/callbacks/hub.py +8 -6
- ultralytics/utils/callbacks/mlflow.py +6 -10
- ultralytics/utils/callbacks/neptune.py +11 -19
- ultralytics/utils/callbacks/platform.py +73 -0
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +9 -12
- ultralytics/utils/callbacks/wb.py +33 -30
- ultralytics/utils/checks.py +163 -114
- ultralytics/utils/cpu.py +89 -0
- ultralytics/utils/dist.py +24 -20
- ultralytics/utils/downloads.py +176 -146
- ultralytics/utils/errors.py +11 -13
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +81 -63
- ultralytics/utils/export/imx.py +294 -0
- ultralytics/utils/export/tensorflow.py +217 -0
- ultralytics/utils/files.py +33 -36
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +105 -120
- ultralytics/utils/logger.py +404 -0
- ultralytics/utils/loss.py +99 -61
- ultralytics/utils/metrics.py +649 -478
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +263 -451
- ultralytics/utils/patches.py +70 -31
- ultralytics/utils/plotting.py +253 -223
- ultralytics/utils/tal.py +48 -61
- ultralytics/utils/torch_utils.py +244 -251
- ultralytics/utils/tqdm.py +438 -0
- ultralytics/utils/triton.py +22 -23
- ultralytics/utils/tuner.py +11 -10
- dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
|
@@ -1,17 +1,17 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
3
5
|
import copy
|
|
4
|
-
from typing import Optional
|
|
5
6
|
|
|
6
7
|
import torch
|
|
7
|
-
from torch import
|
|
8
|
+
from torch import nn
|
|
8
9
|
|
|
9
10
|
from .blocks import RoPEAttention
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
class MemoryAttentionLayer(nn.Module):
|
|
13
|
-
"""
|
|
14
|
-
Implements a memory attention layer with self-attention and cross-attention mechanisms for neural networks.
|
|
14
|
+
"""Implements a memory attention layer with self-attention and cross-attention mechanisms for neural networks.
|
|
15
15
|
|
|
16
16
|
This class combines self-attention, cross-attention, and feedforward components to process input tensors and
|
|
17
17
|
generate memory-based attention outputs.
|
|
@@ -60,8 +60,7 @@ class MemoryAttentionLayer(nn.Module):
|
|
|
60
60
|
pos_enc_at_cross_attn_keys: bool = True,
|
|
61
61
|
pos_enc_at_cross_attn_queries: bool = False,
|
|
62
62
|
):
|
|
63
|
-
"""
|
|
64
|
-
Initialize a memory attention layer with self-attention, cross-attention, and feedforward components.
|
|
63
|
+
"""Initialize a memory attention layer with self-attention, cross-attention, and feedforward components.
|
|
65
64
|
|
|
66
65
|
Args:
|
|
67
66
|
d_model (int): Dimensionality of the model.
|
|
@@ -103,7 +102,7 @@ class MemoryAttentionLayer(nn.Module):
|
|
|
103
102
|
self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
|
|
104
103
|
self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
|
|
105
104
|
|
|
106
|
-
def _forward_sa(self, tgt: Tensor, query_pos:
|
|
105
|
+
def _forward_sa(self, tgt: torch.Tensor, query_pos: torch.Tensor | None) -> torch.Tensor:
|
|
107
106
|
"""Perform self-attention on input tensor using positional encoding and RoPE attention mechanism."""
|
|
108
107
|
tgt2 = self.norm1(tgt)
|
|
109
108
|
q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
|
|
@@ -113,12 +112,12 @@ class MemoryAttentionLayer(nn.Module):
|
|
|
113
112
|
|
|
114
113
|
def _forward_ca(
|
|
115
114
|
self,
|
|
116
|
-
tgt: Tensor,
|
|
117
|
-
memory: Tensor,
|
|
118
|
-
query_pos:
|
|
119
|
-
pos:
|
|
115
|
+
tgt: torch.Tensor,
|
|
116
|
+
memory: torch.Tensor,
|
|
117
|
+
query_pos: torch.Tensor | None,
|
|
118
|
+
pos: torch.Tensor | None,
|
|
120
119
|
num_k_exclude_rope: int = 0,
|
|
121
|
-
) -> Tensor:
|
|
120
|
+
) -> torch.Tensor:
|
|
122
121
|
"""Perform cross-attention between target and memory tensors using RoPEAttention mechanism."""
|
|
123
122
|
kwds = {}
|
|
124
123
|
if num_k_exclude_rope > 0:
|
|
@@ -138,13 +137,24 @@ class MemoryAttentionLayer(nn.Module):
|
|
|
138
137
|
|
|
139
138
|
def forward(
|
|
140
139
|
self,
|
|
141
|
-
tgt: Tensor,
|
|
142
|
-
memory: Tensor,
|
|
143
|
-
pos:
|
|
144
|
-
query_pos:
|
|
140
|
+
tgt: torch.Tensor,
|
|
141
|
+
memory: torch.Tensor,
|
|
142
|
+
pos: torch.Tensor | None = None,
|
|
143
|
+
query_pos: torch.Tensor | None = None,
|
|
145
144
|
num_k_exclude_rope: int = 0,
|
|
146
145
|
) -> torch.Tensor:
|
|
147
|
-
"""Process input tensors through self-attention, cross-attention, and feedforward network layers.
|
|
146
|
+
"""Process input tensors through self-attention, cross-attention, and feedforward network layers.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
tgt (torch.Tensor): Target tensor for self-attention with shape (N, L, D).
|
|
150
|
+
memory (torch.Tensor): Memory tensor for cross-attention with shape (N, S, D).
|
|
151
|
+
pos (Optional[torch.Tensor]): Positional encoding for memory tensor.
|
|
152
|
+
query_pos (Optional[torch.Tensor]): Positional encoding for target tensor.
|
|
153
|
+
num_k_exclude_rope (int): Number of keys to exclude from rotary position embedding.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
(torch.Tensor): Processed tensor after attention and feedforward layers with shape (N, L, D).
|
|
157
|
+
"""
|
|
148
158
|
tgt = self._forward_sa(tgt, query_pos)
|
|
149
159
|
tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
|
|
150
160
|
# MLP
|
|
@@ -155,11 +165,10 @@ class MemoryAttentionLayer(nn.Module):
|
|
|
155
165
|
|
|
156
166
|
|
|
157
167
|
class MemoryAttention(nn.Module):
|
|
158
|
-
"""
|
|
159
|
-
Memory attention module for processing sequential data with self and cross-attention mechanisms.
|
|
168
|
+
"""Memory attention module for processing sequential data with self and cross-attention mechanisms.
|
|
160
169
|
|
|
161
|
-
This class implements a multi-layer attention mechanism that combines self-attention and cross-attention
|
|
162
|
-
|
|
170
|
+
This class implements a multi-layer attention mechanism that combines self-attention and cross-attention for
|
|
171
|
+
processing sequential data, particularly useful in transformer-like architectures.
|
|
163
172
|
|
|
164
173
|
Attributes:
|
|
165
174
|
d_model (int): The dimension of the model's hidden state.
|
|
@@ -193,11 +202,10 @@ class MemoryAttention(nn.Module):
|
|
|
193
202
|
num_layers: int,
|
|
194
203
|
batch_first: bool = True, # Do layers expect batch first input?
|
|
195
204
|
):
|
|
196
|
-
"""
|
|
197
|
-
Initialize MemoryAttention with specified layers and normalization for sequential data processing.
|
|
205
|
+
"""Initialize MemoryAttention with specified layers and normalization for sequential data processing.
|
|
198
206
|
|
|
199
|
-
This class implements a multi-layer attention mechanism that combines self-attention and cross-attention
|
|
200
|
-
|
|
207
|
+
This class implements a multi-layer attention mechanism that combines self-attention and cross-attention for
|
|
208
|
+
processing sequential data, particularly useful in transformer-like architectures.
|
|
201
209
|
|
|
202
210
|
Args:
|
|
203
211
|
d_model (int): The dimension of the model's hidden state.
|
|
@@ -230,18 +238,17 @@ class MemoryAttention(nn.Module):
|
|
|
230
238
|
self,
|
|
231
239
|
curr: torch.Tensor, # self-attention inputs
|
|
232
240
|
memory: torch.Tensor, # cross-attention inputs
|
|
233
|
-
curr_pos:
|
|
234
|
-
memory_pos:
|
|
241
|
+
curr_pos: torch.Tensor | None = None, # pos_enc for self-attention inputs
|
|
242
|
+
memory_pos: torch.Tensor | None = None, # pos_enc for cross-attention inputs
|
|
235
243
|
num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
|
|
236
244
|
) -> torch.Tensor:
|
|
237
|
-
"""
|
|
238
|
-
Process inputs through attention layers, applying self and cross-attention with positional encoding.
|
|
245
|
+
"""Process inputs through attention layers, applying self and cross-attention with positional encoding.
|
|
239
246
|
|
|
240
247
|
Args:
|
|
241
248
|
curr (torch.Tensor): Self-attention input tensor, representing the current state.
|
|
242
249
|
memory (torch.Tensor): Cross-attention input tensor, representing memory information.
|
|
243
|
-
curr_pos (Optional[Tensor]): Positional encoding for self-attention inputs.
|
|
244
|
-
memory_pos (Optional[Tensor]): Positional encoding for cross-attention inputs.
|
|
250
|
+
curr_pos (Optional[torch.Tensor]): Positional encoding for self-attention inputs.
|
|
251
|
+
memory_pos (Optional[torch.Tensor]): Positional encoding for cross-attention inputs.
|
|
245
252
|
num_obj_ptr_tokens (int): Number of object pointer tokens to exclude from rotary position embedding.
|
|
246
253
|
|
|
247
254
|
Returns:
|
|
@@ -3,10 +3,7 @@
|
|
|
3
3
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
4
4
|
# All rights reserved.
|
|
5
5
|
|
|
6
|
-
|
|
7
|
-
# LICENSE file in the root directory of this source tree.
|
|
8
|
-
|
|
9
|
-
from typing import List
|
|
6
|
+
from __future__ import annotations
|
|
10
7
|
|
|
11
8
|
import torch
|
|
12
9
|
import torch.nn.functional as F
|
|
@@ -26,20 +23,21 @@ NO_OBJ_SCORE = -1024.0
|
|
|
26
23
|
|
|
27
24
|
|
|
28
25
|
class SAMModel(nn.Module):
|
|
29
|
-
"""
|
|
30
|
-
Segment Anything Model (SAM) for object segmentation tasks.
|
|
26
|
+
"""Segment Anything Model (SAM) for object segmentation tasks.
|
|
31
27
|
|
|
32
|
-
This class combines image encoders, prompt encoders, and mask decoders to predict object masks from images
|
|
33
|
-
|
|
28
|
+
This class combines image encoders, prompt encoders, and mask decoders to predict object masks from images and input
|
|
29
|
+
prompts.
|
|
34
30
|
|
|
35
31
|
Attributes:
|
|
36
32
|
mask_threshold (float): Threshold value for mask prediction.
|
|
37
33
|
image_encoder (ImageEncoderViT): Backbone for encoding images into embeddings.
|
|
38
34
|
prompt_encoder (PromptEncoder): Encoder for various types of input prompts.
|
|
39
35
|
mask_decoder (MaskDecoder): Predicts object masks from image and prompt embeddings.
|
|
36
|
+
pixel_mean (torch.Tensor): Mean values for normalizing pixels in the input image.
|
|
37
|
+
pixel_std (torch.Tensor): Standard deviation values for normalizing pixels in the input image.
|
|
40
38
|
|
|
41
39
|
Methods:
|
|
42
|
-
|
|
40
|
+
set_imgsz: Set image size to make model compatible with different image sizes.
|
|
43
41
|
|
|
44
42
|
Examples:
|
|
45
43
|
>>> image_encoder = ImageEncoderViT(...)
|
|
@@ -59,18 +57,17 @@ class SAMModel(nn.Module):
|
|
|
59
57
|
image_encoder: ImageEncoderViT,
|
|
60
58
|
prompt_encoder: PromptEncoder,
|
|
61
59
|
mask_decoder: MaskDecoder,
|
|
62
|
-
pixel_mean:
|
|
63
|
-
pixel_std:
|
|
60
|
+
pixel_mean: list[float] = (123.675, 116.28, 103.53),
|
|
61
|
+
pixel_std: list[float] = (58.395, 57.12, 57.375),
|
|
64
62
|
) -> None:
|
|
65
|
-
"""
|
|
66
|
-
Initialize the SAMModel class to predict object masks from an image and input prompts.
|
|
63
|
+
"""Initialize the SAMModel class to predict object masks from an image and input prompts.
|
|
67
64
|
|
|
68
65
|
Args:
|
|
69
66
|
image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings.
|
|
70
67
|
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
|
|
71
68
|
mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts.
|
|
72
|
-
pixel_mean (
|
|
73
|
-
pixel_std (
|
|
69
|
+
pixel_mean (list[float]): Mean values for normalizing pixels in the input image.
|
|
70
|
+
pixel_std (list[float]): Standard deviation values for normalizing pixels in the input image.
|
|
74
71
|
|
|
75
72
|
Examples:
|
|
76
73
|
>>> image_encoder = ImageEncoderViT(...)
|
|
@@ -90,12 +87,7 @@ class SAMModel(nn.Module):
|
|
|
90
87
|
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
|
|
91
88
|
|
|
92
89
|
def set_imgsz(self, imgsz):
|
|
93
|
-
"""
|
|
94
|
-
Set image size to make model compatible with different image sizes.
|
|
95
|
-
|
|
96
|
-
Args:
|
|
97
|
-
imgsz (Tuple[int, int]): The size of the input image.
|
|
98
|
-
"""
|
|
90
|
+
"""Set image size to make model compatible with different image sizes."""
|
|
99
91
|
if hasattr(self.image_encoder, "set_imgsz"):
|
|
100
92
|
self.image_encoder.set_imgsz(imgsz)
|
|
101
93
|
self.prompt_encoder.input_image_size = imgsz
|
|
@@ -104,11 +96,10 @@ class SAMModel(nn.Module):
|
|
|
104
96
|
|
|
105
97
|
|
|
106
98
|
class SAM2Model(torch.nn.Module):
|
|
107
|
-
"""
|
|
108
|
-
SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities.
|
|
99
|
+
"""SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities.
|
|
109
100
|
|
|
110
|
-
This class extends the functionality of SAM to handle video sequences, incorporating memory mechanisms
|
|
111
|
-
|
|
101
|
+
This class extends the functionality of SAM to handle video sequences, incorporating memory mechanisms for temporal
|
|
102
|
+
consistency and efficient tracking of objects across frames.
|
|
112
103
|
|
|
113
104
|
Attributes:
|
|
114
105
|
mask_threshold (float): Threshold value for mask prediction.
|
|
@@ -124,10 +115,48 @@ class SAM2Model(torch.nn.Module):
|
|
|
124
115
|
sam_mask_decoder (SAM2MaskDecoder): Decoder for generating object masks.
|
|
125
116
|
obj_ptr_proj (nn.Module): Projection layer for object pointers.
|
|
126
117
|
obj_ptr_tpos_proj (nn.Module): Projection for temporal positional encoding in object pointers.
|
|
118
|
+
hidden_dim (int): Hidden dimension of the model.
|
|
119
|
+
mem_dim (int): Memory dimension for encoding features.
|
|
120
|
+
use_high_res_features_in_sam (bool): Whether to use high-resolution feature maps in the SAM mask decoder.
|
|
121
|
+
use_obj_ptrs_in_encoder (bool): Whether to cross-attend to object pointers from other frames in the encoder.
|
|
122
|
+
max_obj_ptrs_in_encoder (int): Maximum number of object pointers from other frames in encoder cross-attention.
|
|
123
|
+
add_tpos_enc_to_obj_ptrs (bool): Whether to add temporal positional encoding to object pointers.
|
|
124
|
+
proj_tpos_enc_in_obj_ptrs (bool): Whether to add an extra linear projection layer for temporal positional
|
|
125
|
+
encoding in object pointers.
|
|
126
|
+
use_signed_tpos_enc_to_obj_ptrs (bool): Whether to use signed distance in temporal positional encoding.
|
|
127
|
+
only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past during
|
|
128
|
+
evaluation.
|
|
129
|
+
pred_obj_scores (bool): Whether to predict if there is an object in the frame.
|
|
130
|
+
pred_obj_scores_mlp (bool): Whether to use an MLP to predict object scores.
|
|
131
|
+
fixed_no_obj_ptr (bool): Whether to have a fixed no-object pointer when there is no object present.
|
|
132
|
+
soft_no_obj_ptr (bool): Whether to mix in no-object pointer softly for easier recovery and error mitigation.
|
|
133
|
+
use_mlp_for_obj_ptr_proj (bool): Whether to use MLP for object pointer projection.
|
|
134
|
+
no_obj_embed_spatial (torch.Tensor | None): No-object embedding for spatial frames.
|
|
135
|
+
max_cond_frames_in_attn (int): Maximum number of conditioning frames to participate in memory attention.
|
|
136
|
+
directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the first
|
|
137
|
+
frame.
|
|
138
|
+
multimask_output_in_sam (bool): Whether to output multiple masks for the first click on initial conditioning
|
|
139
|
+
frames.
|
|
140
|
+
multimask_min_pt_num (int): Minimum number of clicks to use multimask output in SAM.
|
|
141
|
+
multimask_max_pt_num (int): Maximum number of clicks to use multimask output in SAM.
|
|
142
|
+
multimask_output_for_tracking (bool): Whether to use multimask output for tracking.
|
|
143
|
+
use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers.
|
|
144
|
+
iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1].
|
|
145
|
+
memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation.
|
|
146
|
+
non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in memory
|
|
147
|
+
encoder during evaluation.
|
|
148
|
+
sigmoid_scale_for_mem_enc (float): Scale factor for mask sigmoid probability.
|
|
149
|
+
sigmoid_bias_for_mem_enc (float): Bias factor for mask sigmoid probability.
|
|
150
|
+
binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames with
|
|
151
|
+
clicks during evaluation.
|
|
152
|
+
use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM prompt
|
|
153
|
+
encoder and mask decoder on frames with mask input.
|
|
127
154
|
|
|
128
155
|
Methods:
|
|
129
|
-
forward_image:
|
|
130
|
-
track_step:
|
|
156
|
+
forward_image: Process image batch through encoder to extract multi-level features.
|
|
157
|
+
track_step: Perform a single tracking step, updating object masks and memory features.
|
|
158
|
+
set_binarize: Set binarize for VideoPredictor.
|
|
159
|
+
set_imgsz: Set image size to make model compatible with different image sizes.
|
|
131
160
|
|
|
132
161
|
Examples:
|
|
133
162
|
>>> model = SAM2Model(image_encoder, memory_attention, memory_encoder)
|
|
@@ -176,56 +205,53 @@ class SAM2Model(torch.nn.Module):
|
|
|
176
205
|
sam_mask_decoder_extra_args=None,
|
|
177
206
|
compile_image_encoder: bool = False,
|
|
178
207
|
):
|
|
179
|
-
"""
|
|
180
|
-
Initialize the SAM2Model for video object segmentation with memory-based tracking.
|
|
208
|
+
"""Initialize the SAM2Model for video object segmentation with memory-based tracking.
|
|
181
209
|
|
|
182
210
|
Args:
|
|
183
211
|
image_encoder (nn.Module): Visual encoder for extracting image features.
|
|
184
212
|
memory_attention (nn.Module): Module for attending to memory features.
|
|
185
213
|
memory_encoder (nn.Module): Encoder for generating memory representations.
|
|
186
|
-
num_maskmem (int): Number of accessible memory frames.
|
|
214
|
+
num_maskmem (int): Number of accessible memory frames.
|
|
187
215
|
image_size (int): Size of input images.
|
|
188
216
|
backbone_stride (int): Stride of the image backbone output.
|
|
189
217
|
sigmoid_scale_for_mem_enc (float): Scale factor for mask sigmoid probability.
|
|
190
218
|
sigmoid_bias_for_mem_enc (float): Bias factor for mask sigmoid probability.
|
|
191
|
-
binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames
|
|
192
|
-
|
|
219
|
+
binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames with
|
|
220
|
+
clicks during evaluation.
|
|
193
221
|
use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM
|
|
194
222
|
prompt encoder and mask decoder on frames with mask input.
|
|
195
223
|
max_cond_frames_in_attn (int): Maximum number of conditioning frames to participate in memory attention.
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
first frame.
|
|
224
|
+
directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the first
|
|
225
|
+
frame.
|
|
199
226
|
use_high_res_features_in_sam (bool): Whether to use high-resolution feature maps in the SAM mask decoder.
|
|
200
|
-
multimask_output_in_sam (bool): Whether to output multiple
|
|
201
|
-
|
|
227
|
+
multimask_output_in_sam (bool): Whether to output multiple masks for the first click on initial conditioning
|
|
228
|
+
frames.
|
|
202
229
|
multimask_min_pt_num (int): Minimum number of clicks to use multimask output in SAM.
|
|
203
230
|
multimask_max_pt_num (int): Maximum number of clicks to use multimask output in SAM.
|
|
204
231
|
multimask_output_for_tracking (bool): Whether to use multimask output for tracking.
|
|
205
232
|
use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers.
|
|
206
233
|
iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1].
|
|
207
234
|
memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation.
|
|
208
|
-
non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in
|
|
209
|
-
|
|
235
|
+
non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in memory
|
|
236
|
+
encoder during evaluation.
|
|
210
237
|
use_obj_ptrs_in_encoder (bool): Whether to cross-attend to object pointers from other frames in the encoder.
|
|
211
238
|
max_obj_ptrs_in_encoder (int): Maximum number of object pointers from other frames in encoder
|
|
212
239
|
cross-attention.
|
|
213
|
-
add_tpos_enc_to_obj_ptrs (bool): Whether to add temporal positional encoding to object pointers in
|
|
214
|
-
|
|
240
|
+
add_tpos_enc_to_obj_ptrs (bool): Whether to add temporal positional encoding to object pointers in the
|
|
241
|
+
encoder.
|
|
215
242
|
proj_tpos_enc_in_obj_ptrs (bool): Whether to add an extra linear projection layer for temporal positional
|
|
216
243
|
encoding in object pointers.
|
|
217
|
-
use_signed_tpos_enc_to_obj_ptrs (bool): Whether to use signed distance
|
|
218
|
-
in the
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
during evaluation.
|
|
244
|
+
use_signed_tpos_enc_to_obj_ptrs (bool): Whether to use signed distance in the temporal positional encoding
|
|
245
|
+
in the object pointers.
|
|
246
|
+
only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past during
|
|
247
|
+
evaluation.
|
|
222
248
|
pred_obj_scores (bool): Whether to predict if there is an object in the frame.
|
|
223
249
|
pred_obj_scores_mlp (bool): Whether to use an MLP to predict object scores.
|
|
224
250
|
fixed_no_obj_ptr (bool): Whether to have a fixed no-object pointer when there is no object present.
|
|
225
251
|
soft_no_obj_ptr (bool): Whether to mix in no-object pointer softly for easier recovery and error mitigation.
|
|
226
252
|
use_mlp_for_obj_ptr_proj (bool): Whether to use MLP for object pointer projection.
|
|
227
253
|
no_obj_embed_spatial (bool): Whether add no obj embedding to spatial frames.
|
|
228
|
-
sam_mask_decoder_extra_args (
|
|
254
|
+
sam_mask_decoder_extra_args (dict | None): Extra arguments for constructing the SAM mask decoder.
|
|
229
255
|
compile_image_encoder (bool): Whether to compile the image encoder for faster inference.
|
|
230
256
|
|
|
231
257
|
Examples:
|
|
@@ -398,36 +424,32 @@ class SAM2Model(torch.nn.Module):
|
|
|
398
424
|
high_res_features=None,
|
|
399
425
|
multimask_output=False,
|
|
400
426
|
):
|
|
401
|
-
"""
|
|
402
|
-
Forward pass through SAM prompt encoders and mask heads.
|
|
427
|
+
"""Forward pass through SAM prompt encoders and mask heads.
|
|
403
428
|
|
|
404
429
|
This method processes image features and optional point/mask inputs to generate object masks and scores.
|
|
405
430
|
|
|
406
431
|
Args:
|
|
407
432
|
backbone_features (torch.Tensor): Image features with shape (B, C, H, W).
|
|
408
|
-
point_inputs (
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
mask_inputs (torch.Tensor | None): Mask of shape (B, 1, H*16, W*16), float or bool, with the
|
|
414
|
-
|
|
415
|
-
high_res_features (
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
output only 1 mask and its IoU estimate.
|
|
433
|
+
point_inputs (dict[str, torch.Tensor] | None): Dictionary containing point prompts.
|
|
434
|
+
'point_coords': Tensor of shape (B, P, 2) with float32 dtype, containing absolute pixel-unit coordinates in
|
|
435
|
+
(x, y) format for P input points.
|
|
436
|
+
'point_labels': Tensor of shape (B, P) with int32 dtype, where 1 means positive clicks, 0 means negative
|
|
437
|
+
clicks, and -1 means padding.
|
|
438
|
+
mask_inputs (torch.Tensor | None): Mask of shape (B, 1, H*16, W*16), float or bool, with the same spatial
|
|
439
|
+
size as the image.
|
|
440
|
+
high_res_features (list[torch.Tensor] | None): List of two feature maps with shapes (B, C, 4*H, 4*W) and (B,
|
|
441
|
+
C, 2*H, 2*W) respectively, used as high-resolution feature maps for SAM decoder.
|
|
442
|
+
multimask_output (bool): If True, output 3 candidate masks and their IoU estimates; if False, output only 1
|
|
443
|
+
mask and its IoU estimate.
|
|
420
444
|
|
|
421
445
|
Returns:
|
|
422
|
-
(
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
object_score_logits: Tensor of shape (B) with object score logits.
|
|
430
|
-
Where M is 3 if multimask_output=True, and 1 if multimask_output=False.
|
|
446
|
+
low_res_multimasks (torch.Tensor): Tensor of shape (B, M, H*4, W*4) with SAM output mask logits.
|
|
447
|
+
high_res_multimasks (torch.Tensor): Tensor of shape (B, M, H*16, W*16) with upsampled mask logits.
|
|
448
|
+
ious (torch.Tensor): Tensor of shape (B, M) with estimated IoU for each output mask.
|
|
449
|
+
low_res_masks (torch.Tensor): Tensor of shape (B, 1, H*4, W*4) with the best low-resolution mask.
|
|
450
|
+
high_res_masks (torch.Tensor): Tensor of shape (B, 1, H*16, W*16) with the best high-resolution mask.
|
|
451
|
+
obj_ptr (torch.Tensor): Tensor of shape (B, C) with object pointer vector for the output mask.
|
|
452
|
+
object_score_logits (torch.Tensor): Tensor of shape (B) with object score logits.
|
|
431
453
|
|
|
432
454
|
Examples:
|
|
433
455
|
>>> backbone_features = torch.rand(1, 256, 32, 32)
|
|
@@ -444,7 +466,7 @@ class SAM2Model(torch.nn.Module):
|
|
|
444
466
|
... object_score_logits,
|
|
445
467
|
... ) = results
|
|
446
468
|
"""
|
|
447
|
-
B = backbone_features.
|
|
469
|
+
B = backbone_features.shape[0]
|
|
448
470
|
device = backbone_features.device
|
|
449
471
|
assert backbone_features.size(1) == self.sam_prompt_embed_dim
|
|
450
472
|
assert backbone_features.size(2) == self.sam_image_embedding_size
|
|
@@ -454,10 +476,10 @@ class SAM2Model(torch.nn.Module):
|
|
|
454
476
|
if point_inputs is not None:
|
|
455
477
|
sam_point_coords = point_inputs["point_coords"]
|
|
456
478
|
sam_point_labels = point_inputs["point_labels"]
|
|
457
|
-
assert sam_point_coords.
|
|
479
|
+
assert sam_point_coords.shape[0] == B and sam_point_labels.shape[0] == B
|
|
458
480
|
else:
|
|
459
481
|
# If no points are provide, pad with an empty point (with label -1)
|
|
460
|
-
sam_point_coords = torch.zeros(B, 1, 2, device=device)
|
|
482
|
+
sam_point_coords = torch.zeros(B, 1, 2, device=device, dtype=backbone_features.dtype)
|
|
461
483
|
sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
|
|
462
484
|
|
|
463
485
|
# b) Handle mask prompts
|
|
@@ -502,7 +524,6 @@ class SAM2Model(torch.nn.Module):
|
|
|
502
524
|
|
|
503
525
|
# convert masks from possibly bfloat16 (or float16) to float32
|
|
504
526
|
# (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
|
|
505
|
-
low_res_multimasks = low_res_multimasks.float()
|
|
506
527
|
high_res_multimasks = F.interpolate(
|
|
507
528
|
low_res_multimasks,
|
|
508
529
|
size=(self.image_size, self.image_size),
|
|
@@ -529,12 +550,11 @@ class SAM2Model(torch.nn.Module):
|
|
|
529
550
|
if self.soft_no_obj_ptr:
|
|
530
551
|
lambda_is_obj_appearing = object_score_logits.sigmoid()
|
|
531
552
|
else:
|
|
532
|
-
lambda_is_obj_appearing = is_obj_appearing.
|
|
553
|
+
lambda_is_obj_appearing = is_obj_appearing.to(obj_ptr.dtype)
|
|
533
554
|
|
|
534
555
|
if self.fixed_no_obj_ptr:
|
|
535
556
|
obj_ptr = lambda_is_obj_appearing * obj_ptr
|
|
536
557
|
obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
|
|
537
|
-
|
|
538
558
|
return (
|
|
539
559
|
low_res_multimasks,
|
|
540
560
|
high_res_multimasks,
|
|
@@ -545,7 +565,7 @@ class SAM2Model(torch.nn.Module):
|
|
|
545
565
|
object_score_logits,
|
|
546
566
|
)
|
|
547
567
|
|
|
548
|
-
def _use_mask_as_output(self, backbone_features, high_res_features
|
|
568
|
+
def _use_mask_as_output(self, mask_inputs, backbone_features=None, high_res_features=None):
|
|
549
569
|
"""Process mask inputs directly as output, bypassing SAM encoder/decoder."""
|
|
550
570
|
# Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
|
|
551
571
|
out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
|
|
@@ -559,10 +579,10 @@ class SAM2Model(torch.nn.Module):
|
|
|
559
579
|
antialias=True, # use antialias for downsampling
|
|
560
580
|
)
|
|
561
581
|
# a dummy IoU prediction of all 1's under mask input
|
|
562
|
-
ious = mask_inputs.new_ones(mask_inputs.
|
|
563
|
-
if not self.use_obj_ptrs_in_encoder:
|
|
582
|
+
ious = mask_inputs.new_ones(mask_inputs.shape[0], 1).float()
|
|
583
|
+
if not self.use_obj_ptrs_in_encoder or backbone_features is None or high_res_features is None:
|
|
564
584
|
# all zeros as a dummy object pointer (of shape [B, C])
|
|
565
|
-
obj_ptr = torch.zeros(mask_inputs.
|
|
585
|
+
obj_ptr = torch.zeros(mask_inputs.shape[0], self.hidden_dim, device=mask_inputs.device)
|
|
566
586
|
else:
|
|
567
587
|
# produce an object pointer using the SAM decoder from the mask input
|
|
568
588
|
_, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
|
|
@@ -686,7 +706,7 @@ class SAM2Model(torch.nn.Module):
|
|
|
686
706
|
continue # skip padding frames
|
|
687
707
|
# "maskmem_features" might have been offloaded to CPU in demo use cases,
|
|
688
708
|
# so we load it back to inference device (it's a no-op if it's already on device).
|
|
689
|
-
feats = prev["maskmem_features"].to(device=device, non_blocking=
|
|
709
|
+
feats = prev["maskmem_features"].to(device=device, non_blocking=device.type == "cuda")
|
|
690
710
|
to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
|
|
691
711
|
# Spatial positional encoding (it might have been offloaded to CPU in eval)
|
|
692
712
|
maskmem_enc = prev["maskmem_pos_enc"][-1].to(device=device)
|
|
@@ -738,7 +758,7 @@ class SAM2Model(torch.nn.Module):
|
|
|
738
758
|
if self.add_tpos_enc_to_obj_ptrs:
|
|
739
759
|
t_diff_max = max_obj_ptrs_in_encoder - 1
|
|
740
760
|
tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
|
|
741
|
-
obj_pos = torch.tensor(pos_list, device=device)
|
|
761
|
+
obj_pos = torch.tensor(pos_list, device=device, dtype=current_vision_feats[-1].dtype)
|
|
742
762
|
obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
|
|
743
763
|
obj_pos = self.obj_ptr_tpos_proj(obj_pos)
|
|
744
764
|
obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
|
|
@@ -803,7 +823,7 @@ class SAM2Model(torch.nn.Module):
|
|
|
803
823
|
# scale the raw mask logits with a temperature before applying sigmoid
|
|
804
824
|
binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
|
|
805
825
|
if binarize and not self.training:
|
|
806
|
-
mask_for_mem = (pred_masks_high_res > 0).
|
|
826
|
+
mask_for_mem = (pred_masks_high_res > 0).to(pix_feat.dtype)
|
|
807
827
|
else:
|
|
808
828
|
# apply sigmoid on the raw mask logits to turn them into range (0, 1)
|
|
809
829
|
mask_for_mem = torch.sigmoid(pred_masks_high_res)
|
|
@@ -840,7 +860,6 @@ class SAM2Model(torch.nn.Module):
|
|
|
840
860
|
prev_sam_mask_logits,
|
|
841
861
|
):
|
|
842
862
|
"""Perform a single tracking step, updating object masks and memory features based on current frame inputs."""
|
|
843
|
-
current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
|
|
844
863
|
# High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
|
|
845
864
|
if len(current_vision_feats) > 1:
|
|
846
865
|
high_res_features = [
|
|
@@ -854,7 +873,7 @@ class SAM2Model(torch.nn.Module):
|
|
|
854
873
|
# (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
|
|
855
874
|
pix_feat = current_vision_feats[-1].permute(1, 2, 0)
|
|
856
875
|
pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
|
|
857
|
-
sam_outputs = self._use_mask_as_output(pix_feat, high_res_features
|
|
876
|
+
sam_outputs = self._use_mask_as_output(mask_inputs, pix_feat, high_res_features)
|
|
858
877
|
else:
|
|
859
878
|
# fused the visual feature with previous memory features in the memory bank
|
|
860
879
|
pix_feat = self._prepare_memory_conditioned_features(
|
|
@@ -882,7 +901,7 @@ class SAM2Model(torch.nn.Module):
|
|
|
882
901
|
high_res_features=high_res_features,
|
|
883
902
|
multimask_output=multimask_output,
|
|
884
903
|
)
|
|
885
|
-
return
|
|
904
|
+
return sam_outputs, high_res_features, pix_feat
|
|
886
905
|
|
|
887
906
|
def _encode_memory_in_output(
|
|
888
907
|
self,
|
|
@@ -896,11 +915,10 @@ class SAM2Model(torch.nn.Module):
|
|
|
896
915
|
):
|
|
897
916
|
"""Run memory encoder on predicted mask to encode it into a new memory feature for future frames."""
|
|
898
917
|
if run_mem_encoder and self.num_maskmem > 0:
|
|
899
|
-
high_res_masks_for_mem_enc = high_res_masks
|
|
900
918
|
maskmem_features, maskmem_pos_enc = self._encode_new_memory(
|
|
901
919
|
current_vision_feats=current_vision_feats,
|
|
902
920
|
feat_sizes=feat_sizes,
|
|
903
|
-
pred_masks_high_res=
|
|
921
|
+
pred_masks_high_res=high_res_masks,
|
|
904
922
|
object_score_logits=object_score_logits,
|
|
905
923
|
is_mask_from_pts=(point_inputs is not None),
|
|
906
924
|
)
|
|
@@ -932,7 +950,7 @@ class SAM2Model(torch.nn.Module):
|
|
|
932
950
|
prev_sam_mask_logits=None,
|
|
933
951
|
):
|
|
934
952
|
"""Perform a single tracking step, updating object masks and memory features based on current frame inputs."""
|
|
935
|
-
|
|
953
|
+
sam_outputs, _, _ = self._track_step(
|
|
936
954
|
frame_idx,
|
|
937
955
|
is_init_cond_frame,
|
|
938
956
|
current_vision_feats,
|
|
@@ -947,9 +965,11 @@ class SAM2Model(torch.nn.Module):
|
|
|
947
965
|
)
|
|
948
966
|
_, _, _, low_res_masks, high_res_masks, obj_ptr, object_score_logits = sam_outputs
|
|
949
967
|
|
|
950
|
-
current_out
|
|
951
|
-
|
|
952
|
-
|
|
968
|
+
current_out = {
|
|
969
|
+
"pred_masks": low_res_masks,
|
|
970
|
+
"pred_masks_high_res": high_res_masks,
|
|
971
|
+
"obj_ptr": obj_ptr,
|
|
972
|
+
}
|
|
953
973
|
if not self.training:
|
|
954
974
|
# Only add this in inference (to avoid unused param in activation checkpointing;
|
|
955
975
|
# it's mainly used in the demo to encode spatial memories w/ consolidated masks)
|
|
@@ -980,7 +1000,7 @@ class SAM2Model(torch.nn.Module):
|
|
|
980
1000
|
@staticmethod
|
|
981
1001
|
def _apply_non_overlapping_constraints(pred_masks):
|
|
982
1002
|
"""Apply non-overlapping constraints to masks, keeping the highest scoring object per location."""
|
|
983
|
-
batch_size = pred_masks.
|
|
1003
|
+
batch_size = pred_masks.shape[0]
|
|
984
1004
|
if batch_size == 1:
|
|
985
1005
|
return pred_masks
|
|
986
1006
|
|
|
@@ -1004,3 +1024,4 @@ class SAM2Model(torch.nn.Module):
|
|
|
1004
1024
|
self.image_size = imgsz[0]
|
|
1005
1025
|
self.sam_prompt_encoder.input_image_size = imgsz
|
|
1006
1026
|
self.sam_prompt_encoder.image_embedding_size = [x // 16 for x in imgsz] # fixed ViT patch size of 16
|
|
1027
|
+
self.sam_image_embedding_size = self.image_size // self.backbone_stride # update image embedding size
|