dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/METADATA +64 -74
- dgenerate_ultralytics_headless-8.4.7.dist-info/RECORD +311 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -9
- tests/conftest.py +8 -15
- tests/test_cli.py +1 -1
- tests/test_cuda.py +13 -10
- tests/test_engine.py +9 -9
- tests/test_exports.py +65 -13
- tests/test_integrations.py +13 -13
- tests/test_python.py +125 -69
- tests/test_solutions.py +161 -152
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +86 -92
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/TT100K.yaml +346 -0
- ultralytics/cfg/datasets/VOC.yaml +15 -16
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +21 -0
- ultralytics/cfg/datasets/coco12-formats.yaml +101 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +5 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +4 -2
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
- ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
- ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
- ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
- ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
- ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
- ultralytics/cfg/models/26/yolo26.yaml +52 -0
- ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
- ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +5 -6
- ultralytics/data/augment.py +300 -475
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +147 -25
- ultralytics/data/converter.py +108 -87
- ultralytics/data/dataset.py +47 -75
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +36 -45
- ultralytics/engine/exporter.py +351 -263
- ultralytics/engine/model.py +186 -225
- ultralytics/engine/predictor.py +45 -54
- ultralytics/engine/results.py +198 -325
- ultralytics/engine/trainer.py +165 -106
- ultralytics/engine/tuner.py +41 -43
- ultralytics/engine/validator.py +55 -38
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +18 -30
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +5 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +19 -37
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +200 -73
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2248 -350
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +10 -13
- ultralytics/models/yolo/classify/train.py +12 -33
- ultralytics/models/yolo/classify/val.py +30 -29
- ultralytics/models/yolo/detect/predict.py +9 -12
- ultralytics/models/yolo/detect/train.py +17 -23
- ultralytics/models/yolo/detect/val.py +77 -59
- ultralytics/models/yolo/model.py +43 -60
- ultralytics/models/yolo/obb/predict.py +7 -16
- ultralytics/models/yolo/obb/train.py +14 -17
- ultralytics/models/yolo/obb/val.py +40 -37
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +7 -22
- ultralytics/models/yolo/pose/train.py +13 -16
- ultralytics/models/yolo/pose/val.py +39 -58
- ultralytics/models/yolo/segment/predict.py +17 -21
- ultralytics/models/yolo/segment/train.py +7 -10
- ultralytics/models/yolo/segment/val.py +95 -47
- ultralytics/models/yolo/world/train.py +8 -14
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +36 -44
- ultralytics/models/yolo/yoloe/train_seg.py +11 -11
- ultralytics/models/yolo/yoloe/val.py +15 -20
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +159 -85
- ultralytics/nn/modules/__init__.py +68 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +260 -224
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +831 -299
- ultralytics/nn/modules/transformer.py +76 -88
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +180 -195
- ultralytics/nn/text_model.py +45 -69
- ultralytics/optim/__init__.py +5 -0
- ultralytics/optim/muon.py +338 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +13 -19
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +6 -7
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +8 -14
- ultralytics/solutions/instance_segmentation.py +6 -9
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +34 -32
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +10 -15
- ultralytics/solutions/solutions.py +77 -76
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +2 -4
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +21 -37
- ultralytics/trackers/track.py +4 -7
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +124 -124
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +57 -71
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +423 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +25 -31
- ultralytics/utils/callbacks/wb.py +16 -14
- ultralytics/utils/checks.py +127 -85
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +9 -12
- ultralytics/utils/downloads.py +25 -33
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +4 -236
- ultralytics/utils/export/engine.py +246 -0
- ultralytics/utils/export/imx.py +117 -63
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +26 -30
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +601 -215
- ultralytics/utils/metrics.py +128 -156
- ultralytics/utils/nms.py +13 -16
- ultralytics/utils/ops.py +117 -166
- ultralytics/utils/patches.py +75 -21
- ultralytics/utils/plotting.py +75 -80
- ultralytics/utils/tal.py +125 -59
- ultralytics/utils/torch_utils.py +53 -79
- ultralytics/utils/tqdm.py +24 -21
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +19 -10
- dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,377 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
|
|
3
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
|
4
|
+
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
|
|
7
|
+
from ultralytics.nn.modules.transformer import MLP
|
|
8
|
+
from ultralytics.utils.patches import torch_load
|
|
9
|
+
|
|
10
|
+
from .modules.blocks import PositionEmbeddingSine, RoPEAttention
|
|
11
|
+
from .modules.encoders import MemoryEncoder
|
|
12
|
+
from .modules.memory_attention import MemoryAttention, MemoryAttentionLayer
|
|
13
|
+
from .modules.sam import SAM3Model
|
|
14
|
+
from .sam3.decoder import TransformerDecoder, TransformerDecoderLayer
|
|
15
|
+
from .sam3.encoder import TransformerEncoderFusion, TransformerEncoderLayer
|
|
16
|
+
from .sam3.geometry_encoders import SequenceGeometryEncoder
|
|
17
|
+
from .sam3.maskformer_segmentation import PixelDecoder, UniversalSegmentationHead
|
|
18
|
+
from .sam3.model_misc import DotProductScoring, TransformerWrapper
|
|
19
|
+
from .sam3.necks import Sam3DualViTDetNeck
|
|
20
|
+
from .sam3.sam3_image import SAM3SemanticModel
|
|
21
|
+
from .sam3.text_encoder_ve import VETextEncoder
|
|
22
|
+
from .sam3.vitdet import ViT
|
|
23
|
+
from .sam3.vl_combiner import SAM3VLBackbone
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _create_vision_backbone(compile_mode=None, enable_inst_interactivity=True) -> Sam3DualViTDetNeck:
|
|
27
|
+
"""Create SAM3 visual backbone with ViT and neck."""
|
|
28
|
+
# Position encoding
|
|
29
|
+
position_encoding = PositionEmbeddingSine(
|
|
30
|
+
num_pos_feats=256,
|
|
31
|
+
normalize=True,
|
|
32
|
+
scale=None,
|
|
33
|
+
temperature=10000,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
# ViT backbone
|
|
37
|
+
vit_backbone = ViT(
|
|
38
|
+
img_size=1008,
|
|
39
|
+
pretrain_img_size=336,
|
|
40
|
+
patch_size=14,
|
|
41
|
+
embed_dim=1024,
|
|
42
|
+
depth=32,
|
|
43
|
+
num_heads=16,
|
|
44
|
+
mlp_ratio=4.625,
|
|
45
|
+
norm_layer="LayerNorm",
|
|
46
|
+
drop_path_rate=0.1,
|
|
47
|
+
qkv_bias=True,
|
|
48
|
+
use_abs_pos=True,
|
|
49
|
+
tile_abs_pos=True,
|
|
50
|
+
global_att_blocks=(7, 15, 23, 31),
|
|
51
|
+
rel_pos_blocks=(),
|
|
52
|
+
use_rope=True,
|
|
53
|
+
use_interp_rope=True,
|
|
54
|
+
window_size=24,
|
|
55
|
+
pretrain_use_cls_token=True,
|
|
56
|
+
retain_cls_token=False,
|
|
57
|
+
ln_pre=True,
|
|
58
|
+
ln_post=False,
|
|
59
|
+
return_interm_layers=False,
|
|
60
|
+
bias_patch_embed=False,
|
|
61
|
+
compile_mode=compile_mode,
|
|
62
|
+
)
|
|
63
|
+
return Sam3DualViTDetNeck(
|
|
64
|
+
position_encoding=position_encoding,
|
|
65
|
+
d_model=256,
|
|
66
|
+
scale_factors=[4.0, 2.0, 1.0, 0.5],
|
|
67
|
+
trunk=vit_backbone,
|
|
68
|
+
add_sam2_neck=enable_inst_interactivity,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _create_sam3_transformer() -> TransformerWrapper:
|
|
73
|
+
"""Create SAM3 detector encoder and decoder."""
|
|
74
|
+
encoder: TransformerEncoderFusion = TransformerEncoderFusion(
|
|
75
|
+
layer=TransformerEncoderLayer(
|
|
76
|
+
d_model=256,
|
|
77
|
+
dim_feedforward=2048,
|
|
78
|
+
dropout=0.1,
|
|
79
|
+
pos_enc_at_attn=True,
|
|
80
|
+
pos_enc_at_cross_attn_keys=False,
|
|
81
|
+
pos_enc_at_cross_attn_queries=False,
|
|
82
|
+
pre_norm=True,
|
|
83
|
+
self_attention=nn.MultiheadAttention(
|
|
84
|
+
num_heads=8,
|
|
85
|
+
dropout=0.1,
|
|
86
|
+
embed_dim=256,
|
|
87
|
+
batch_first=True,
|
|
88
|
+
),
|
|
89
|
+
cross_attention=nn.MultiheadAttention(
|
|
90
|
+
num_heads=8,
|
|
91
|
+
dropout=0.1,
|
|
92
|
+
embed_dim=256,
|
|
93
|
+
batch_first=True,
|
|
94
|
+
),
|
|
95
|
+
),
|
|
96
|
+
num_layers=6,
|
|
97
|
+
d_model=256,
|
|
98
|
+
num_feature_levels=1,
|
|
99
|
+
frozen=False,
|
|
100
|
+
use_act_checkpoint=True,
|
|
101
|
+
add_pooled_text_to_img_feat=False,
|
|
102
|
+
pool_text_with_mask=True,
|
|
103
|
+
)
|
|
104
|
+
decoder: TransformerDecoder = TransformerDecoder(
|
|
105
|
+
layer=TransformerDecoderLayer(
|
|
106
|
+
d_model=256,
|
|
107
|
+
dim_feedforward=2048,
|
|
108
|
+
dropout=0.1,
|
|
109
|
+
cross_attention=nn.MultiheadAttention(
|
|
110
|
+
num_heads=8,
|
|
111
|
+
dropout=0.1,
|
|
112
|
+
embed_dim=256,
|
|
113
|
+
),
|
|
114
|
+
n_heads=8,
|
|
115
|
+
use_text_cross_attention=True,
|
|
116
|
+
),
|
|
117
|
+
num_layers=6,
|
|
118
|
+
num_queries=200,
|
|
119
|
+
return_intermediate=True,
|
|
120
|
+
box_refine=True,
|
|
121
|
+
num_o2m_queries=0,
|
|
122
|
+
dac=True,
|
|
123
|
+
boxRPB="log",
|
|
124
|
+
d_model=256,
|
|
125
|
+
frozen=False,
|
|
126
|
+
interaction_layer=None,
|
|
127
|
+
dac_use_selfatt_ln=True,
|
|
128
|
+
use_act_checkpoint=True,
|
|
129
|
+
presence_token=True,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
return TransformerWrapper(encoder=encoder, decoder=decoder, d_model=256)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def build_sam3_image_model(checkpoint_path: str, enable_segmentation: bool = True, compile: bool = False):
|
|
136
|
+
"""Build SAM3 image model.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
checkpoint_path: Optional path to model checkpoint
|
|
140
|
+
enable_segmentation: Whether to enable segmentation head
|
|
141
|
+
compile: To enable compilation, set to "default"
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
A SAM3 image model
|
|
145
|
+
"""
|
|
146
|
+
try:
|
|
147
|
+
import clip
|
|
148
|
+
except ImportError:
|
|
149
|
+
from ultralytics.utils.checks import check_requirements
|
|
150
|
+
|
|
151
|
+
check_requirements("git+https://github.com/ultralytics/CLIP.git")
|
|
152
|
+
import clip
|
|
153
|
+
# Create visual components
|
|
154
|
+
compile_mode = "default" if compile else None
|
|
155
|
+
vision_encoder = _create_vision_backbone(compile_mode=compile_mode, enable_inst_interactivity=True)
|
|
156
|
+
|
|
157
|
+
# Create text components
|
|
158
|
+
text_encoder = VETextEncoder(
|
|
159
|
+
tokenizer=clip.simple_tokenizer.SimpleTokenizer(),
|
|
160
|
+
d_model=256,
|
|
161
|
+
width=1024,
|
|
162
|
+
heads=16,
|
|
163
|
+
layers=24,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
# Create visual-language backbone
|
|
167
|
+
backbone = SAM3VLBackbone(visual=vision_encoder, text=text_encoder, scalp=1)
|
|
168
|
+
|
|
169
|
+
# Create transformer components
|
|
170
|
+
transformer = _create_sam3_transformer()
|
|
171
|
+
|
|
172
|
+
# Create dot product scoring
|
|
173
|
+
dot_prod_scoring = DotProductScoring(
|
|
174
|
+
d_model=256,
|
|
175
|
+
d_proj=256,
|
|
176
|
+
prompt_mlp=MLP(
|
|
177
|
+
input_dim=256,
|
|
178
|
+
hidden_dim=2048,
|
|
179
|
+
output_dim=256,
|
|
180
|
+
num_layers=2,
|
|
181
|
+
residual=True,
|
|
182
|
+
out_norm=nn.LayerNorm(256),
|
|
183
|
+
),
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
# Create segmentation head if enabled
|
|
187
|
+
segmentation_head = (
|
|
188
|
+
UniversalSegmentationHead(
|
|
189
|
+
hidden_dim=256,
|
|
190
|
+
upsampling_stages=3,
|
|
191
|
+
aux_masks=False,
|
|
192
|
+
presence_head=False,
|
|
193
|
+
dot_product_scorer=None,
|
|
194
|
+
act_ckpt=True,
|
|
195
|
+
cross_attend_prompt=nn.MultiheadAttention(
|
|
196
|
+
num_heads=8,
|
|
197
|
+
dropout=0,
|
|
198
|
+
embed_dim=256,
|
|
199
|
+
),
|
|
200
|
+
pixel_decoder=PixelDecoder(
|
|
201
|
+
num_upsampling_stages=3,
|
|
202
|
+
interpolation_mode="nearest",
|
|
203
|
+
hidden_dim=256,
|
|
204
|
+
compile_mode=compile_mode,
|
|
205
|
+
),
|
|
206
|
+
)
|
|
207
|
+
if enable_segmentation
|
|
208
|
+
else None
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
# Create geometry encoder
|
|
212
|
+
input_geometry_encoder = SequenceGeometryEncoder(
|
|
213
|
+
pos_enc=PositionEmbeddingSine(
|
|
214
|
+
num_pos_feats=256,
|
|
215
|
+
normalize=True,
|
|
216
|
+
scale=None,
|
|
217
|
+
temperature=10000,
|
|
218
|
+
),
|
|
219
|
+
encode_boxes_as_points=False,
|
|
220
|
+
boxes_direct_project=True,
|
|
221
|
+
boxes_pool=True,
|
|
222
|
+
boxes_pos_enc=True,
|
|
223
|
+
d_model=256,
|
|
224
|
+
num_layers=3,
|
|
225
|
+
layer=TransformerEncoderLayer(
|
|
226
|
+
d_model=256,
|
|
227
|
+
dim_feedforward=2048,
|
|
228
|
+
dropout=0.1,
|
|
229
|
+
pos_enc_at_attn=False,
|
|
230
|
+
pre_norm=True,
|
|
231
|
+
pos_enc_at_cross_attn_queries=False,
|
|
232
|
+
pos_enc_at_cross_attn_keys=True,
|
|
233
|
+
),
|
|
234
|
+
use_act_ckpt=True,
|
|
235
|
+
add_cls=True,
|
|
236
|
+
add_post_encode_proj=True,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
# Create the SAM3SemanticModel model
|
|
240
|
+
model = SAM3SemanticModel(
|
|
241
|
+
backbone=backbone,
|
|
242
|
+
transformer=transformer,
|
|
243
|
+
input_geometry_encoder=input_geometry_encoder,
|
|
244
|
+
segmentation_head=segmentation_head,
|
|
245
|
+
num_feature_levels=1,
|
|
246
|
+
o2m_mask_predict=True,
|
|
247
|
+
dot_prod_scoring=dot_prod_scoring,
|
|
248
|
+
use_instance_query=False,
|
|
249
|
+
multimask_output=True,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
# Load checkpoint
|
|
253
|
+
model = _load_checkpoint(model, checkpoint_path)
|
|
254
|
+
model.eval()
|
|
255
|
+
return model
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def build_interactive_sam3(checkpoint_path: str, compile=None, with_backbone=True) -> SAM3Model:
|
|
259
|
+
"""Build the SAM3 Tracker module for video tracking.
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
Sam3TrackerPredictor: Wrapped SAM3 Tracker module
|
|
263
|
+
"""
|
|
264
|
+
# Create model components
|
|
265
|
+
memory_encoder = MemoryEncoder(out_dim=64, interpol_size=[1152, 1152])
|
|
266
|
+
memory_attention = MemoryAttention(
|
|
267
|
+
batch_first=True,
|
|
268
|
+
d_model=256,
|
|
269
|
+
pos_enc_at_input=True,
|
|
270
|
+
layer=MemoryAttentionLayer(
|
|
271
|
+
dim_feedforward=2048,
|
|
272
|
+
dropout=0.1,
|
|
273
|
+
pos_enc_at_attn=False,
|
|
274
|
+
pos_enc_at_cross_attn_keys=True,
|
|
275
|
+
pos_enc_at_cross_attn_queries=False,
|
|
276
|
+
self_attn=RoPEAttention(
|
|
277
|
+
embedding_dim=256,
|
|
278
|
+
num_heads=1,
|
|
279
|
+
downsample_rate=1,
|
|
280
|
+
rope_theta=10000.0,
|
|
281
|
+
feat_sizes=[72, 72],
|
|
282
|
+
),
|
|
283
|
+
d_model=256,
|
|
284
|
+
cross_attn=RoPEAttention(
|
|
285
|
+
embedding_dim=256,
|
|
286
|
+
num_heads=1,
|
|
287
|
+
downsample_rate=1,
|
|
288
|
+
kv_in_dim=64,
|
|
289
|
+
rope_theta=10000.0,
|
|
290
|
+
feat_sizes=[72, 72],
|
|
291
|
+
rope_k_repeat=True,
|
|
292
|
+
),
|
|
293
|
+
),
|
|
294
|
+
num_layers=4,
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
backbone = (
|
|
298
|
+
SAM3VLBackbone(scalp=1, visual=_create_vision_backbone(compile_mode=compile), text=None)
|
|
299
|
+
if with_backbone
|
|
300
|
+
else None
|
|
301
|
+
)
|
|
302
|
+
model = SAM3Model(
|
|
303
|
+
image_size=1008,
|
|
304
|
+
image_encoder=backbone,
|
|
305
|
+
memory_attention=memory_attention,
|
|
306
|
+
memory_encoder=memory_encoder,
|
|
307
|
+
backbone_stride=14,
|
|
308
|
+
num_maskmem=7,
|
|
309
|
+
sigmoid_scale_for_mem_enc=20.0,
|
|
310
|
+
sigmoid_bias_for_mem_enc=-10.0,
|
|
311
|
+
use_mask_input_as_output_without_sam=True,
|
|
312
|
+
directly_add_no_mem_embed=True,
|
|
313
|
+
use_high_res_features_in_sam=True,
|
|
314
|
+
multimask_output_in_sam=True,
|
|
315
|
+
iou_prediction_use_sigmoid=True,
|
|
316
|
+
use_obj_ptrs_in_encoder=True,
|
|
317
|
+
add_tpos_enc_to_obj_ptrs=True,
|
|
318
|
+
only_obj_ptrs_in_the_past_for_eval=True,
|
|
319
|
+
pred_obj_scores=True,
|
|
320
|
+
pred_obj_scores_mlp=True,
|
|
321
|
+
fixed_no_obj_ptr=True,
|
|
322
|
+
multimask_output_for_tracking=True,
|
|
323
|
+
use_multimask_token_for_obj_ptr=True,
|
|
324
|
+
multimask_min_pt_num=0,
|
|
325
|
+
multimask_max_pt_num=1,
|
|
326
|
+
use_mlp_for_obj_ptr_proj=True,
|
|
327
|
+
compile_image_encoder=False,
|
|
328
|
+
no_obj_embed_spatial=True,
|
|
329
|
+
proj_tpos_enc_in_obj_ptrs=True,
|
|
330
|
+
use_signed_tpos_enc_to_obj_ptrs=True,
|
|
331
|
+
sam_mask_decoder_extra_args=dict(
|
|
332
|
+
dynamic_multimask_via_stability=True,
|
|
333
|
+
dynamic_multimask_stability_delta=0.05,
|
|
334
|
+
dynamic_multimask_stability_thresh=0.98,
|
|
335
|
+
),
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
# Load checkpoint if provided
|
|
339
|
+
model = _load_checkpoint(model, checkpoint_path, interactive=True)
|
|
340
|
+
|
|
341
|
+
# Setup device and mode
|
|
342
|
+
model.eval()
|
|
343
|
+
return model
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
def _load_checkpoint(model, checkpoint, interactive=False):
|
|
347
|
+
"""Load SAM3 model checkpoint from file."""
|
|
348
|
+
with open(checkpoint, "rb") as f:
|
|
349
|
+
ckpt = torch_load(f)
|
|
350
|
+
if "model" in ckpt and isinstance(ckpt["model"], dict):
|
|
351
|
+
ckpt = ckpt["model"]
|
|
352
|
+
sam3_image_ckpt = {k.replace("detector.", ""): v for k, v in ckpt.items() if "detector" in k}
|
|
353
|
+
if interactive:
|
|
354
|
+
sam3_image_ckpt.update(
|
|
355
|
+
{
|
|
356
|
+
k.replace("backbone.vision_backbone", "image_encoder.vision_backbone"): v
|
|
357
|
+
for k, v in sam3_image_ckpt.items()
|
|
358
|
+
if "backbone.vision_backbone" in k
|
|
359
|
+
}
|
|
360
|
+
)
|
|
361
|
+
sam3_image_ckpt.update(
|
|
362
|
+
{
|
|
363
|
+
k.replace("tracker.transformer.encoder", "memory_attention"): v
|
|
364
|
+
for k, v in ckpt.items()
|
|
365
|
+
if "tracker.transformer" in k
|
|
366
|
+
}
|
|
367
|
+
)
|
|
368
|
+
sam3_image_ckpt.update(
|
|
369
|
+
{
|
|
370
|
+
k.replace("tracker.maskmem_backbone", "memory_encoder"): v
|
|
371
|
+
for k, v in ckpt.items()
|
|
372
|
+
if "tracker.maskmem_backbone" in k
|
|
373
|
+
}
|
|
374
|
+
)
|
|
375
|
+
sam3_image_ckpt.update({k.replace("tracker.", ""): v for k, v in ckpt.items() if "tracker." in k})
|
|
376
|
+
model.load_state_dict(sam3_image_ckpt, strict=False)
|
|
377
|
+
return model
|
ultralytics/models/sam/model.py
CHANGED
|
@@ -21,16 +21,15 @@ from pathlib import Path
|
|
|
21
21
|
from ultralytics.engine.model import Model
|
|
22
22
|
from ultralytics.utils.torch_utils import model_info
|
|
23
23
|
|
|
24
|
-
from .predict import Predictor, SAM2Predictor
|
|
24
|
+
from .predict import Predictor, SAM2Predictor, SAM3Predictor
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
class SAM(Model):
|
|
28
|
-
"""
|
|
29
|
-
SAM (Segment Anything Model) interface class for real-time image segmentation tasks.
|
|
28
|
+
"""SAM (Segment Anything Model) interface class for real-time image segmentation tasks.
|
|
30
29
|
|
|
31
|
-
This class provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for
|
|
32
|
-
|
|
33
|
-
|
|
30
|
+
This class provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for promptable
|
|
31
|
+
segmentation with versatility in image analysis. It supports various prompts such as bounding boxes, points, or
|
|
32
|
+
labels, and features zero-shot performance capabilities.
|
|
34
33
|
|
|
35
34
|
Attributes:
|
|
36
35
|
model (torch.nn.Module): The loaded SAM model.
|
|
@@ -45,31 +44,26 @@ class SAM(Model):
|
|
|
45
44
|
>>> sam = SAM("sam_b.pt")
|
|
46
45
|
>>> results = sam.predict("image.jpg", points=[[500, 375]])
|
|
47
46
|
>>> for r in results:
|
|
48
|
-
|
|
47
|
+
... print(f"Detected {len(r.masks)} masks")
|
|
49
48
|
"""
|
|
50
49
|
|
|
51
50
|
def __init__(self, model: str = "sam_b.pt") -> None:
|
|
52
|
-
"""
|
|
53
|
-
Initialize the SAM (Segment Anything Model) instance.
|
|
51
|
+
"""Initialize the SAM (Segment Anything Model) instance.
|
|
54
52
|
|
|
55
53
|
Args:
|
|
56
54
|
model (str): Path to the pre-trained SAM model file. File should have a .pt or .pth extension.
|
|
57
55
|
|
|
58
56
|
Raises:
|
|
59
57
|
NotImplementedError: If the model file extension is not .pt or .pth.
|
|
60
|
-
|
|
61
|
-
Examples:
|
|
62
|
-
>>> sam = SAM("sam_b.pt")
|
|
63
|
-
>>> print(sam.is_sam2)
|
|
64
58
|
"""
|
|
65
59
|
if model and Path(model).suffix not in {".pt", ".pth"}:
|
|
66
60
|
raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
|
|
67
61
|
self.is_sam2 = "sam2" in Path(model).stem
|
|
62
|
+
self.is_sam3 = "sam3" in Path(model).stem
|
|
68
63
|
super().__init__(model=model, task="segment")
|
|
69
64
|
|
|
70
65
|
def _load(self, weights: str, task=None):
|
|
71
|
-
"""
|
|
72
|
-
Load the specified weights into the SAM model.
|
|
66
|
+
"""Load the specified weights into the SAM model.
|
|
73
67
|
|
|
74
68
|
Args:
|
|
75
69
|
weights (str): Path to the weights file. Should be a .pt or .pth file containing the model parameters.
|
|
@@ -79,17 +73,21 @@ class SAM(Model):
|
|
|
79
73
|
>>> sam = SAM("sam_b.pt")
|
|
80
74
|
>>> sam._load("path/to/custom_weights.pt")
|
|
81
75
|
"""
|
|
82
|
-
|
|
76
|
+
if self.is_sam3:
|
|
77
|
+
from .build_sam3 import build_interactive_sam3
|
|
83
78
|
|
|
84
|
-
|
|
79
|
+
self.model = build_interactive_sam3(weights)
|
|
80
|
+
else:
|
|
81
|
+
from .build import build_sam # slow import
|
|
82
|
+
|
|
83
|
+
self.model = build_sam(weights)
|
|
85
84
|
|
|
86
85
|
def predict(self, source, stream: bool = False, bboxes=None, points=None, labels=None, **kwargs):
|
|
87
|
-
"""
|
|
88
|
-
Perform segmentation prediction on the given image or video source.
|
|
86
|
+
"""Perform segmentation prediction on the given image or video source.
|
|
89
87
|
|
|
90
88
|
Args:
|
|
91
|
-
source (str | PIL.Image | np.ndarray): Path to the image or video file, or a PIL.Image object, or
|
|
92
|
-
|
|
89
|
+
source (str | PIL.Image | np.ndarray): Path to the image or video file, or a PIL.Image object, or a
|
|
90
|
+
np.ndarray object.
|
|
93
91
|
stream (bool): If True, enables real-time streaming.
|
|
94
92
|
bboxes (list[list[float]] | None): List of bounding box coordinates for prompted segmentation.
|
|
95
93
|
points (list[list[float]] | None): List of points for prompted segmentation.
|
|
@@ -111,15 +109,14 @@ class SAM(Model):
|
|
|
111
109
|
return super().predict(source, stream, prompts=prompts, **kwargs)
|
|
112
110
|
|
|
113
111
|
def __call__(self, source=None, stream: bool = False, bboxes=None, points=None, labels=None, **kwargs):
|
|
114
|
-
"""
|
|
115
|
-
Perform segmentation prediction on the given image or video source.
|
|
112
|
+
"""Perform segmentation prediction on the given image or video source.
|
|
116
113
|
|
|
117
|
-
This method is an alias for the 'predict' method, providing a convenient way to call the SAM model
|
|
118
|
-
|
|
114
|
+
This method is an alias for the 'predict' method, providing a convenient way to call the SAM model for
|
|
115
|
+
segmentation tasks.
|
|
119
116
|
|
|
120
117
|
Args:
|
|
121
|
-
source (str | PIL.Image | np.ndarray | None): Path to the image or video file, or a PIL.Image
|
|
122
|
-
|
|
118
|
+
source (str | PIL.Image | np.ndarray | None): Path to the image or video file, or a PIL.Image object, or a
|
|
119
|
+
np.ndarray object.
|
|
123
120
|
stream (bool): If True, enables real-time streaming.
|
|
124
121
|
bboxes (list[list[float]] | None): List of bounding box coordinates for prompted segmentation.
|
|
125
122
|
points (list[list[float]] | None): List of points for prompted segmentation.
|
|
@@ -137,8 +134,7 @@ class SAM(Model):
|
|
|
137
134
|
return self.predict(source, stream, bboxes, points, labels, **kwargs)
|
|
138
135
|
|
|
139
136
|
def info(self, detailed: bool = False, verbose: bool = True):
|
|
140
|
-
"""
|
|
141
|
-
Log information about the SAM model.
|
|
137
|
+
"""Log information about the SAM model.
|
|
142
138
|
|
|
143
139
|
Args:
|
|
144
140
|
detailed (bool): If True, displays detailed information about the model layers and operations.
|
|
@@ -156,8 +152,7 @@ class SAM(Model):
|
|
|
156
152
|
|
|
157
153
|
@property
|
|
158
154
|
def task_map(self) -> dict[str, dict[str, type[Predictor]]]:
|
|
159
|
-
"""
|
|
160
|
-
Provide a mapping from the 'segment' task to its corresponding 'Predictor'.
|
|
155
|
+
"""Provide a mapping from the 'segment' task to its corresponding 'Predictor'.
|
|
161
156
|
|
|
162
157
|
Returns:
|
|
163
158
|
(dict[str, dict[str, Type[Predictor]]]): A dictionary mapping the 'segment' task to its corresponding
|
|
@@ -169,4 +164,6 @@ class SAM(Model):
|
|
|
169
164
|
>>> print(task_map)
|
|
170
165
|
{'segment': {'predictor': <class 'ultralytics.models.sam.predict.Predictor'>}}
|
|
171
166
|
"""
|
|
172
|
-
return {
|
|
167
|
+
return {
|
|
168
|
+
"segment": {"predictor": SAM2Predictor if self.is_sam2 else SAM3Predictor if self.is_sam3 else Predictor}
|
|
169
|
+
}
|