dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
- dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -6
- tests/conftest.py +15 -39
- tests/test_cli.py +17 -17
- tests/test_cuda.py +17 -8
- tests/test_engine.py +36 -10
- tests/test_exports.py +98 -37
- tests/test_integrations.py +12 -15
- tests/test_python.py +126 -82
- tests/test_solutions.py +319 -135
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +83 -87
- ultralytics/cfg/datasets/Argoverse.yaml +4 -4
- ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
- ultralytics/cfg/datasets/ImageNet.yaml +3 -3
- ultralytics/cfg/datasets/Objects365.yaml +24 -20
- ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
- ultralytics/cfg/datasets/VOC.yaml +10 -13
- ultralytics/cfg/datasets/VisDrone.yaml +43 -33
- ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
- ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
- ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
- ultralytics/cfg/datasets/coco-pose.yaml +26 -4
- ultralytics/cfg/datasets/coco.yaml +4 -4
- ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco128.yaml +2 -2
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco8.yaml +2 -2
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +5 -5
- ultralytics/cfg/datasets/dog-pose.yaml +32 -4
- ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
- ultralytics/cfg/datasets/lvis.yaml +9 -9
- ultralytics/cfg/datasets/medical-pills.yaml +4 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
- ultralytics/cfg/datasets/package-seg.yaml +5 -5
- ultralytics/cfg/datasets/signature.yaml +4 -4
- ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
- ultralytics/cfg/datasets/xView.yaml +5 -5
- ultralytics/cfg/default.yaml +96 -93
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +12 -12
- ultralytics/data/augment.py +531 -564
- ultralytics/data/base.py +76 -81
- ultralytics/data/build.py +206 -42
- ultralytics/data/converter.py +179 -78
- ultralytics/data/dataset.py +121 -121
- ultralytics/data/loaders.py +114 -91
- ultralytics/data/split.py +28 -15
- ultralytics/data/split_dota.py +67 -48
- ultralytics/data/utils.py +110 -89
- ultralytics/engine/exporter.py +422 -460
- ultralytics/engine/model.py +224 -252
- ultralytics/engine/predictor.py +94 -89
- ultralytics/engine/results.py +345 -595
- ultralytics/engine/trainer.py +231 -134
- ultralytics/engine/tuner.py +279 -73
- ultralytics/engine/validator.py +53 -46
- ultralytics/hub/__init__.py +26 -28
- ultralytics/hub/auth.py +30 -16
- ultralytics/hub/google/__init__.py +34 -36
- ultralytics/hub/session.py +53 -77
- ultralytics/hub/utils.py +23 -109
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +36 -18
- ultralytics/models/fastsam/predict.py +33 -44
- ultralytics/models/fastsam/utils.py +4 -5
- ultralytics/models/fastsam/val.py +12 -14
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +16 -20
- ultralytics/models/nas/predict.py +12 -14
- ultralytics/models/nas/val.py +4 -5
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +9 -9
- ultralytics/models/rtdetr/predict.py +22 -17
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +79 -59
- ultralytics/models/sam/__init__.py +8 -2
- ultralytics/models/sam/amg.py +53 -38
- ultralytics/models/sam/build.py +29 -31
- ultralytics/models/sam/model.py +33 -38
- ultralytics/models/sam/modules/blocks.py +159 -182
- ultralytics/models/sam/modules/decoders.py +38 -47
- ultralytics/models/sam/modules/encoders.py +114 -133
- ultralytics/models/sam/modules/memory_attention.py +38 -31
- ultralytics/models/sam/modules/sam.py +114 -93
- ultralytics/models/sam/modules/tiny_encoder.py +268 -291
- ultralytics/models/sam/modules/transformer.py +59 -66
- ultralytics/models/sam/modules/utils.py +55 -72
- ultralytics/models/sam/predict.py +745 -341
- ultralytics/models/utils/loss.py +118 -107
- ultralytics/models/utils/ops.py +118 -71
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +28 -26
- ultralytics/models/yolo/classify/train.py +50 -81
- ultralytics/models/yolo/classify/val.py +68 -61
- ultralytics/models/yolo/detect/predict.py +12 -15
- ultralytics/models/yolo/detect/train.py +56 -46
- ultralytics/models/yolo/detect/val.py +279 -223
- ultralytics/models/yolo/model.py +167 -86
- ultralytics/models/yolo/obb/predict.py +7 -11
- ultralytics/models/yolo/obb/train.py +23 -25
- ultralytics/models/yolo/obb/val.py +107 -99
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +12 -14
- ultralytics/models/yolo/pose/train.py +31 -69
- ultralytics/models/yolo/pose/val.py +119 -254
- ultralytics/models/yolo/segment/predict.py +21 -25
- ultralytics/models/yolo/segment/train.py +12 -66
- ultralytics/models/yolo/segment/val.py +126 -305
- ultralytics/models/yolo/world/train.py +53 -45
- ultralytics/models/yolo/world/train_world.py +51 -32
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +30 -37
- ultralytics/models/yolo/yoloe/train.py +89 -71
- ultralytics/models/yolo/yoloe/train_seg.py +15 -17
- ultralytics/models/yolo/yoloe/val.py +56 -41
- ultralytics/nn/__init__.py +9 -11
- ultralytics/nn/autobackend.py +179 -107
- ultralytics/nn/modules/__init__.py +67 -67
- ultralytics/nn/modules/activation.py +8 -7
- ultralytics/nn/modules/block.py +302 -323
- ultralytics/nn/modules/conv.py +61 -104
- ultralytics/nn/modules/head.py +488 -186
- ultralytics/nn/modules/transformer.py +183 -123
- ultralytics/nn/modules/utils.py +15 -20
- ultralytics/nn/tasks.py +327 -203
- ultralytics/nn/text_model.py +81 -65
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +19 -27
- ultralytics/solutions/analytics.py +36 -26
- ultralytics/solutions/config.py +29 -28
- ultralytics/solutions/distance_calculation.py +23 -24
- ultralytics/solutions/heatmap.py +17 -19
- ultralytics/solutions/instance_segmentation.py +21 -19
- ultralytics/solutions/object_blurrer.py +16 -17
- ultralytics/solutions/object_counter.py +48 -53
- ultralytics/solutions/object_cropper.py +22 -16
- ultralytics/solutions/parking_management.py +61 -58
- ultralytics/solutions/queue_management.py +19 -19
- ultralytics/solutions/region_counter.py +63 -50
- ultralytics/solutions/security_alarm.py +22 -25
- ultralytics/solutions/similarity_search.py +107 -60
- ultralytics/solutions/solutions.py +343 -262
- ultralytics/solutions/speed_estimation.py +35 -31
- ultralytics/solutions/streamlit_inference.py +104 -40
- ultralytics/solutions/templates/similarity-search.html +31 -24
- ultralytics/solutions/trackzone.py +24 -24
- ultralytics/solutions/vision_eye.py +11 -12
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +18 -27
- ultralytics/trackers/bot_sort.py +48 -39
- ultralytics/trackers/byte_tracker.py +94 -94
- ultralytics/trackers/track.py +7 -16
- ultralytics/trackers/utils/gmc.py +37 -69
- ultralytics/trackers/utils/kalman_filter.py +68 -76
- ultralytics/trackers/utils/matching.py +13 -17
- ultralytics/utils/__init__.py +251 -275
- ultralytics/utils/autobatch.py +19 -7
- ultralytics/utils/autodevice.py +68 -38
- ultralytics/utils/benchmarks.py +169 -130
- ultralytics/utils/callbacks/base.py +12 -13
- ultralytics/utils/callbacks/clearml.py +14 -15
- ultralytics/utils/callbacks/comet.py +139 -66
- ultralytics/utils/callbacks/dvc.py +19 -27
- ultralytics/utils/callbacks/hub.py +8 -6
- ultralytics/utils/callbacks/mlflow.py +6 -10
- ultralytics/utils/callbacks/neptune.py +11 -19
- ultralytics/utils/callbacks/platform.py +73 -0
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +9 -12
- ultralytics/utils/callbacks/wb.py +33 -30
- ultralytics/utils/checks.py +163 -114
- ultralytics/utils/cpu.py +89 -0
- ultralytics/utils/dist.py +24 -20
- ultralytics/utils/downloads.py +176 -146
- ultralytics/utils/errors.py +11 -13
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +81 -63
- ultralytics/utils/export/imx.py +294 -0
- ultralytics/utils/export/tensorflow.py +217 -0
- ultralytics/utils/files.py +33 -36
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +105 -120
- ultralytics/utils/logger.py +404 -0
- ultralytics/utils/loss.py +99 -61
- ultralytics/utils/metrics.py +649 -478
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +263 -451
- ultralytics/utils/patches.py +70 -31
- ultralytics/utils/plotting.py +253 -223
- ultralytics/utils/tal.py +48 -61
- ultralytics/utils/torch_utils.py +244 -251
- ultralytics/utils/tqdm.py +438 -0
- ultralytics/utils/triton.py +22 -23
- ultralytics/utils/tuner.py +11 -10
- dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
from torch import nn
|
|
@@ -9,8 +9,7 @@ from ultralytics.nn.modules import MLP, LayerNorm2d
|
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class MaskDecoder(nn.Module):
|
|
12
|
-
"""
|
|
13
|
-
Decoder module for generating masks and their associated quality scores using a transformer architecture.
|
|
12
|
+
"""Decoder module for generating masks and their associated quality scores using a transformer architecture.
|
|
14
13
|
|
|
15
14
|
This class predicts masks given image and prompt embeddings, utilizing a transformer to process the inputs and
|
|
16
15
|
generate mask predictions along with their quality scores.
|
|
@@ -27,7 +26,7 @@ class MaskDecoder(nn.Module):
|
|
|
27
26
|
iou_prediction_head (nn.Module): MLP for predicting mask quality.
|
|
28
27
|
|
|
29
28
|
Methods:
|
|
30
|
-
forward:
|
|
29
|
+
forward: Predict masks given image and prompt embeddings.
|
|
31
30
|
predict_masks: Internal method for mask prediction.
|
|
32
31
|
|
|
33
32
|
Examples:
|
|
@@ -43,12 +42,11 @@ class MaskDecoder(nn.Module):
|
|
|
43
42
|
transformer_dim: int,
|
|
44
43
|
transformer: nn.Module,
|
|
45
44
|
num_multimask_outputs: int = 3,
|
|
46
|
-
activation:
|
|
45
|
+
activation: type[nn.Module] = nn.GELU,
|
|
47
46
|
iou_head_depth: int = 3,
|
|
48
47
|
iou_head_hidden_dim: int = 256,
|
|
49
48
|
) -> None:
|
|
50
|
-
"""
|
|
51
|
-
Initialize the MaskDecoder module for generating masks and their associated quality scores.
|
|
49
|
+
"""Initialize the MaskDecoder module for generating masks and their associated quality scores.
|
|
52
50
|
|
|
53
51
|
Args:
|
|
54
52
|
transformer_dim (int): Channel dimension for the transformer module.
|
|
@@ -93,9 +91,8 @@ class MaskDecoder(nn.Module):
|
|
|
93
91
|
sparse_prompt_embeddings: torch.Tensor,
|
|
94
92
|
dense_prompt_embeddings: torch.Tensor,
|
|
95
93
|
multimask_output: bool,
|
|
96
|
-
) ->
|
|
97
|
-
"""
|
|
98
|
-
Predict masks given image and prompt embeddings.
|
|
94
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
95
|
+
"""Predict masks given image and prompt embeddings.
|
|
99
96
|
|
|
100
97
|
Args:
|
|
101
98
|
image_embeddings (torch.Tensor): Embeddings from the image encoder.
|
|
@@ -129,7 +126,6 @@ class MaskDecoder(nn.Module):
|
|
|
129
126
|
masks = masks[:, mask_slice, :, :]
|
|
130
127
|
iou_pred = iou_pred[:, mask_slice]
|
|
131
128
|
|
|
132
|
-
# Prepare output
|
|
133
129
|
return masks, iou_pred
|
|
134
130
|
|
|
135
131
|
def predict_masks(
|
|
@@ -138,7 +134,7 @@ class MaskDecoder(nn.Module):
|
|
|
138
134
|
image_pe: torch.Tensor,
|
|
139
135
|
sparse_prompt_embeddings: torch.Tensor,
|
|
140
136
|
dense_prompt_embeddings: torch.Tensor,
|
|
141
|
-
) ->
|
|
137
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
142
138
|
"""Predict masks and quality scores using image and prompt embeddings via transformer architecture."""
|
|
143
139
|
# Concatenate output tokens
|
|
144
140
|
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
|
|
@@ -159,7 +155,7 @@ class MaskDecoder(nn.Module):
|
|
|
159
155
|
# Upscale mask embeddings and predict masks using the mask tokens
|
|
160
156
|
src = src.transpose(1, 2).view(b, c, h, w)
|
|
161
157
|
upscaled_embedding = self.output_upscaling(src)
|
|
162
|
-
hyper_in_list:
|
|
158
|
+
hyper_in_list: list[torch.Tensor] = [
|
|
163
159
|
self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)
|
|
164
160
|
]
|
|
165
161
|
hyper_in = torch.stack(hyper_in_list, dim=1)
|
|
@@ -173,11 +169,10 @@ class MaskDecoder(nn.Module):
|
|
|
173
169
|
|
|
174
170
|
|
|
175
171
|
class SAM2MaskDecoder(nn.Module):
|
|
176
|
-
"""
|
|
177
|
-
Transformer-based decoder for predicting instance segmentation masks from image and prompt embeddings.
|
|
172
|
+
"""Transformer-based decoder for predicting instance segmentation masks from image and prompt embeddings.
|
|
178
173
|
|
|
179
|
-
This class extends the functionality of the MaskDecoder, incorporating additional features such as
|
|
180
|
-
|
|
174
|
+
This class extends the functionality of the MaskDecoder, incorporating additional features such as high-resolution
|
|
175
|
+
feature processing, dynamic multimask output, and object score prediction.
|
|
181
176
|
|
|
182
177
|
Attributes:
|
|
183
178
|
transformer_dim (int): Channel dimension of the transformer.
|
|
@@ -201,10 +196,10 @@ class SAM2MaskDecoder(nn.Module):
|
|
|
201
196
|
dynamic_multimask_stability_thresh (float): Threshold for dynamic multimask stability.
|
|
202
197
|
|
|
203
198
|
Methods:
|
|
204
|
-
forward:
|
|
205
|
-
predict_masks:
|
|
206
|
-
_get_stability_scores:
|
|
207
|
-
_dynamic_multimask_via_stability: Dynamically
|
|
199
|
+
forward: Predict masks given image and prompt embeddings.
|
|
200
|
+
predict_masks: Predict instance segmentation masks from image and prompt embeddings.
|
|
201
|
+
_get_stability_scores: Compute mask stability scores based on IoU between thresholds.
|
|
202
|
+
_dynamic_multimask_via_stability: Dynamically select the most stable mask output.
|
|
208
203
|
|
|
209
204
|
Examples:
|
|
210
205
|
>>> image_embeddings = torch.rand(1, 256, 64, 64)
|
|
@@ -222,7 +217,7 @@ class SAM2MaskDecoder(nn.Module):
|
|
|
222
217
|
transformer_dim: int,
|
|
223
218
|
transformer: nn.Module,
|
|
224
219
|
num_multimask_outputs: int = 3,
|
|
225
|
-
activation:
|
|
220
|
+
activation: type[nn.Module] = nn.GELU,
|
|
226
221
|
iou_head_depth: int = 3,
|
|
227
222
|
iou_head_hidden_dim: int = 256,
|
|
228
223
|
use_high_res_features: bool = False,
|
|
@@ -234,11 +229,10 @@ class SAM2MaskDecoder(nn.Module):
|
|
|
234
229
|
pred_obj_scores_mlp: bool = False,
|
|
235
230
|
use_multimask_token_for_obj_ptr: bool = False,
|
|
236
231
|
) -> None:
|
|
237
|
-
"""
|
|
238
|
-
Initialize the SAM2MaskDecoder module for predicting instance segmentation masks.
|
|
232
|
+
"""Initialize the SAM2MaskDecoder module for predicting instance segmentation masks.
|
|
239
233
|
|
|
240
|
-
This decoder extends the functionality of MaskDecoder, incorporating additional features such as
|
|
241
|
-
|
|
234
|
+
This decoder extends the functionality of MaskDecoder, incorporating additional features such as high-resolution
|
|
235
|
+
feature processing, dynamic multimask output, and object score prediction.
|
|
242
236
|
|
|
243
237
|
Args:
|
|
244
238
|
transformer_dim (int): Channel dimension of the transformer.
|
|
@@ -318,10 +312,9 @@ class SAM2MaskDecoder(nn.Module):
|
|
|
318
312
|
dense_prompt_embeddings: torch.Tensor,
|
|
319
313
|
multimask_output: bool,
|
|
320
314
|
repeat_image: bool,
|
|
321
|
-
high_res_features:
|
|
322
|
-
) ->
|
|
323
|
-
"""
|
|
324
|
-
Predict masks given image and prompt embeddings.
|
|
315
|
+
high_res_features: list[torch.Tensor] | None = None,
|
|
316
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
317
|
+
"""Predict masks given image and prompt embeddings.
|
|
325
318
|
|
|
326
319
|
Args:
|
|
327
320
|
image_embeddings (torch.Tensor): Embeddings from the image encoder with shape (B, C, H, W).
|
|
@@ -330,7 +323,7 @@ class SAM2MaskDecoder(nn.Module):
|
|
|
330
323
|
dense_prompt_embeddings (torch.Tensor): Embeddings of the mask inputs with shape (B, C, H, W).
|
|
331
324
|
multimask_output (bool): Whether to return multiple masks or a single mask.
|
|
332
325
|
repeat_image (bool): Flag to repeat the image embeddings.
|
|
333
|
-
high_res_features (
|
|
326
|
+
high_res_features (list[torch.Tensor] | None, optional): Optional high-resolution features.
|
|
334
327
|
|
|
335
328
|
Returns:
|
|
336
329
|
masks (torch.Tensor): Batched predicted masks with shape (B, N, H, W).
|
|
@@ -377,7 +370,6 @@ class SAM2MaskDecoder(nn.Module):
|
|
|
377
370
|
# are always the single mask token (and we'll let it be the object-memory token).
|
|
378
371
|
sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape
|
|
379
372
|
|
|
380
|
-
# Prepare output
|
|
381
373
|
return masks, iou_pred, sam_tokens_out, object_score_logits
|
|
382
374
|
|
|
383
375
|
def predict_masks(
|
|
@@ -387,8 +379,8 @@ class SAM2MaskDecoder(nn.Module):
|
|
|
387
379
|
sparse_prompt_embeddings: torch.Tensor,
|
|
388
380
|
dense_prompt_embeddings: torch.Tensor,
|
|
389
381
|
repeat_image: bool,
|
|
390
|
-
high_res_features:
|
|
391
|
-
) ->
|
|
382
|
+
high_res_features: list[torch.Tensor] | None = None,
|
|
383
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
392
384
|
"""Predict instance segmentation masks from image and prompt embeddings using a transformer."""
|
|
393
385
|
# Concatenate output tokens
|
|
394
386
|
s = 0
|
|
@@ -404,7 +396,7 @@ class SAM2MaskDecoder(nn.Module):
|
|
|
404
396
|
s = 1
|
|
405
397
|
else:
|
|
406
398
|
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
|
|
407
|
-
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.
|
|
399
|
+
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.shape[0], -1, -1)
|
|
408
400
|
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
|
|
409
401
|
|
|
410
402
|
# Expand per-image data in batch direction to be per-mask
|
|
@@ -414,7 +406,7 @@ class SAM2MaskDecoder(nn.Module):
|
|
|
414
406
|
assert image_embeddings.shape[0] == tokens.shape[0]
|
|
415
407
|
src = image_embeddings
|
|
416
408
|
src = src + dense_prompt_embeddings
|
|
417
|
-
assert image_pe.
|
|
409
|
+
assert image_pe.shape[0] == 1, "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
|
|
418
410
|
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
|
|
419
411
|
b, c, h, w = src.shape
|
|
420
412
|
|
|
@@ -425,7 +417,7 @@ class SAM2MaskDecoder(nn.Module):
|
|
|
425
417
|
|
|
426
418
|
# Upscale mask embeddings and predict masks using the mask tokens
|
|
427
419
|
src = src.transpose(1, 2).view(b, c, h, w)
|
|
428
|
-
if not self.use_high_res_features:
|
|
420
|
+
if not self.use_high_res_features or high_res_features is None:
|
|
429
421
|
upscaled_embedding = self.output_upscaling(src)
|
|
430
422
|
else:
|
|
431
423
|
dc1, ln1, act1, dc2, act2 = self.output_upscaling
|
|
@@ -433,7 +425,7 @@ class SAM2MaskDecoder(nn.Module):
|
|
|
433
425
|
upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
|
|
434
426
|
upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)
|
|
435
427
|
|
|
436
|
-
hyper_in_list:
|
|
428
|
+
hyper_in_list: list[torch.Tensor] = [
|
|
437
429
|
self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)
|
|
438
430
|
]
|
|
439
431
|
hyper_in = torch.stack(hyper_in_list, dim=1)
|
|
@@ -460,17 +452,16 @@ class SAM2MaskDecoder(nn.Module):
|
|
|
460
452
|
return torch.where(area_u > 0, area_i / area_u, 1.0)
|
|
461
453
|
|
|
462
454
|
def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
|
|
463
|
-
"""
|
|
464
|
-
Dynamically select the most stable mask output based on stability scores and IoU predictions.
|
|
455
|
+
"""Dynamically select the most stable mask output based on stability scores and IoU predictions.
|
|
465
456
|
|
|
466
|
-
This method is used when outputting a single mask. If the stability score from the current single-mask
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
457
|
+
This method is used when outputting a single mask. If the stability score from the current single-mask output
|
|
458
|
+
(based on output token 0) falls below a threshold, it instead selects from multi-mask outputs (based on output
|
|
459
|
+
tokens 1-3) the mask with the highest predicted IoU score. This ensures a valid mask for both clicking and
|
|
460
|
+
tracking scenarios.
|
|
470
461
|
|
|
471
462
|
Args:
|
|
472
|
-
all_mask_logits (torch.Tensor): Logits for all predicted masks, shape (B, N, H, W) where B is
|
|
473
|
-
|
|
463
|
+
all_mask_logits (torch.Tensor): Logits for all predicted masks, shape (B, N, H, W) where B is batch size, N
|
|
464
|
+
is number of masks (typically 4), and H, W are mask dimensions.
|
|
474
465
|
all_iou_scores (torch.Tensor): Predicted IoU scores for all masks, shape (B, N).
|
|
475
466
|
|
|
476
467
|
Returns:
|
|
@@ -489,7 +480,7 @@ class SAM2MaskDecoder(nn.Module):
|
|
|
489
480
|
multimask_logits = all_mask_logits[:, 1:, :, :]
|
|
490
481
|
multimask_iou_scores = all_iou_scores[:, 1:]
|
|
491
482
|
best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
|
|
492
|
-
batch_inds = torch.arange(multimask_iou_scores.
|
|
483
|
+
batch_inds = torch.arange(multimask_iou_scores.shape[0], device=all_iou_scores.device)
|
|
493
484
|
best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
|
|
494
485
|
best_multimask_logits = best_multimask_logits.unsqueeze(1)
|
|
495
486
|
best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
|