dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.3.248__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +13 -14
- dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
- tests/__init__.py +5 -7
- tests/conftest.py +8 -15
- tests/test_cli.py +1 -1
- tests/test_cuda.py +5 -8
- tests/test_engine.py +1 -1
- tests/test_exports.py +57 -12
- tests/test_integrations.py +4 -4
- tests/test_python.py +84 -53
- tests/test_solutions.py +160 -151
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +56 -62
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/VOC.yaml +15 -16
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +21 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +5 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +1 -1
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +3 -4
- ultralytics/data/augment.py +285 -475
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +147 -25
- ultralytics/data/converter.py +36 -46
- ultralytics/data/dataset.py +46 -74
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +34 -43
- ultralytics/engine/exporter.py +319 -237
- ultralytics/engine/model.py +148 -188
- ultralytics/engine/predictor.py +29 -38
- ultralytics/engine/results.py +177 -311
- ultralytics/engine/trainer.py +83 -59
- ultralytics/engine/tuner.py +23 -34
- ultralytics/engine/validator.py +39 -22
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +17 -29
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +5 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +19 -37
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +200 -73
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2248 -350
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +9 -12
- ultralytics/models/yolo/classify/train.py +11 -32
- ultralytics/models/yolo/classify/val.py +29 -28
- ultralytics/models/yolo/detect/predict.py +7 -10
- ultralytics/models/yolo/detect/train.py +11 -20
- ultralytics/models/yolo/detect/val.py +70 -58
- ultralytics/models/yolo/model.py +36 -53
- ultralytics/models/yolo/obb/predict.py +5 -14
- ultralytics/models/yolo/obb/train.py +11 -14
- ultralytics/models/yolo/obb/val.py +39 -36
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +6 -21
- ultralytics/models/yolo/pose/train.py +10 -15
- ultralytics/models/yolo/pose/val.py +38 -57
- ultralytics/models/yolo/segment/predict.py +14 -18
- ultralytics/models/yolo/segment/train.py +3 -6
- ultralytics/models/yolo/segment/val.py +93 -45
- ultralytics/models/yolo/world/train.py +8 -14
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +30 -43
- ultralytics/models/yolo/yoloe/train_seg.py +5 -10
- ultralytics/models/yolo/yoloe/val.py +15 -20
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +145 -77
- ultralytics/nn/modules/__init__.py +60 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +132 -216
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +50 -103
- ultralytics/nn/modules/transformer.py +76 -88
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +94 -154
- ultralytics/nn/text_model.py +40 -67
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +11 -17
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +5 -6
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +7 -13
- ultralytics/solutions/instance_segmentation.py +5 -8
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +33 -31
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +10 -15
- ultralytics/solutions/solutions.py +75 -74
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +2 -4
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +14 -30
- ultralytics/trackers/track.py +3 -6
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +116 -116
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +32 -46
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +314 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +23 -31
- ultralytics/utils/callbacks/wb.py +10 -13
- ultralytics/utils/checks.py +99 -76
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +8 -12
- ultralytics/utils/downloads.py +20 -30
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +4 -236
- ultralytics/utils/export/engine.py +237 -0
- ultralytics/utils/export/imx.py +91 -55
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +24 -28
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +14 -22
- ultralytics/utils/metrics.py +126 -155
- ultralytics/utils/nms.py +13 -16
- ultralytics/utils/ops.py +107 -165
- ultralytics/utils/patches.py +33 -21
- ultralytics/utils/plotting.py +72 -80
- ultralytics/utils/tal.py +25 -39
- ultralytics/utils/torch_utils.py +52 -78
- ultralytics/utils/tqdm.py +20 -20
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +17 -5
- dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
|
@@ -17,8 +17,7 @@ from .utils import add_decomposed_rel_pos, apply_rotary_enc, compute_axial_cis,
|
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
class DropPath(nn.Module):
|
|
20
|
-
"""
|
|
21
|
-
Implements stochastic depth regularization for neural networks during training.
|
|
20
|
+
"""Implements stochastic depth regularization for neural networks during training.
|
|
22
21
|
|
|
23
22
|
Attributes:
|
|
24
23
|
drop_prob (float): Probability of dropping a path during training.
|
|
@@ -52,16 +51,14 @@ class DropPath(nn.Module):
|
|
|
52
51
|
|
|
53
52
|
|
|
54
53
|
class MaskDownSampler(nn.Module):
|
|
55
|
-
"""
|
|
56
|
-
A mask downsampling and embedding module for efficient processing of input masks.
|
|
54
|
+
"""A mask downsampling and embedding module for efficient processing of input masks.
|
|
57
55
|
|
|
58
|
-
This class implements a mask downsampler that progressively reduces the spatial dimensions of input masks
|
|
59
|
-
|
|
60
|
-
functions.
|
|
56
|
+
This class implements a mask downsampler that progressively reduces the spatial dimensions of input masks while
|
|
57
|
+
expanding their channel dimensions using convolutional layers, layer normalization, and activation functions.
|
|
61
58
|
|
|
62
59
|
Attributes:
|
|
63
|
-
encoder (nn.Sequential): A sequential container of convolutional layers, layer normalization, and
|
|
64
|
-
|
|
60
|
+
encoder (nn.Sequential): A sequential container of convolutional layers, layer normalization, and activation
|
|
61
|
+
functions for downsampling and embedding masks.
|
|
65
62
|
|
|
66
63
|
Methods:
|
|
67
64
|
forward: Downsamples and encodes input mask to embed_dim channels.
|
|
@@ -82,6 +79,7 @@ class MaskDownSampler(nn.Module):
|
|
|
82
79
|
padding: int = 0,
|
|
83
80
|
total_stride: int = 16,
|
|
84
81
|
activation: type[nn.Module] = nn.GELU,
|
|
82
|
+
interpol_size: tuple[int, int] | None = None,
|
|
85
83
|
):
|
|
86
84
|
"""Initialize a mask downsampler module for progressive downsampling and channel expansion."""
|
|
87
85
|
super().__init__()
|
|
@@ -105,18 +103,32 @@ class MaskDownSampler(nn.Module):
|
|
|
105
103
|
mask_in_chans = mask_out_chans
|
|
106
104
|
|
|
107
105
|
self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
|
|
106
|
+
self.interpol_size = interpol_size
|
|
107
|
+
if self.interpol_size is not None:
|
|
108
|
+
assert isinstance(self.interpol_size, (list, tuple)), (
|
|
109
|
+
f"Unsupported type {type(self.interpol_size)}. Should be a list or tuple."
|
|
110
|
+
)
|
|
111
|
+
self.interpol_size = list(interpol_size)
|
|
112
|
+
assert len(self.interpol_size) == 2
|
|
108
113
|
|
|
109
114
|
def forward(self, x: Tensor) -> Tensor:
|
|
110
115
|
"""Downsample and encode input mask to embed_dim channels using convolutional layers and LayerNorm2d."""
|
|
116
|
+
if self.interpol_size is not None and self.interpol_size != list(x.shape[-2:]):
|
|
117
|
+
x = F.interpolate(
|
|
118
|
+
x.float(),
|
|
119
|
+
size=self.interpol_size,
|
|
120
|
+
align_corners=False,
|
|
121
|
+
mode="bilinear",
|
|
122
|
+
antialias=True,
|
|
123
|
+
).to(x.dtype)
|
|
111
124
|
return self.encoder(x)
|
|
112
125
|
|
|
113
126
|
|
|
114
127
|
class CXBlock(nn.Module):
|
|
115
|
-
"""
|
|
116
|
-
ConvNeXt Block for efficient feature extraction in convolutional neural networks.
|
|
128
|
+
"""ConvNeXt Block for efficient feature extraction in convolutional neural networks.
|
|
117
129
|
|
|
118
|
-
This block implements a modified version of the ConvNeXt architecture, offering improved performance and
|
|
119
|
-
|
|
130
|
+
This block implements a modified version of the ConvNeXt architecture, offering improved performance and flexibility
|
|
131
|
+
in feature extraction.
|
|
120
132
|
|
|
121
133
|
Attributes:
|
|
122
134
|
dwconv (nn.Conv2d): Depthwise or standard 2D convolution layer.
|
|
@@ -148,8 +160,7 @@ class CXBlock(nn.Module):
|
|
|
148
160
|
layer_scale_init_value: float = 1e-6,
|
|
149
161
|
use_dwconv: bool = True,
|
|
150
162
|
):
|
|
151
|
-
"""
|
|
152
|
-
Initialize a ConvNeXt Block for efficient feature extraction in convolutional neural networks.
|
|
163
|
+
"""Initialize a ConvNeXt Block for efficient feature extraction in convolutional neural networks.
|
|
153
164
|
|
|
154
165
|
This block implements a modified version of the ConvNeXt architecture, offering improved performance and
|
|
155
166
|
flexibility in feature extraction.
|
|
@@ -161,13 +172,6 @@ class CXBlock(nn.Module):
|
|
|
161
172
|
drop_path (float): Stochastic depth rate.
|
|
162
173
|
layer_scale_init_value (float): Initial value for Layer Scale.
|
|
163
174
|
use_dwconv (bool): Whether to use depthwise convolution.
|
|
164
|
-
|
|
165
|
-
Examples:
|
|
166
|
-
>>> block = CXBlock(dim=64, kernel_size=7, padding=3)
|
|
167
|
-
>>> x = torch.randn(1, 64, 32, 32)
|
|
168
|
-
>>> output = block(x)
|
|
169
|
-
>>> print(output.shape)
|
|
170
|
-
torch.Size([1, 64, 32, 32])
|
|
171
175
|
"""
|
|
172
176
|
super().__init__()
|
|
173
177
|
self.dwconv = nn.Conv2d(
|
|
@@ -206,8 +210,7 @@ class CXBlock(nn.Module):
|
|
|
206
210
|
|
|
207
211
|
|
|
208
212
|
class Fuser(nn.Module):
|
|
209
|
-
"""
|
|
210
|
-
A module for fusing features through multiple layers of a neural network.
|
|
213
|
+
"""A module for fusing features through multiple layers of a neural network.
|
|
211
214
|
|
|
212
215
|
This class applies a series of identical layers to an input tensor, optionally projecting the input first.
|
|
213
216
|
|
|
@@ -228,8 +231,7 @@ class Fuser(nn.Module):
|
|
|
228
231
|
"""
|
|
229
232
|
|
|
230
233
|
def __init__(self, layer: nn.Module, num_layers: int, dim: int | None = None, input_projection: bool = False):
|
|
231
|
-
"""
|
|
232
|
-
Initialize the Fuser module for feature fusion through multiple layers.
|
|
234
|
+
"""Initialize the Fuser module for feature fusion through multiple layers.
|
|
233
235
|
|
|
234
236
|
This module creates a sequence of identical layers and optionally applies an input projection.
|
|
235
237
|
|
|
@@ -238,12 +240,6 @@ class Fuser(nn.Module):
|
|
|
238
240
|
num_layers (int): The number of times to replicate the layer.
|
|
239
241
|
dim (int | None): The dimension for input projection, if used.
|
|
240
242
|
input_projection (bool): Whether to use input projection.
|
|
241
|
-
|
|
242
|
-
Examples:
|
|
243
|
-
>>> layer = nn.Linear(64, 64)
|
|
244
|
-
>>> fuser = Fuser(layer, num_layers=3, dim=64, input_projection=True)
|
|
245
|
-
>>> input_tensor = torch.randn(1, 64)
|
|
246
|
-
>>> output = fuser(input_tensor)
|
|
247
243
|
"""
|
|
248
244
|
super().__init__()
|
|
249
245
|
self.proj = nn.Identity()
|
|
@@ -262,12 +258,11 @@ class Fuser(nn.Module):
|
|
|
262
258
|
|
|
263
259
|
|
|
264
260
|
class SAM2TwoWayAttentionBlock(TwoWayAttentionBlock):
|
|
265
|
-
"""
|
|
266
|
-
A two-way attention block for performing self-attention and cross-attention in both directions.
|
|
261
|
+
"""A two-way attention block for performing self-attention and cross-attention in both directions.
|
|
267
262
|
|
|
268
|
-
This block extends the TwoWayAttentionBlock and consists of four main components: self-attention on
|
|
269
|
-
|
|
270
|
-
|
|
263
|
+
This block extends the TwoWayAttentionBlock and consists of four main components: self-attention on sparse inputs,
|
|
264
|
+
cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and cross-attention from dense to sparse
|
|
265
|
+
inputs.
|
|
271
266
|
|
|
272
267
|
Attributes:
|
|
273
268
|
self_attn (Attention): Self-attention layer for queries.
|
|
@@ -299,12 +294,11 @@ class SAM2TwoWayAttentionBlock(TwoWayAttentionBlock):
|
|
|
299
294
|
attention_downsample_rate: int = 2,
|
|
300
295
|
skip_first_layer_pe: bool = False,
|
|
301
296
|
) -> None:
|
|
302
|
-
"""
|
|
303
|
-
Initialize a SAM2TwoWayAttentionBlock for performing self-attention and cross-attention in two directions.
|
|
297
|
+
"""Initialize a SAM2TwoWayAttentionBlock for performing self-attention and cross-attention in two directions.
|
|
304
298
|
|
|
305
299
|
This block extends the TwoWayAttentionBlock and consists of four main components: self-attention on sparse
|
|
306
|
-
inputs, cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and cross-attention
|
|
307
|
-
|
|
300
|
+
inputs, cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and cross-attention from
|
|
301
|
+
dense to sparse inputs.
|
|
308
302
|
|
|
309
303
|
Args:
|
|
310
304
|
embedding_dim (int): The channel dimension of the embeddings.
|
|
@@ -313,24 +307,17 @@ class SAM2TwoWayAttentionBlock(TwoWayAttentionBlock):
|
|
|
313
307
|
activation (Type[nn.Module]): The activation function of the MLP block.
|
|
314
308
|
attention_downsample_rate (int): The downsample rate for attention computations.
|
|
315
309
|
skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer.
|
|
316
|
-
|
|
317
|
-
Examples:
|
|
318
|
-
>>> block = SAM2TwoWayAttentionBlock(embedding_dim=256, num_heads=8, mlp_dim=2048)
|
|
319
|
-
>>> sparse_inputs = torch.randn(1, 100, 256)
|
|
320
|
-
>>> dense_inputs = torch.randn(1, 256, 32, 32)
|
|
321
|
-
>>> sparse_outputs, dense_outputs = block(sparse_inputs, dense_inputs)
|
|
322
310
|
"""
|
|
323
311
|
super().__init__(embedding_dim, num_heads, mlp_dim, activation, attention_downsample_rate, skip_first_layer_pe)
|
|
324
312
|
self.mlp = MLP(embedding_dim, mlp_dim, embedding_dim, num_layers=2, act=activation)
|
|
325
313
|
|
|
326
314
|
|
|
327
315
|
class SAM2TwoWayTransformer(TwoWayTransformer):
|
|
328
|
-
"""
|
|
329
|
-
A Two-Way Transformer module for simultaneous attention to image and query points.
|
|
316
|
+
"""A Two-Way Transformer module for simultaneous attention to image and query points.
|
|
330
317
|
|
|
331
|
-
This class extends the TwoWayTransformer, implementing a specialized transformer decoder that attends to an
|
|
332
|
-
|
|
333
|
-
|
|
318
|
+
This class extends the TwoWayTransformer, implementing a specialized transformer decoder that attends to an input
|
|
319
|
+
image using queries with supplied positional embeddings. It is particularly useful for tasks like object detection,
|
|
320
|
+
image segmentation, and point cloud processing.
|
|
334
321
|
|
|
335
322
|
Attributes:
|
|
336
323
|
depth (int): Number of layers in the transformer.
|
|
@@ -362,11 +349,10 @@ class SAM2TwoWayTransformer(TwoWayTransformer):
|
|
|
362
349
|
activation: type[nn.Module] = nn.ReLU,
|
|
363
350
|
attention_downsample_rate: int = 2,
|
|
364
351
|
) -> None:
|
|
365
|
-
"""
|
|
366
|
-
Initialize a SAM2TwoWayTransformer instance.
|
|
352
|
+
"""Initialize a SAM2TwoWayTransformer instance.
|
|
367
353
|
|
|
368
|
-
This transformer decoder attends to an input image using queries with supplied positional embeddings.
|
|
369
|
-
|
|
354
|
+
This transformer decoder attends to an input image using queries with supplied positional embeddings. It is
|
|
355
|
+
designed for tasks like object detection, image segmentation, and point cloud processing.
|
|
370
356
|
|
|
371
357
|
Args:
|
|
372
358
|
depth (int): Number of layers in the transformer.
|
|
@@ -375,17 +361,6 @@ class SAM2TwoWayTransformer(TwoWayTransformer):
|
|
|
375
361
|
mlp_dim (int): Channel dimension internal to the MLP block.
|
|
376
362
|
activation (Type[nn.Module]): Activation function to use in the MLP block.
|
|
377
363
|
attention_downsample_rate (int): Downsampling rate for attention computations.
|
|
378
|
-
|
|
379
|
-
Examples:
|
|
380
|
-
>>> transformer = SAM2TwoWayTransformer(depth=5, embedding_dim=256, num_heads=8, mlp_dim=2048)
|
|
381
|
-
>>> transformer
|
|
382
|
-
SAM2TwoWayTransformer(
|
|
383
|
-
(layers): ModuleList(
|
|
384
|
-
(0-4): 5 x SAM2TwoWayAttentionBlock(...)
|
|
385
|
-
)
|
|
386
|
-
(final_attn_token_to_image): Attention(...)
|
|
387
|
-
(norm_final_attn): LayerNorm(...)
|
|
388
|
-
)
|
|
389
364
|
"""
|
|
390
365
|
super().__init__(depth, embedding_dim, num_heads, mlp_dim, activation, attention_downsample_rate)
|
|
391
366
|
self.layers = nn.ModuleList()
|
|
@@ -403,11 +378,10 @@ class SAM2TwoWayTransformer(TwoWayTransformer):
|
|
|
403
378
|
|
|
404
379
|
|
|
405
380
|
class RoPEAttention(Attention):
|
|
406
|
-
"""
|
|
407
|
-
Implements rotary position encoding for attention mechanisms in transformer architectures.
|
|
381
|
+
"""Implements rotary position encoding for attention mechanisms in transformer architectures.
|
|
408
382
|
|
|
409
|
-
This class extends the base Attention class by incorporating Rotary Position Encoding (RoPE) to enhance
|
|
410
|
-
|
|
383
|
+
This class extends the base Attention class by incorporating Rotary Position Encoding (RoPE) to enhance the
|
|
384
|
+
positional awareness of the attention mechanism.
|
|
411
385
|
|
|
412
386
|
Attributes:
|
|
413
387
|
compute_cis (Callable): Function to compute axial complex numbers for rotary encoding.
|
|
@@ -471,13 +445,7 @@ class RoPEAttention(Attention):
|
|
|
471
445
|
)
|
|
472
446
|
|
|
473
447
|
# Attention
|
|
474
|
-
|
|
475
|
-
attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
|
|
476
|
-
attn = attn / math.sqrt(c_per_head)
|
|
477
|
-
attn = torch.softmax(attn, dim=-1)
|
|
478
|
-
|
|
479
|
-
# Get output
|
|
480
|
-
out = attn @ v
|
|
448
|
+
out = F.scaled_dot_product_attention(q, k, v)
|
|
481
449
|
|
|
482
450
|
out = self._recombine_heads(out)
|
|
483
451
|
out = self.out_proj(out)
|
|
@@ -501,12 +469,11 @@ def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.T
|
|
|
501
469
|
|
|
502
470
|
|
|
503
471
|
class MultiScaleAttention(nn.Module):
|
|
504
|
-
"""
|
|
505
|
-
Implements multiscale self-attention with optional query pooling for efficient feature extraction.
|
|
472
|
+
"""Implements multiscale self-attention with optional query pooling for efficient feature extraction.
|
|
506
473
|
|
|
507
|
-
This class provides a flexible implementation of multiscale attention, allowing for optional
|
|
508
|
-
|
|
509
|
-
|
|
474
|
+
This class provides a flexible implementation of multiscale attention, allowing for optional downsampling of query
|
|
475
|
+
features through pooling. It's designed to enhance the model's ability to capture multiscale information in visual
|
|
476
|
+
tasks.
|
|
510
477
|
|
|
511
478
|
Attributes:
|
|
512
479
|
dim (int): Input dimension of the feature map.
|
|
@@ -581,11 +548,10 @@ class MultiScaleAttention(nn.Module):
|
|
|
581
548
|
|
|
582
549
|
|
|
583
550
|
class MultiScaleBlock(nn.Module):
|
|
584
|
-
"""
|
|
585
|
-
A multiscale attention block with window partitioning and query pooling for efficient vision transformers.
|
|
551
|
+
"""A multiscale attention block with window partitioning and query pooling for efficient vision transformers.
|
|
586
552
|
|
|
587
|
-
This class implements a multiscale attention mechanism with optional window partitioning and downsampling,
|
|
588
|
-
|
|
553
|
+
This class implements a multiscale attention mechanism with optional window partitioning and downsampling, designed
|
|
554
|
+
for use in vision transformer architectures.
|
|
589
555
|
|
|
590
556
|
Attributes:
|
|
591
557
|
dim (int): Input dimension of the block.
|
|
@@ -619,7 +585,7 @@ class MultiScaleBlock(nn.Module):
|
|
|
619
585
|
mlp_ratio: float = 4.0,
|
|
620
586
|
drop_path: float = 0.0,
|
|
621
587
|
norm_layer: nn.Module | str = "LayerNorm",
|
|
622
|
-
q_stride: tuple[int, int] = None,
|
|
588
|
+
q_stride: tuple[int, int] | None = None,
|
|
623
589
|
act_layer: type[nn.Module] = nn.GELU,
|
|
624
590
|
window_size: int = 0,
|
|
625
591
|
):
|
|
@@ -696,11 +662,10 @@ class MultiScaleBlock(nn.Module):
|
|
|
696
662
|
|
|
697
663
|
|
|
698
664
|
class PositionEmbeddingSine(nn.Module):
|
|
699
|
-
"""
|
|
700
|
-
A module for generating sinusoidal positional embeddings for 2D inputs like images.
|
|
665
|
+
"""A module for generating sinusoidal positional embeddings for 2D inputs like images.
|
|
701
666
|
|
|
702
|
-
This class implements sinusoidal position encoding for 2D spatial positions, which can be used in
|
|
703
|
-
|
|
667
|
+
This class implements sinusoidal position encoding for 2D spatial positions, which can be used in transformer-based
|
|
668
|
+
models for computer vision tasks.
|
|
704
669
|
|
|
705
670
|
Attributes:
|
|
706
671
|
num_pos_feats (int): Number of positional features (half of the embedding dimension).
|
|
@@ -811,8 +776,7 @@ class PositionEmbeddingSine(nn.Module):
|
|
|
811
776
|
|
|
812
777
|
|
|
813
778
|
class PositionEmbeddingRandom(nn.Module):
|
|
814
|
-
"""
|
|
815
|
-
Positional encoding using random spatial frequencies.
|
|
779
|
+
"""Positional encoding using random spatial frequencies.
|
|
816
780
|
|
|
817
781
|
This class generates positional embeddings for input coordinates using random spatial frequencies. It is
|
|
818
782
|
particularly useful for transformer-based models that require position information.
|
|
@@ -878,12 +842,11 @@ class PositionEmbeddingRandom(nn.Module):
|
|
|
878
842
|
|
|
879
843
|
|
|
880
844
|
class Block(nn.Module):
|
|
881
|
-
"""
|
|
882
|
-
Transformer block with support for window attention and residual propagation.
|
|
845
|
+
"""Transformer block with support for window attention and residual propagation.
|
|
883
846
|
|
|
884
|
-
This class implements a transformer block that can use either global or windowed self-attention,
|
|
885
|
-
|
|
886
|
-
|
|
847
|
+
This class implements a transformer block that can use either global or windowed self-attention, followed by a
|
|
848
|
+
feed-forward network. It supports relative positional embeddings and is designed for use in vision transformer
|
|
849
|
+
architectures.
|
|
887
850
|
|
|
888
851
|
Attributes:
|
|
889
852
|
norm1 (nn.Module): First normalization layer.
|
|
@@ -917,12 +880,11 @@ class Block(nn.Module):
|
|
|
917
880
|
window_size: int = 0,
|
|
918
881
|
input_size: tuple[int, int] | None = None,
|
|
919
882
|
) -> None:
|
|
920
|
-
"""
|
|
921
|
-
Initialize a transformer block with optional window attention and relative positional embeddings.
|
|
883
|
+
"""Initialize a transformer block with optional window attention and relative positional embeddings.
|
|
922
884
|
|
|
923
|
-
This constructor sets up a transformer block that can use either global or windowed self-attention,
|
|
924
|
-
|
|
925
|
-
|
|
885
|
+
This constructor sets up a transformer block that can use either global or windowed self-attention, followed by
|
|
886
|
+
a feed-forward network. It supports relative positional embeddings and is designed for use in vision transformer
|
|
887
|
+
architectures.
|
|
926
888
|
|
|
927
889
|
Args:
|
|
928
890
|
dim (int): Number of input channels.
|
|
@@ -935,13 +897,6 @@ class Block(nn.Module):
|
|
|
935
897
|
rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero.
|
|
936
898
|
window_size (int): Size of attention window. If 0, uses global attention.
|
|
937
899
|
input_size (tuple[int, int] | None): Input resolution for calculating relative positional parameter size.
|
|
938
|
-
|
|
939
|
-
Examples:
|
|
940
|
-
>>> block = Block(dim=256, num_heads=8, window_size=7)
|
|
941
|
-
>>> x = torch.randn(1, 56, 56, 256)
|
|
942
|
-
>>> output = block(x)
|
|
943
|
-
>>> print(output.shape)
|
|
944
|
-
torch.Size([1, 56, 56, 256])
|
|
945
900
|
"""
|
|
946
901
|
super().__init__()
|
|
947
902
|
self.norm1 = norm_layer(dim)
|
|
@@ -978,12 +933,11 @@ class Block(nn.Module):
|
|
|
978
933
|
|
|
979
934
|
|
|
980
935
|
class REAttention(nn.Module):
|
|
981
|
-
"""
|
|
982
|
-
Relative Position Attention module for efficient self-attention in transformer architectures.
|
|
936
|
+
"""Relative Position Attention module for efficient self-attention in transformer architectures.
|
|
983
937
|
|
|
984
|
-
This class implements a multi-head attention mechanism with relative positional embeddings, designed
|
|
985
|
-
|
|
986
|
-
|
|
938
|
+
This class implements a multi-head attention mechanism with relative positional embeddings, designed for use in
|
|
939
|
+
vision transformer models. It supports optional query pooling and window partitioning for efficient processing of
|
|
940
|
+
large inputs.
|
|
987
941
|
|
|
988
942
|
Attributes:
|
|
989
943
|
num_heads (int): Number of attention heads.
|
|
@@ -1014,11 +968,10 @@ class REAttention(nn.Module):
|
|
|
1014
968
|
rel_pos_zero_init: bool = True,
|
|
1015
969
|
input_size: tuple[int, int] | None = None,
|
|
1016
970
|
) -> None:
|
|
1017
|
-
"""
|
|
1018
|
-
Initialize a Relative Position Attention module for transformer-based architectures.
|
|
971
|
+
"""Initialize a Relative Position Attention module for transformer-based architectures.
|
|
1019
972
|
|
|
1020
|
-
This module implements multi-head attention with optional relative positional encodings, designed
|
|
1021
|
-
|
|
973
|
+
This module implements multi-head attention with optional relative positional encodings, designed specifically
|
|
974
|
+
for vision tasks in transformer models.
|
|
1022
975
|
|
|
1023
976
|
Args:
|
|
1024
977
|
dim (int): Number of input channels.
|
|
@@ -1028,13 +981,6 @@ class REAttention(nn.Module):
|
|
|
1028
981
|
rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero.
|
|
1029
982
|
input_size (tuple[int, int] | None): Input resolution for calculating relative positional parameter size.
|
|
1030
983
|
Required if use_rel_pos is True.
|
|
1031
|
-
|
|
1032
|
-
Examples:
|
|
1033
|
-
>>> attention = REAttention(dim=256, num_heads=8, input_size=(32, 32))
|
|
1034
|
-
>>> x = torch.randn(1, 32, 32, 256)
|
|
1035
|
-
>>> output = attention(x)
|
|
1036
|
-
>>> print(output.shape)
|
|
1037
|
-
torch.Size([1, 32, 32, 256])
|
|
1038
984
|
"""
|
|
1039
985
|
super().__init__()
|
|
1040
986
|
self.num_heads = num_heads
|
|
@@ -1070,12 +1016,11 @@ class REAttention(nn.Module):
|
|
|
1070
1016
|
|
|
1071
1017
|
|
|
1072
1018
|
class PatchEmbed(nn.Module):
|
|
1073
|
-
"""
|
|
1074
|
-
Image to Patch Embedding module for vision transformer architectures.
|
|
1019
|
+
"""Image to Patch Embedding module for vision transformer architectures.
|
|
1075
1020
|
|
|
1076
|
-
This module converts an input image into a sequence of patch embeddings using a convolutional layer.
|
|
1077
|
-
|
|
1078
|
-
|
|
1021
|
+
This module converts an input image into a sequence of patch embeddings using a convolutional layer. It is commonly
|
|
1022
|
+
used as the first layer in vision transformer architectures to transform image data into a suitable format for
|
|
1023
|
+
subsequent transformer blocks.
|
|
1079
1024
|
|
|
1080
1025
|
Attributes:
|
|
1081
1026
|
proj (nn.Conv2d): Convolutional layer for projecting image patches to embeddings.
|
|
@@ -1098,12 +1043,12 @@ class PatchEmbed(nn.Module):
|
|
|
1098
1043
|
padding: tuple[int, int] = (0, 0),
|
|
1099
1044
|
in_chans: int = 3,
|
|
1100
1045
|
embed_dim: int = 768,
|
|
1046
|
+
bias: bool = True,
|
|
1101
1047
|
) -> None:
|
|
1102
|
-
"""
|
|
1103
|
-
Initialize the PatchEmbed module for converting image patches to embeddings.
|
|
1048
|
+
"""Initialize the PatchEmbed module for converting image patches to embeddings.
|
|
1104
1049
|
|
|
1105
|
-
This module is typically used as the first layer in vision transformer architectures to transform
|
|
1106
|
-
|
|
1050
|
+
This module is typically used as the first layer in vision transformer architectures to transform image data
|
|
1051
|
+
into a suitable format for subsequent transformer blocks.
|
|
1107
1052
|
|
|
1108
1053
|
Args:
|
|
1109
1054
|
kernel_size (tuple[int, int]): Size of the convolutional kernel for patch extraction.
|
|
@@ -1111,17 +1056,11 @@ class PatchEmbed(nn.Module):
|
|
|
1111
1056
|
padding (tuple[int, int]): Padding applied to the input before convolution.
|
|
1112
1057
|
in_chans (int): Number of input image channels.
|
|
1113
1058
|
embed_dim (int): Dimensionality of the output patch embeddings.
|
|
1114
|
-
|
|
1115
|
-
Examples:
|
|
1116
|
-
>>> patch_embed = PatchEmbed(kernel_size=(16, 16), stride=(16, 16), in_chans=3, embed_dim=768)
|
|
1117
|
-
>>> x = torch.randn(1, 3, 224, 224)
|
|
1118
|
-
>>> output = patch_embed(x)
|
|
1119
|
-
>>> print(output.shape)
|
|
1120
|
-
torch.Size([1, 768, 14, 14])
|
|
1059
|
+
bias (bool): Whether to include a bias term in the convolutional layer.
|
|
1121
1060
|
"""
|
|
1122
1061
|
super().__init__()
|
|
1123
1062
|
|
|
1124
|
-
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
|
|
1063
|
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
|
|
1125
1064
|
|
|
1126
1065
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
1127
1066
|
"""Compute patch embedding by applying convolution and transposing resulting tensor."""
|
|
@@ -9,8 +9,7 @@ from ultralytics.nn.modules import MLP, LayerNorm2d
|
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class MaskDecoder(nn.Module):
|
|
12
|
-
"""
|
|
13
|
-
Decoder module for generating masks and their associated quality scores using a transformer architecture.
|
|
12
|
+
"""Decoder module for generating masks and their associated quality scores using a transformer architecture.
|
|
14
13
|
|
|
15
14
|
This class predicts masks given image and prompt embeddings, utilizing a transformer to process the inputs and
|
|
16
15
|
generate mask predictions along with their quality scores.
|
|
@@ -47,8 +46,7 @@ class MaskDecoder(nn.Module):
|
|
|
47
46
|
iou_head_depth: int = 3,
|
|
48
47
|
iou_head_hidden_dim: int = 256,
|
|
49
48
|
) -> None:
|
|
50
|
-
"""
|
|
51
|
-
Initialize the MaskDecoder module for generating masks and their associated quality scores.
|
|
49
|
+
"""Initialize the MaskDecoder module for generating masks and their associated quality scores.
|
|
52
50
|
|
|
53
51
|
Args:
|
|
54
52
|
transformer_dim (int): Channel dimension for the transformer module.
|
|
@@ -57,11 +55,6 @@ class MaskDecoder(nn.Module):
|
|
|
57
55
|
activation (Type[nn.Module]): Type of activation to use when upscaling masks.
|
|
58
56
|
iou_head_depth (int): Depth of the MLP used to predict mask quality.
|
|
59
57
|
iou_head_hidden_dim (int): Hidden dimension of the MLP used to predict mask quality.
|
|
60
|
-
|
|
61
|
-
Examples:
|
|
62
|
-
>>> transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=256, nhead=8), num_layers=6)
|
|
63
|
-
>>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer)
|
|
64
|
-
>>> print(decoder)
|
|
65
58
|
"""
|
|
66
59
|
super().__init__()
|
|
67
60
|
self.transformer_dim = transformer_dim
|
|
@@ -94,8 +87,7 @@ class MaskDecoder(nn.Module):
|
|
|
94
87
|
dense_prompt_embeddings: torch.Tensor,
|
|
95
88
|
multimask_output: bool,
|
|
96
89
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
97
|
-
"""
|
|
98
|
-
Predict masks given image and prompt embeddings.
|
|
90
|
+
"""Predict masks given image and prompt embeddings.
|
|
99
91
|
|
|
100
92
|
Args:
|
|
101
93
|
image_embeddings (torch.Tensor): Embeddings from the image encoder.
|
|
@@ -172,11 +164,10 @@ class MaskDecoder(nn.Module):
|
|
|
172
164
|
|
|
173
165
|
|
|
174
166
|
class SAM2MaskDecoder(nn.Module):
|
|
175
|
-
"""
|
|
176
|
-
Transformer-based decoder for predicting instance segmentation masks from image and prompt embeddings.
|
|
167
|
+
"""Transformer-based decoder for predicting instance segmentation masks from image and prompt embeddings.
|
|
177
168
|
|
|
178
|
-
This class extends the functionality of the MaskDecoder, incorporating additional features such as
|
|
179
|
-
|
|
169
|
+
This class extends the functionality of the MaskDecoder, incorporating additional features such as high-resolution
|
|
170
|
+
feature processing, dynamic multimask output, and object score prediction.
|
|
180
171
|
|
|
181
172
|
Attributes:
|
|
182
173
|
transformer_dim (int): Channel dimension of the transformer.
|
|
@@ -233,11 +224,10 @@ class SAM2MaskDecoder(nn.Module):
|
|
|
233
224
|
pred_obj_scores_mlp: bool = False,
|
|
234
225
|
use_multimask_token_for_obj_ptr: bool = False,
|
|
235
226
|
) -> None:
|
|
236
|
-
"""
|
|
237
|
-
Initialize the SAM2MaskDecoder module for predicting instance segmentation masks.
|
|
227
|
+
"""Initialize the SAM2MaskDecoder module for predicting instance segmentation masks.
|
|
238
228
|
|
|
239
|
-
This decoder extends the functionality of MaskDecoder, incorporating additional features such as
|
|
240
|
-
|
|
229
|
+
This decoder extends the functionality of MaskDecoder, incorporating additional features such as high-resolution
|
|
230
|
+
feature processing, dynamic multimask output, and object score prediction.
|
|
241
231
|
|
|
242
232
|
Args:
|
|
243
233
|
transformer_dim (int): Channel dimension of the transformer.
|
|
@@ -254,11 +244,6 @@ class SAM2MaskDecoder(nn.Module):
|
|
|
254
244
|
pred_obj_scores (bool): Whether to predict object scores.
|
|
255
245
|
pred_obj_scores_mlp (bool): Whether to use MLP for object score prediction.
|
|
256
246
|
use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer.
|
|
257
|
-
|
|
258
|
-
Examples:
|
|
259
|
-
>>> transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=256, nhead=8), num_layers=6)
|
|
260
|
-
>>> decoder = SAM2MaskDecoder(transformer_dim=256, transformer=transformer)
|
|
261
|
-
>>> print(decoder)
|
|
262
247
|
"""
|
|
263
248
|
super().__init__()
|
|
264
249
|
self.transformer_dim = transformer_dim
|
|
@@ -319,8 +304,7 @@ class SAM2MaskDecoder(nn.Module):
|
|
|
319
304
|
repeat_image: bool,
|
|
320
305
|
high_res_features: list[torch.Tensor] | None = None,
|
|
321
306
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
322
|
-
"""
|
|
323
|
-
Predict masks given image and prompt embeddings.
|
|
307
|
+
"""Predict masks given image and prompt embeddings.
|
|
324
308
|
|
|
325
309
|
Args:
|
|
326
310
|
image_embeddings (torch.Tensor): Embeddings from the image encoder with shape (B, C, H, W).
|
|
@@ -452,23 +436,21 @@ class SAM2MaskDecoder(nn.Module):
|
|
|
452
436
|
def _get_stability_scores(self, mask_logits):
|
|
453
437
|
"""Compute mask stability scores based on IoU between upper and lower thresholds."""
|
|
454
438
|
mask_logits = mask_logits.flatten(-2)
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
|
|
439
|
+
area_i = torch.sum(mask_logits > self.dynamic_multimask_stability_delta, dim=-1).float()
|
|
440
|
+
area_u = torch.sum(mask_logits > -self.dynamic_multimask_stability_delta, dim=-1).float()
|
|
458
441
|
return torch.where(area_u > 0, area_i / area_u, 1.0)
|
|
459
442
|
|
|
460
443
|
def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
|
|
461
|
-
"""
|
|
462
|
-
Dynamically select the most stable mask output based on stability scores and IoU predictions.
|
|
444
|
+
"""Dynamically select the most stable mask output based on stability scores and IoU predictions.
|
|
463
445
|
|
|
464
|
-
This method is used when outputting a single mask. If the stability score from the current single-mask
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
446
|
+
This method is used when outputting a single mask. If the stability score from the current single-mask output
|
|
447
|
+
(based on output token 0) falls below a threshold, it instead selects from multi-mask outputs (based on output
|
|
448
|
+
tokens 1-3) the mask with the highest predicted IoU score. This ensures a valid mask for both clicking and
|
|
449
|
+
tracking scenarios.
|
|
468
450
|
|
|
469
451
|
Args:
|
|
470
|
-
all_mask_logits (torch.Tensor): Logits for all predicted masks, shape (B, N, H, W) where B is
|
|
471
|
-
|
|
452
|
+
all_mask_logits (torch.Tensor): Logits for all predicted masks, shape (B, N, H, W) where B is batch size, N
|
|
453
|
+
is number of masks (typically 4), and H, W are mask dimensions.
|
|
472
454
|
all_iou_scores (torch.Tensor): Predicted IoU scores for all masks, shape (B, N).
|
|
473
455
|
|
|
474
456
|
Returns:
|